diff --git a/context.go b/context.go index 85ba419..4273389 100644 --- a/context.go +++ b/context.go @@ -289,23 +289,34 @@ func normalizeFlags(flags []Flag, set *flag.FlagSet) error { type requiredFlagsErr interface { error - getMissingFlags() []string + getMissingFlags() map[string]bool } type errRequiredFlags struct { - missingFlags []string + missingFlags map[string]bool } func (e *errRequiredFlags) Error() string { - numberOfMissingFlags := len(e.missingFlags) + var missingFlagNames []string + var missingFlagNamesReqErr []string + + for k, v := range e.missingFlags { + if v == false { + missingFlagNames = append(missingFlagNames, k) + } else { + missingFlagNamesReqErr = append(missingFlagNamesReqErr, k) + } + } + + numberOfMissingFlags := len(missingFlagNames) if numberOfMissingFlags == 1 { - return fmt.Sprintf("Required flag %q not set", e.missingFlags[0]) + return fmt.Sprintf("Required flag %q not set", missingFlagNames[0]) } - joinedMissingFlags := strings.Join(e.missingFlags, ", ") + joinedMissingFlags := strings.Join(missingFlagNames, ", ") return fmt.Sprintf("Required flags %q not set", joinedMissingFlags) } -func (e *errRequiredFlags) getMissingFlags() []string { +func (e *errRequiredFlags) getMissingFlags() map[string]bool { return e.missingFlags }