diff --git a/app.go b/app.go index 2c97251..5c616e6 100644 --- a/app.go +++ b/app.go @@ -278,7 +278,7 @@ func (a *App) RunContext(ctx context.Context, arguments []string) (err error) { return nil } - cerr := checkRequiredFlags(a.Flags, context) + cerr := context.checkRequiredFlags(a.Flags) if cerr != nil { _ = ShowAppHelp(context) return cerr @@ -397,7 +397,7 @@ func (a *App) RunAsSubcommand(ctx *Context) (err error) { } } - cerr := checkRequiredFlags(a.Flags, context) + cerr := context.checkRequiredFlags(a.Flags) if cerr != nil { _ = ShowSubcommandHelp(context) return cerr diff --git a/command.go b/command.go index 3503a55..9a6b877 100644 --- a/command.go +++ b/command.go @@ -127,7 +127,7 @@ func (c *Command) Run(ctx *Context) (err error) { return nil } - cerr := checkRequiredFlags(c.Flags, context) + cerr := context.checkRequiredFlags(c.Flags) if cerr != nil { _ = ShowCommandHelp(context, c.Name) return cerr diff --git a/context.go b/context.go index 65e0d1e..94cbb65 100644 --- a/context.go +++ b/context.go @@ -2,9 +2,7 @@ package cli import ( "context" - "errors" "flag" - "fmt" "strings" ) @@ -53,20 +51,18 @@ func (c *Context) Set(name, value string) error { // IsSet determines if the flag was actually set func (c *Context) IsSet(name string) bool { - if fs := lookupFlagSet(name, c); fs != nil { - if fs := lookupFlagSet(name, c); fs != nil { - isSet := false - fs.Visit(func(f *flag.Flag) { - if f.Name == name { - isSet = true - } - }) - if isSet { - return true + if fs := c.lookupFlagSet(name); fs != nil { + isSet := false + fs.Visit(func(f *flag.Flag) { + if f.Name == name { + isSet = true } + }) + if isSet { + return true } - f := lookupFlag(name, c) + f := c.lookupFlag(name) if f == nil { return false } @@ -108,7 +104,7 @@ func (c *Context) Lineage() []*Context { // Value returns the value of the flag corresponding to `name` func (c *Context) Value(name string) interface{} { - if fs := lookupFlagSet(name, c); fs != nil { + if fs := c.lookupFlagSet(name); fs != nil { return fs.Lookup(name).Value.(flag.Getter).Get() } return nil @@ -125,7 +121,7 @@ func (c *Context) NArg() int { return c.Args().Len() } -func lookupFlag(name string, ctx *Context) Flag { +func (ctx *Context) lookupFlag(name string) Flag { for _, c := range ctx.Lineage() { if c.Command == nil { continue @@ -153,7 +149,7 @@ func lookupFlag(name string, ctx *Context) Flag { return nil } -func lookupFlagSet(name string, ctx *Context) *flag.FlagSet { +func (ctx *Context) lookupFlagSet(name string) *flag.FlagSet { for _, c := range ctx.Lineage() { if f := c.flagSet.Lookup(name); f != nil { return c.flagSet @@ -163,89 +159,7 @@ func lookupFlagSet(name string, ctx *Context) *flag.FlagSet { return nil } -func copyFlag(name string, ff *flag.Flag, set *flag.FlagSet) { - switch ff.Value.(type) { - case Serializer: - _ = set.Set(name, ff.Value.(Serializer).Serialize()) - default: - _ = set.Set(name, ff.Value.String()) - } -} - -func normalizeFlags(flags []Flag, set *flag.FlagSet) error { - visited := make(map[string]bool) - set.Visit(func(f *flag.Flag) { - visited[f.Name] = true - }) - for _, f := range flags { - parts := f.Names() - if len(parts) == 1 { - continue - } - var ff *flag.Flag - for _, name := range parts { - name = strings.Trim(name, " ") - if visited[name] { - if ff != nil { - return errors.New("Cannot use two forms of the same flag: " + name + " " + ff.Name) - } - ff = set.Lookup(name) - } - } - if ff == nil { - continue - } - for _, name := range parts { - name = strings.Trim(name, " ") - if !visited[name] { - copyFlag(name, ff, set) - } - } - } - return nil -} - -func makeFlagNameVisitor(names *[]string) func(*flag.Flag) { - return func(f *flag.Flag) { - nameParts := strings.Split(f.Name, ",") - name := strings.TrimSpace(nameParts[0]) - - for _, part := range nameParts { - part = strings.TrimSpace(part) - if len(part) > len(name) { - name = part - } - } - - if name != "" { - *names = append(*names, name) - } - } -} - -type requiredFlagsErr interface { - error - getMissingFlags() []string -} - -type errRequiredFlags struct { - missingFlags []string -} - -func (e *errRequiredFlags) Error() string { - numberOfMissingFlags := len(e.missingFlags) - if numberOfMissingFlags == 1 { - return fmt.Sprintf("Required flag %q not set", e.missingFlags[0]) - } - joinedMissingFlags := strings.Join(e.missingFlags, ", ") - return fmt.Sprintf("Required flags %q not set", joinedMissingFlags) -} - -func (e *errRequiredFlags) getMissingFlags() []string { - return e.missingFlags -} - -func checkRequiredFlags(flags []Flag, context *Context) requiredFlagsErr { +func (context *Context) checkRequiredFlags(flags []Flag) requiredFlagsErr { var missingFlags []string for _, f := range flags { if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() { @@ -274,3 +188,21 @@ func checkRequiredFlags(flags []Flag, context *Context) requiredFlagsErr { return nil } + +func makeFlagNameVisitor(names *[]string) func(*flag.Flag) { + return func(f *flag.Flag) { + nameParts := strings.Split(f.Name, ",") + name := strings.TrimSpace(nameParts[0]) + + for _, part := range nameParts { + part = strings.TrimSpace(part) + if len(part) > len(name) { + name = part + } + } + + if name != "" { + *names = append(*names, name) + } + } +} diff --git a/context_test.go b/context_test.go index 35feefe..b37876c 100644 --- a/context_test.go +++ b/context_test.go @@ -316,13 +316,13 @@ func TestContext_lookupFlagSet(t *testing.T) { _ = set.Parse([]string{"--local-flag"}) _ = parentSet.Parse([]string{"--top-flag"}) - fs := lookupFlagSet("top-flag", ctx) + fs := ctx.lookupFlagSet("top-flag") expect(t, fs, parentCtx.flagSet) - fs = lookupFlagSet("local-flag", ctx) + fs = ctx.lookupFlagSet("local-flag") expect(t, fs, ctx.flagSet) - if fs := lookupFlagSet("frob", ctx); fs != nil { + if fs := ctx.lookupFlagSet("frob"); fs != nil { t.Fail() } } @@ -576,7 +576,7 @@ func TestCheckRequiredFlags(t *testing.T) { ctx.Command.Flags = test.flags // logic under test - err := checkRequiredFlags(test.flags, ctx) + err := ctx.checkRequiredFlags(test.flags) // assertions if test.expectedAnError && err == nil { diff --git a/errors.go b/errors.go index 751ef9b..8f641fb 100644 --- a/errors.go +++ b/errors.go @@ -47,6 +47,28 @@ func (m *multiError) Errors() []error { return errs } +type requiredFlagsErr interface { + error + getMissingFlags() []string +} + +type errRequiredFlags struct { + missingFlags []string +} + +func (e *errRequiredFlags) Error() string { + numberOfMissingFlags := len(e.missingFlags) + if numberOfMissingFlags == 1 { + return fmt.Sprintf("Required flag %q not set", e.missingFlags[0]) + } + joinedMissingFlags := strings.Join(e.missingFlags, ", ") + return fmt.Sprintf("Required flags %q not set", joinedMissingFlags) +} + +func (e *errRequiredFlags) getMissingFlags() []string { + return e.missingFlags +} + // ErrorFormatter is the interface that will suitably format the error output type ErrorFormatter interface { Format(s fmt.State, verb rune) diff --git a/flag.go b/flag.go index aff8d5b..a693386 100644 --- a/flag.go +++ b/flag.go @@ -1,6 +1,7 @@ package cli import ( + "errors" "flag" "fmt" "io/ioutil" @@ -130,6 +131,48 @@ func flagSet(name string, flags []Flag) (*flag.FlagSet, error) { return set, nil } +func copyFlag(name string, ff *flag.Flag, set *flag.FlagSet) { + switch ff.Value.(type) { + case Serializer: + _ = set.Set(name, ff.Value.(Serializer).Serialize()) + default: + _ = set.Set(name, ff.Value.String()) + } +} + +func normalizeFlags(flags []Flag, set *flag.FlagSet) error { + visited := make(map[string]bool) + set.Visit(func(f *flag.Flag) { + visited[f.Name] = true + }) + for _, f := range flags { + parts := f.Names() + if len(parts) == 1 { + continue + } + var ff *flag.Flag + for _, name := range parts { + name = strings.Trim(name, " ") + if visited[name] { + if ff != nil { + return errors.New("Cannot use two forms of the same flag: " + name + " " + ff.Name) + } + ff = set.Lookup(name) + } + } + if ff == nil { + continue + } + for _, name := range parts { + name = strings.Trim(name, " ") + if !visited[name] { + copyFlag(name, ff, set) + } + } + } + return nil +} + func visibleFlags(fl []Flag) []Flag { var visible []Flag for _, f := range fl { diff --git a/flag_bool.go b/flag_bool.go index bc9ea35..85270e4 100644 --- a/flag_bool.go +++ b/flag_bool.go @@ -87,7 +87,7 @@ func (f *BoolFlag) Apply(set *flag.FlagSet) error { // Bool looks up the value of a local BoolFlag, returns // false if not found func (c *Context) Bool(name string) bool { - if fs := lookupFlagSet(name, c); fs != nil { + if fs := c.lookupFlagSet(name); fs != nil { return lookupBool(name, fs) } return false diff --git a/flag_duration.go b/flag_duration.go index 22a2e67..7b59a38 100644 --- a/flag_duration.go +++ b/flag_duration.go @@ -86,7 +86,7 @@ func (f *DurationFlag) Apply(set *flag.FlagSet) error { // Duration looks up the value of a local DurationFlag, returns // 0 if not found func (c *Context) Duration(name string) time.Duration { - if fs := lookupFlagSet(name, c); fs != nil { + if fs := c.lookupFlagSet(name); fs != nil { return lookupDuration(name, fs) } return 0 diff --git a/flag_float64.go b/flag_float64.go index 91c778c..d2a6458 100644 --- a/flag_float64.go +++ b/flag_float64.go @@ -87,7 +87,7 @@ func (f *Float64Flag) Apply(set *flag.FlagSet) error { // Float64 looks up the value of a local Float64Flag, returns // 0 if not found func (c *Context) Float64(name string) float64 { - if fs := lookupFlagSet(name, c); fs != nil { + if fs := c.lookupFlagSet(name); fs != nil { return lookupFloat64(name, fs) } return 0 diff --git a/flag_float64_slice.go b/flag_float64_slice.go index 706ee6c..49a04d4 100644 --- a/flag_float64_slice.go +++ b/flag_float64_slice.go @@ -146,7 +146,7 @@ func (f *Float64SliceFlag) Apply(set *flag.FlagSet) error { // Float64Slice looks up the value of a local Float64SliceFlag, returns // nil if not found func (c *Context) Float64Slice(name string) []float64 { - if fs := lookupFlagSet(name, c); fs != nil { + if fs := c.lookupFlagSet(name); fs != nil { return lookupFloat64Slice(name, fs) } return nil diff --git a/flag_generic.go b/flag_generic.go index b0c8ff4..d6800a8 100644 --- a/flag_generic.go +++ b/flag_generic.go @@ -89,7 +89,7 @@ func (f GenericFlag) Apply(set *flag.FlagSet) error { // Generic looks up the value of a local GenericFlag, returns // nil if not found func (c *Context) Generic(name string) interface{} { - if fs := lookupFlagSet(name, c); fs != nil { + if fs := c.lookupFlagSet(name); fs != nil { return lookupGeneric(name, fs) } return nil diff --git a/flag_int.go b/flag_int.go index ac39d4a..e9da3fa 100644 --- a/flag_int.go +++ b/flag_int.go @@ -87,7 +87,7 @@ func (f *IntFlag) Apply(set *flag.FlagSet) error { // Int looks up the value of a local IntFlag, returns // 0 if not found func (c *Context) Int(name string) int { - if fs := lookupFlagSet(name, c); fs != nil { + if fs := c.lookupFlagSet(name); fs != nil { return lookupInt(name, fs) } return 0 diff --git a/flag_int64.go b/flag_int64.go index e099912..6c55458 100644 --- a/flag_int64.go +++ b/flag_int64.go @@ -86,7 +86,7 @@ func (f *Int64Flag) Apply(set *flag.FlagSet) error { // Int64 looks up the value of a local Int64Flag, returns // 0 if not found func (c *Context) Int64(name string) int64 { - if fs := lookupFlagSet(name, c); fs != nil { + if fs := c.lookupFlagSet(name); fs != nil { return lookupInt64(name, fs) } return 0 diff --git a/flag_int64_slice.go b/flag_int64_slice.go index 2c9a15a..773ef8a 100644 --- a/flag_int64_slice.go +++ b/flag_int64_slice.go @@ -145,7 +145,7 @@ func (f *Int64SliceFlag) Apply(set *flag.FlagSet) error { // Int64Slice looks up the value of a local Int64SliceFlag, returns // nil if not found func (c *Context) Int64Slice(name string) []int64 { - if fs := lookupFlagSet(name, c); fs != nil { + if fs := c.lookupFlagSet(name); fs != nil { return lookupInt64Slice(name, fs) } return nil diff --git a/flag_int_slice.go b/flag_int_slice.go index a73ca6b..8feef5f 100644 --- a/flag_int_slice.go +++ b/flag_int_slice.go @@ -156,7 +156,7 @@ func (f *IntSliceFlag) Apply(set *flag.FlagSet) error { // IntSlice looks up the value of a local IntSliceFlag, returns // nil if not found func (c *Context) IntSlice(name string) []int { - if fs := lookupFlagSet(name, c); fs != nil { + if fs := c.lookupFlagSet(name); fs != nil { return lookupIntSlice(name, fs) } return nil diff --git a/flag_path.go b/flag_path.go index 8070dc4..37c6f27 100644 --- a/flag_path.go +++ b/flag_path.go @@ -75,7 +75,7 @@ func (f *PathFlag) Apply(set *flag.FlagSet) error { // Path looks up the value of a local PathFlag, returns // "" if not found func (c *Context) Path(name string) string { - if fs := lookupFlagSet(name, c); fs != nil { + if fs := c.lookupFlagSet(name); fs != nil { return lookupPath(name, fs) } diff --git a/flag_string.go b/flag_string.go index 400bb53..a43f7c2 100644 --- a/flag_string.go +++ b/flag_string.go @@ -76,7 +76,7 @@ func (f *StringFlag) Apply(set *flag.FlagSet) error { // String looks up the value of a local StringFlag, returns // "" if not found func (c *Context) String(name string) string { - if fs := lookupFlagSet(name, c); fs != nil { + if fs := c.lookupFlagSet(name); fs != nil { return lookupString(name, fs) } return "" diff --git a/flag_string_slice.go b/flag_string_slice.go index 3549703..3934a60 100644 --- a/flag_string_slice.go +++ b/flag_string_slice.go @@ -163,7 +163,7 @@ func (f *StringSliceFlag) Apply(set *flag.FlagSet) error { // StringSlice looks up the value of a local StringSliceFlag, returns // nil if not found func (c *Context) StringSlice(name string) []string { - if fs := lookupFlagSet(name, c); fs != nil { + if fs := c.lookupFlagSet(name); fs != nil { return lookupStringSlice(name, fs) } return nil diff --git a/flag_timestamp.go b/flag_timestamp.go index da95512..8266e23 100644 --- a/flag_timestamp.go +++ b/flag_timestamp.go @@ -148,7 +148,7 @@ func (f *TimestampFlag) Apply(set *flag.FlagSet) error { // Timestamp gets the timestamp from a flag name func (c *Context) Timestamp(name string) *time.Time { - if fs := lookupFlagSet(name, c); fs != nil { + if fs := c.lookupFlagSet(name); fs != nil { return lookupTimestamp(name, fs) } return nil diff --git a/flag_uint.go b/flag_uint.go index 2e5e76b..6730e69 100644 --- a/flag_uint.go +++ b/flag_uint.go @@ -86,7 +86,7 @@ func (f *UintFlag) GetValue() string { // Uint looks up the value of a local UintFlag, returns // 0 if not found func (c *Context) Uint(name string) uint { - if fs := lookupFlagSet(name, c); fs != nil { + if fs := c.lookupFlagSet(name); fs != nil { return lookupUint(name, fs) } return 0 diff --git a/flag_uint64.go b/flag_uint64.go index 8fc3289..4af65fa 100644 --- a/flag_uint64.go +++ b/flag_uint64.go @@ -86,7 +86,7 @@ func (f *Uint64Flag) GetValue() string { // Uint64 looks up the value of a local Uint64Flag, returns // 0 if not found func (c *Context) Uint64(name string) uint64 { - if fs := lookupFlagSet(name, c); fs != nil { + if fs := c.lookupFlagSet(name); fs != nil { return lookupUint64(name, fs) } return 0