Add count option for bool flags

This commit is contained in:
Naveen Gogineni 2021-03-23 14:11:36 -04:00
parent 94c9951e4a
commit b4e0ec3b8a
2 changed files with 63 additions and 3 deletions

View File

@ -1,11 +1,55 @@
package cli
import (
"errors"
"flag"
"fmt"
"strconv"
)
// boolValue needs to implement the boolFlag internal interface in flag
// to be able to capture bool fields and values
// type boolFlag interface {
// Value
// IsBoolFlag() bool
// }
type boolValue struct {
destination *bool
count *int
}
func newBoolValue(val bool, p *bool, count *int) *boolValue {
*p = val
return &boolValue{
destination: p,
count: count,
}
}
func (b *boolValue) Set(s string) error {
v, err := strconv.ParseBool(s)
if err != nil {
err = errors.New("parse error")
return err
}
*b.destination = v
if b.count != nil {
*b.count = *b.count + 1
}
return err
}
func (b *boolValue) Get() interface{} { return *b.destination }
func (b *boolValue) String() string {
if b.destination != nil {
return strconv.FormatBool(*b.destination)
}
return strconv.FormatBool(false)
}
func (b *boolValue) IsBoolFlag() bool { return true }
// TakesValue returns true of the flag takes a value, otherwise false
func (f *BoolFlag) TakesValue() bool {
return false
@ -61,11 +105,14 @@ func (f *BoolFlag) Apply(set *flag.FlagSet) error {
}
for _, name := range f.Names() {
var value flag.Value
if f.Destination != nil {
set.BoolVar(f.Destination, name, f.Value, f.Usage)
continue
value = newBoolValue(f.Value, f.Destination, f.Count)
} else {
t := new(bool)
value = newBoolValue(f.Value, t, f.Count)
}
set.Bool(name, f.Value, f.Usage)
set.Var(value, name, f.Usage)
}
return nil

View File

@ -62,6 +62,19 @@ func TestBoolFlagValueFromContext(t *testing.T) {
expect(t, ff.Get(ctx), false)
}
func TestBoolFlagApply_SetsCount(t *testing.T) {
v := false
count := 0
fl := BoolFlag{Name: "wat", Aliases: []string{"W", "huh"}, Destination: &v, Count: &count}
set := flag.NewFlagSet("test", 0)
_ = fl.Apply(set)
err := set.Parse([]string{"--wat", "-W", "--huh"})
expect(t, err, nil)
expect(t, v, true)
expect(t, count, 3)
}
func TestFlagsFromEnv(t *testing.T) {
newSetFloat64Slice := func(defaults ...float64) Float64Slice {
s := NewFloat64Slice(defaults...)