diff --git a/context.go b/context.go index 0fc1532..2daa2e7 100644 --- a/context.go +++ b/context.go @@ -289,68 +289,67 @@ func normalizeFlags(flags []Flag, set *flag.FlagSet) error { type requiredFlagsErr interface { error - getMissingFlags() map[string]bool + getMissingDefaultFlags() []string + getMissingCustomFlags() []string } type errRequiredFlags struct { - missingFlags map[string]bool + missingDefaultFlags []string + missingCustomFlags []string } func (e *errRequiredFlags) Error() string { - var missingFlagNames []string - var missingFlagNamesReqErr []string - - for k, v := range e.missingFlags { - if v == false { - missingFlagNames = append(missingFlagNames, k) - } else { - missingFlagNamesReqErr = append(missingFlagNamesReqErr, k) - } - } - var allErrors []string - numberOfMissingFlags := len(missingFlagNames) - numberOfMissingReqErrFlags := len(missingFlagNamesReqErr) + numberOfMissingFlags := len(e.missingDefaultFlags) + numberOfMissingReqErrFlags := len(e.missingCustomFlags) if numberOfMissingFlags > 0 { if numberOfMissingFlags == 1 { - allErrors = append(allErrors, fmt.Sprintf("Required flag %q not set", missingFlagNames[0])) + allErrors = append(allErrors, fmt.Sprintf("Required flag %q not set", e.missingDefaultFlags[0])) } else { - joinedMissingFlags := strings.Join(missingFlagNames, ", ") + joinedMissingFlags := strings.Join(e.missingDefaultFlags, ", ") allErrors = append(allErrors, fmt.Sprintf("Required flags %q not set", joinedMissingFlags)) } } if numberOfMissingReqErrFlags > 0 { - - // handle user defined errors and append - + for i := range e.missingCustomFlags { + allErrors = append(allErrors, e.missingCustomFlags[i]) + } } return strings.Join(allErrors, "\n") } -func (e *errRequiredFlags) getMissingFlags() map[string]bool { - return e.missingFlags +func (e *errRequiredFlags) getMissingDefaultFlags() []string { + return e.missingDefaultFlags +} + +func (e *errRequiredFlags) getMissingCustomFlags() []string { + return e.missingCustomFlags } func checkRequiredFlags(flags []Flag, context *Context) requiredFlagsErr { - missingFlags := make(map[string]bool) + var missingDefaultFlags []string + var missingCustomFlags []string for _, f := range flags { if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() { key := strings.Split(f.GetName(), ",")[0] if !context.IsSet(key) { - if re, ok := f.(RequiredFlagsErr); ok && re.FlagsErrRequired() { - missingFlags[key] = true + if re, ok := f.(RequiredFlagErr); ok && re.IsCustom() { + missingCustomFlags = append(missingCustomFlags, re.GetMessage()) } else { - missingFlags[key] = false + missingDefaultFlags = append(missingDefaultFlags, key) } } } } - if len(missingFlags) != 0 { - return &errRequiredFlags{missingFlags: missingFlags} + if len(missingDefaultFlags) != 0 || len(missingCustomFlags) != 0 { + return &errRequiredFlags{ + missingDefaultFlags: missingDefaultFlags, + missingCustomFlags: missingCustomFlags, + } } return nil diff --git a/flag.go b/flag.go index 1f77c3e..be6b184 100644 --- a/flag.go +++ b/flag.go @@ -83,14 +83,13 @@ type RequiredFlag interface { IsRequired() bool } -// RequiredFlagsErr is an interface that allows users to redefine errors on required flags +// RequiredFlagErr is an interface that allows users to redefine errors on required flags // it allows flags with user-defined errors to be backwards compatible with the Flag interface type RequiredFlagErr interface { Flag IsCustom() bool GetMessage() string - HasInterpolation() bool } // DocGenerationFlag is an interface that allows documentation generation for the flag diff --git a/flag_generated.go b/flag_generated.go index 8cc022d..c4d1591 100644 --- a/flag_generated.go +++ b/flag_generated.go @@ -10,9 +10,8 @@ import ( ) type FlagErr struct { - Custom bool - Message string - Interpolate bool + Custom bool + Message string } // BoolFlag is a flag with type bool @@ -136,11 +135,6 @@ func (f BoolTFlag) GetMessage() string { return f.RequiredFlagErr.Message } -// FlagsErrRequired returns whether or not the flag is required -func (f BoolTFlag) FlagsErrRequired() bool { - return f.RequiredFlagsErr -} - // TakesValue returns true of the flag takes a value, otherwise false func (f BoolTFlag) TakesValue() bool { return false @@ -224,11 +218,6 @@ func (f DurationFlag) GetMessage() string { return f.RequiredFlagErr.Message } -// FlagsErrRequired returns whether or not the flag is required -func (f DurationFlag) FlagsErrRequired() bool { - return f.RequiredFlagsErr -} - // TakesValue returns true of the flag takes a value, otherwise false func (f DurationFlag) TakesValue() bool { return true