Add count option for bool flags
This commit is contained in:
parent
eea567aa20
commit
03487fc7f0
54
flag_bool.go
54
flag_bool.go
@ -1,11 +1,56 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// boolValue needs to implement the boolFlag internal interface in flag
|
||||
// to be able to capture bool fields and values
|
||||
//
|
||||
// type boolFlag interface {
|
||||
// Value
|
||||
// IsBoolFlag() bool
|
||||
// }
|
||||
type boolValue struct {
|
||||
destination *bool
|
||||
count *int
|
||||
}
|
||||
|
||||
func newBoolValue(val bool, p *bool, count *int) *boolValue {
|
||||
*p = val
|
||||
return &boolValue{
|
||||
destination: p,
|
||||
count: count,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *boolValue) Set(s string) error {
|
||||
v, err := strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
err = errors.New("parse error")
|
||||
return err
|
||||
}
|
||||
*b.destination = v
|
||||
if b.count != nil {
|
||||
*b.count = *b.count + 1
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *boolValue) Get() interface{} { return *b.destination }
|
||||
|
||||
func (b *boolValue) String() string {
|
||||
if b.destination != nil {
|
||||
return strconv.FormatBool(*b.destination)
|
||||
}
|
||||
return strconv.FormatBool(false)
|
||||
}
|
||||
|
||||
func (b *boolValue) IsBoolFlag() bool { return true }
|
||||
|
||||
// GetValue returns the flags value as string representation and an empty
|
||||
// string if the flag takes no value at all.
|
||||
func (f *BoolFlag) GetValue() string {
|
||||
@ -41,11 +86,14 @@ func (f *BoolFlag) Apply(set *flag.FlagSet) error {
|
||||
}
|
||||
|
||||
for _, name := range f.Names() {
|
||||
var value flag.Value
|
||||
if f.Destination != nil {
|
||||
set.BoolVar(f.Destination, name, f.Value, f.Usage)
|
||||
continue
|
||||
value = newBoolValue(f.Value, f.Destination, f.Count)
|
||||
} else {
|
||||
t := new(bool)
|
||||
value = newBoolValue(f.Value, t, f.Count)
|
||||
}
|
||||
set.Bool(name, f.Value, f.Usage)
|
||||
set.Var(value, name, f.Usage)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
13
flag_test.go
13
flag_test.go
@ -62,6 +62,19 @@ func TestBoolFlagValueFromContext(t *testing.T) {
|
||||
expect(t, ff.Get(ctx), false)
|
||||
}
|
||||
|
||||
func TestBoolFlagApply_SetsCount(t *testing.T) {
|
||||
v := false
|
||||
count := 0
|
||||
fl := BoolFlag{Name: "wat", Aliases: []string{"W", "huh"}, Destination: &v, Count: &count}
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
_ = fl.Apply(set)
|
||||
|
||||
err := set.Parse([]string{"--wat", "-W", "--huh"})
|
||||
expect(t, err, nil)
|
||||
expect(t, v, true)
|
||||
expect(t, count, 3)
|
||||
}
|
||||
|
||||
func TestFlagsFromEnv(t *testing.T) {
|
||||
newSetFloat64Slice := func(defaults ...float64) Float64Slice {
|
||||
s := NewFloat64Slice(defaults...)
|
||||
|
Loading…
Reference in New Issue
Block a user