diff --git a/context.go b/context.go index dc0d1ef..0c7a07e 100644 --- a/context.go +++ b/context.go @@ -105,6 +105,18 @@ func (cCtx *Context) Lineage() []*Context { return lineage } +// NumOccurrences returns the num of occurences of this flag +func (cCtx *Context) Count(name string) int { + if fs := cCtx.lookupFlagSet(name); fs != nil { + if bf, ok := fs.Lookup(name).Value.(*boolValue); ok { + if bf.count != nil { + return *bf.count + } + } + } + return 0 +} + // Value returns the value of the flag corresponding to `name` func (cCtx *Context) Value(name string) interface{} { if fs := cCtx.lookupFlagSet(name); fs != nil { diff --git a/flag_bool.go b/flag_bool.go index dc32402..aad26cc 100644 --- a/flag_bool.go +++ b/flag_bool.go @@ -105,14 +105,18 @@ func (f *BoolFlag) Apply(set *flag.FlagSet) error { f.HasBeenSet = true } + count := f.Count + dest := f.Destination + + if count == nil { + count = new(int) + } + if dest == nil { + dest = new(bool) + } + for _, name := range f.Names() { - var value flag.Value - if f.Destination != nil { - value = newBoolValue(f.Value, f.Destination, f.Count) - } else { - t := new(bool) - value = newBoolValue(f.Value, t, f.Count) - } + value := newBoolValue(f.Value, dest, count) set.Var(value, name, f.Usage) } diff --git a/flag_test.go b/flag_test.go index 167f31e..8451ccb 100644 --- a/flag_test.go +++ b/flag_test.go @@ -67,14 +67,39 @@ func TestBoolFlagApply_SetsCount(t *testing.T) { count := 0 fl := BoolFlag{Name: "wat", Aliases: []string{"W", "huh"}, Destination: &v, Count: &count} set := flag.NewFlagSet("test", 0) - _ = fl.Apply(set) + err := fl.Apply(set) + expect(t, err, nil) - err := set.Parse([]string{"--wat", "-W", "--huh"}) + err = set.Parse([]string{"--wat", "-W", "--huh"}) expect(t, err, nil) expect(t, v, true) expect(t, count, 3) } +func TestBoolFlagCountFromContext(t *testing.T) { + set := flag.NewFlagSet("test", 0) + ctx := NewContext(nil, set, nil) + tf := &BoolFlag{Name: "tf", Aliases: []string{"w", "huh"}} + err := tf.Apply(set) + expect(t, err, nil) + + err = set.Parse([]string{"-tf", "-w", "-huh"}) + expect(t, err, nil) + expect(t, tf.Get(ctx), true) + expect(t, ctx.Count("tf"), 3) + + set1 := flag.NewFlagSet("test", 0) + ctx1 := NewContext(nil, set1, nil) + tf1 := &BoolFlag{Name: "tf", Aliases: []string{"w", "huh"}} + err = tf1.Apply(set1) + expect(t, err, nil) + + err = set1.Parse([]string{}) + expect(t, err, nil) + expect(t, tf1.Get(ctx1), false) + expect(t, ctx1.Count("tf"), 0) +} + func TestFlagsFromEnv(t *testing.T) { newSetFloat64Slice := func(defaults ...float64) Float64Slice { s := NewFloat64Slice(defaults...) diff --git a/godoc-current.txt b/godoc-current.txt index db2b5ac..382d4f7 100644 --- a/godoc-current.txt +++ b/godoc-current.txt @@ -633,6 +633,9 @@ func (cCtx *Context) Args() Args func (cCtx *Context) Bool(name string) bool Bool looks up the value of a local BoolFlag, returns false if not found +func (cCtx *Context) Count(name string) int + NumOccurrences returns the num of occurences of this flag + func (cCtx *Context) Duration(name string) time.Duration Duration looks up the value of a local DurationFlag, returns 0 if not found