adjust custom-error related interface, struct and methods to reflect change in RequiredFlag interface

main
Aaron Berns 5 years ago
parent eb1734ba59
commit 52a016034a

@ -289,68 +289,67 @@ func normalizeFlags(flags []Flag, set *flag.FlagSet) error {
type requiredFlagsErr interface { type requiredFlagsErr interface {
error error
getMissingFlags() map[string]bool getMissingDefaultFlags() []string
getMissingCustomFlags() []string
} }
type errRequiredFlags struct { type errRequiredFlags struct {
missingFlags map[string]bool missingDefaultFlags []string
missingCustomFlags []string
} }
func (e *errRequiredFlags) Error() string { func (e *errRequiredFlags) Error() string {
var missingFlagNames []string
var missingFlagNamesReqErr []string
for k, v := range e.missingFlags {
if v == false {
missingFlagNames = append(missingFlagNames, k)
} else {
missingFlagNamesReqErr = append(missingFlagNamesReqErr, k)
}
}
var allErrors []string var allErrors []string
numberOfMissingFlags := len(missingFlagNames) numberOfMissingFlags := len(e.missingDefaultFlags)
numberOfMissingReqErrFlags := len(missingFlagNamesReqErr) numberOfMissingReqErrFlags := len(e.missingCustomFlags)
if numberOfMissingFlags > 0 { if numberOfMissingFlags > 0 {
if numberOfMissingFlags == 1 { if numberOfMissingFlags == 1 {
allErrors = append(allErrors, fmt.Sprintf("Required flag %q not set", missingFlagNames[0])) allErrors = append(allErrors, fmt.Sprintf("Required flag %q not set", e.missingDefaultFlags[0]))
} else { } else {
joinedMissingFlags := strings.Join(missingFlagNames, ", ") joinedMissingFlags := strings.Join(e.missingDefaultFlags, ", ")
allErrors = append(allErrors, fmt.Sprintf("Required flags %q not set", joinedMissingFlags)) allErrors = append(allErrors, fmt.Sprintf("Required flags %q not set", joinedMissingFlags))
} }
} }
if numberOfMissingReqErrFlags > 0 { if numberOfMissingReqErrFlags > 0 {
for i := range e.missingCustomFlags {
// handle user defined errors and append allErrors = append(allErrors, e.missingCustomFlags[i])
}
} }
return strings.Join(allErrors, "\n") return strings.Join(allErrors, "\n")
} }
func (e *errRequiredFlags) getMissingFlags() map[string]bool { func (e *errRequiredFlags) getMissingDefaultFlags() []string {
return e.missingFlags return e.missingDefaultFlags
}
func (e *errRequiredFlags) getMissingCustomFlags() []string {
return e.missingCustomFlags
} }
func checkRequiredFlags(flags []Flag, context *Context) requiredFlagsErr { func checkRequiredFlags(flags []Flag, context *Context) requiredFlagsErr {
missingFlags := make(map[string]bool) var missingDefaultFlags []string
var missingCustomFlags []string
for _, f := range flags { for _, f := range flags {
if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() { if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() {
key := strings.Split(f.GetName(), ",")[0] key := strings.Split(f.GetName(), ",")[0]
if !context.IsSet(key) { if !context.IsSet(key) {
if re, ok := f.(RequiredFlagsErr); ok && re.FlagsErrRequired() { if re, ok := f.(RequiredFlagErr); ok && re.IsCustom() {
missingFlags[key] = true missingCustomFlags = append(missingCustomFlags, re.GetMessage())
} else { } else {
missingFlags[key] = false missingDefaultFlags = append(missingDefaultFlags, key)
} }
} }
} }
} }
if len(missingFlags) != 0 { if len(missingDefaultFlags) != 0 || len(missingCustomFlags) != 0 {
return &errRequiredFlags{missingFlags: missingFlags} return &errRequiredFlags{
missingDefaultFlags: missingDefaultFlags,
missingCustomFlags: missingCustomFlags,
}
} }
return nil return nil

@ -83,14 +83,13 @@ type RequiredFlag interface {
IsRequired() bool IsRequired() bool
} }
// RequiredFlagsErr is an interface that allows users to redefine errors on required flags // RequiredFlagErr is an interface that allows users to redefine errors on required flags
// it allows flags with user-defined errors to be backwards compatible with the Flag interface // it allows flags with user-defined errors to be backwards compatible with the Flag interface
type RequiredFlagErr interface { type RequiredFlagErr interface {
Flag Flag
IsCustom() bool IsCustom() bool
GetMessage() string GetMessage() string
HasInterpolation() bool
} }
// DocGenerationFlag is an interface that allows documentation generation for the flag // DocGenerationFlag is an interface that allows documentation generation for the flag

@ -10,9 +10,8 @@ import (
) )
type FlagErr struct { type FlagErr struct {
Custom bool Custom bool
Message string Message string
Interpolate bool
} }
// BoolFlag is a flag with type bool // BoolFlag is a flag with type bool
@ -136,11 +135,6 @@ func (f BoolTFlag) GetMessage() string {
return f.RequiredFlagErr.Message return f.RequiredFlagErr.Message
} }
// FlagsErrRequired returns whether or not the flag is required
func (f BoolTFlag) FlagsErrRequired() bool {
return f.RequiredFlagsErr
}
// TakesValue returns true of the flag takes a value, otherwise false // TakesValue returns true of the flag takes a value, otherwise false
func (f BoolTFlag) TakesValue() bool { func (f BoolTFlag) TakesValue() bool {
return false return false
@ -224,11 +218,6 @@ func (f DurationFlag) GetMessage() string {
return f.RequiredFlagErr.Message return f.RequiredFlagErr.Message
} }
// FlagsErrRequired returns whether or not the flag is required
func (f DurationFlag) FlagsErrRequired() bool {
return f.RequiredFlagsErr
}
// TakesValue returns true of the flag takes a value, otherwise false // TakesValue returns true of the flag takes a value, otherwise false
func (f DurationFlag) TakesValue() bool { func (f DurationFlag) TakesValue() bool {
return true return true

Loading…
Cancel
Save