adjust custom-error related interface, struct and methods to reflect change in RequiredFlag interface

main
Aaron Berns 5 years ago
parent eb1734ba59
commit 52a016034a

@ -289,68 +289,67 @@ func normalizeFlags(flags []Flag, set *flag.FlagSet) error {
type requiredFlagsErr interface {
error
getMissingFlags() map[string]bool
getMissingDefaultFlags() []string
getMissingCustomFlags() []string
}
type errRequiredFlags struct {
missingFlags map[string]bool
missingDefaultFlags []string
missingCustomFlags []string
}
func (e *errRequiredFlags) Error() string {
var missingFlagNames []string
var missingFlagNamesReqErr []string
for k, v := range e.missingFlags {
if v == false {
missingFlagNames = append(missingFlagNames, k)
} else {
missingFlagNamesReqErr = append(missingFlagNamesReqErr, k)
}
}
var allErrors []string
numberOfMissingFlags := len(missingFlagNames)
numberOfMissingReqErrFlags := len(missingFlagNamesReqErr)
numberOfMissingFlags := len(e.missingDefaultFlags)
numberOfMissingReqErrFlags := len(e.missingCustomFlags)
if numberOfMissingFlags > 0 {
if numberOfMissingFlags == 1 {
allErrors = append(allErrors, fmt.Sprintf("Required flag %q not set", missingFlagNames[0]))
allErrors = append(allErrors, fmt.Sprintf("Required flag %q not set", e.missingDefaultFlags[0]))
} else {
joinedMissingFlags := strings.Join(missingFlagNames, ", ")
joinedMissingFlags := strings.Join(e.missingDefaultFlags, ", ")
allErrors = append(allErrors, fmt.Sprintf("Required flags %q not set", joinedMissingFlags))
}
}
if numberOfMissingReqErrFlags > 0 {
// handle user defined errors and append
for i := range e.missingCustomFlags {
allErrors = append(allErrors, e.missingCustomFlags[i])
}
}
return strings.Join(allErrors, "\n")
}
func (e *errRequiredFlags) getMissingFlags() map[string]bool {
return e.missingFlags
func (e *errRequiredFlags) getMissingDefaultFlags() []string {
return e.missingDefaultFlags
}
func (e *errRequiredFlags) getMissingCustomFlags() []string {
return e.missingCustomFlags
}
func checkRequiredFlags(flags []Flag, context *Context) requiredFlagsErr {
missingFlags := make(map[string]bool)
var missingDefaultFlags []string
var missingCustomFlags []string
for _, f := range flags {
if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() {
key := strings.Split(f.GetName(), ",")[0]
if !context.IsSet(key) {
if re, ok := f.(RequiredFlagsErr); ok && re.FlagsErrRequired() {
missingFlags[key] = true
if re, ok := f.(RequiredFlagErr); ok && re.IsCustom() {
missingCustomFlags = append(missingCustomFlags, re.GetMessage())
} else {
missingFlags[key] = false
missingDefaultFlags = append(missingDefaultFlags, key)
}
}
}
}
if len(missingFlags) != 0 {
return &errRequiredFlags{missingFlags: missingFlags}
if len(missingDefaultFlags) != 0 || len(missingCustomFlags) != 0 {
return &errRequiredFlags{
missingDefaultFlags: missingDefaultFlags,
missingCustomFlags: missingCustomFlags,
}
}
return nil

@ -83,14 +83,13 @@ type RequiredFlag interface {
IsRequired() bool
}
// RequiredFlagsErr is an interface that allows users to redefine errors on required flags
// RequiredFlagErr is an interface that allows users to redefine errors on required flags
// it allows flags with user-defined errors to be backwards compatible with the Flag interface
type RequiredFlagErr interface {
Flag
IsCustom() bool
GetMessage() string
HasInterpolation() bool
}
// DocGenerationFlag is an interface that allows documentation generation for the flag

@ -10,9 +10,8 @@ import (
)
type FlagErr struct {
Custom bool
Message string
Interpolate bool
Custom bool
Message string
}
// BoolFlag is a flag with type bool
@ -136,11 +135,6 @@ func (f BoolTFlag) GetMessage() string {
return f.RequiredFlagErr.Message
}
// FlagsErrRequired returns whether or not the flag is required
func (f BoolTFlag) FlagsErrRequired() bool {
return f.RequiredFlagsErr
}
// TakesValue returns true of the flag takes a value, otherwise false
func (f BoolTFlag) TakesValue() bool {
return false
@ -224,11 +218,6 @@ func (f DurationFlag) GetMessage() string {
return f.RequiredFlagErr.Message
}
// FlagsErrRequired returns whether or not the flag is required
func (f DurationFlag) FlagsErrRequired() bool {
return f.RequiredFlagsErr
}
// TakesValue returns true of the flag takes a value, otherwise false
func (f DurationFlag) TakesValue() bool {
return true

Loading…
Cancel
Save