adjust custom-error related interface, struct and methods to reflect change in RequiredFlag interface
This commit is contained in:
parent
eb1734ba59
commit
52a016034a
55
context.go
55
context.go
@ -289,68 +289,67 @@ func normalizeFlags(flags []Flag, set *flag.FlagSet) error {
|
||||
|
||||
type requiredFlagsErr interface {
|
||||
error
|
||||
getMissingFlags() map[string]bool
|
||||
getMissingDefaultFlags() []string
|
||||
getMissingCustomFlags() []string
|
||||
}
|
||||
|
||||
type errRequiredFlags struct {
|
||||
missingFlags map[string]bool
|
||||
missingDefaultFlags []string
|
||||
missingCustomFlags []string
|
||||
}
|
||||
|
||||
func (e *errRequiredFlags) Error() string {
|
||||
var missingFlagNames []string
|
||||
var missingFlagNamesReqErr []string
|
||||
|
||||
for k, v := range e.missingFlags {
|
||||
if v == false {
|
||||
missingFlagNames = append(missingFlagNames, k)
|
||||
} else {
|
||||
missingFlagNamesReqErr = append(missingFlagNamesReqErr, k)
|
||||
}
|
||||
}
|
||||
|
||||
var allErrors []string
|
||||
numberOfMissingFlags := len(missingFlagNames)
|
||||
numberOfMissingReqErrFlags := len(missingFlagNamesReqErr)
|
||||
numberOfMissingFlags := len(e.missingDefaultFlags)
|
||||
numberOfMissingReqErrFlags := len(e.missingCustomFlags)
|
||||
|
||||
if numberOfMissingFlags > 0 {
|
||||
if numberOfMissingFlags == 1 {
|
||||
allErrors = append(allErrors, fmt.Sprintf("Required flag %q not set", missingFlagNames[0]))
|
||||
allErrors = append(allErrors, fmt.Sprintf("Required flag %q not set", e.missingDefaultFlags[0]))
|
||||
} else {
|
||||
joinedMissingFlags := strings.Join(missingFlagNames, ", ")
|
||||
joinedMissingFlags := strings.Join(e.missingDefaultFlags, ", ")
|
||||
allErrors = append(allErrors, fmt.Sprintf("Required flags %q not set", joinedMissingFlags))
|
||||
}
|
||||
}
|
||||
|
||||
if numberOfMissingReqErrFlags > 0 {
|
||||
|
||||
// handle user defined errors and append
|
||||
|
||||
for i := range e.missingCustomFlags {
|
||||
allErrors = append(allErrors, e.missingCustomFlags[i])
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(allErrors, "\n")
|
||||
}
|
||||
|
||||
func (e *errRequiredFlags) getMissingFlags() map[string]bool {
|
||||
return e.missingFlags
|
||||
func (e *errRequiredFlags) getMissingDefaultFlags() []string {
|
||||
return e.missingDefaultFlags
|
||||
}
|
||||
|
||||
func (e *errRequiredFlags) getMissingCustomFlags() []string {
|
||||
return e.missingCustomFlags
|
||||
}
|
||||
|
||||
func checkRequiredFlags(flags []Flag, context *Context) requiredFlagsErr {
|
||||
missingFlags := make(map[string]bool)
|
||||
var missingDefaultFlags []string
|
||||
var missingCustomFlags []string
|
||||
for _, f := range flags {
|
||||
if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() {
|
||||
key := strings.Split(f.GetName(), ",")[0]
|
||||
if !context.IsSet(key) {
|
||||
if re, ok := f.(RequiredFlagsErr); ok && re.FlagsErrRequired() {
|
||||
missingFlags[key] = true
|
||||
if re, ok := f.(RequiredFlagErr); ok && re.IsCustom() {
|
||||
missingCustomFlags = append(missingCustomFlags, re.GetMessage())
|
||||
} else {
|
||||
missingFlags[key] = false
|
||||
missingDefaultFlags = append(missingDefaultFlags, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingFlags) != 0 {
|
||||
return &errRequiredFlags{missingFlags: missingFlags}
|
||||
if len(missingDefaultFlags) != 0 || len(missingCustomFlags) != 0 {
|
||||
return &errRequiredFlags{
|
||||
missingDefaultFlags: missingDefaultFlags,
|
||||
missingCustomFlags: missingCustomFlags,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
3
flag.go
3
flag.go
@ -83,14 +83,13 @@ type RequiredFlag interface {
|
||||
IsRequired() bool
|
||||
}
|
||||
|
||||
// RequiredFlagsErr is an interface that allows users to redefine errors on required flags
|
||||
// RequiredFlagErr is an interface that allows users to redefine errors on required flags
|
||||
// it allows flags with user-defined errors to be backwards compatible with the Flag interface
|
||||
type RequiredFlagErr interface {
|
||||
Flag
|
||||
|
||||
IsCustom() bool
|
||||
GetMessage() string
|
||||
HasInterpolation() bool
|
||||
}
|
||||
|
||||
// DocGenerationFlag is an interface that allows documentation generation for the flag
|
||||
|
@ -10,9 +10,8 @@ import (
|
||||
)
|
||||
|
||||
type FlagErr struct {
|
||||
Custom bool
|
||||
Message string
|
||||
Interpolate bool
|
||||
Custom bool
|
||||
Message string
|
||||
}
|
||||
|
||||
// BoolFlag is a flag with type bool
|
||||
@ -136,11 +135,6 @@ func (f BoolTFlag) GetMessage() string {
|
||||
return f.RequiredFlagErr.Message
|
||||
}
|
||||
|
||||
// FlagsErrRequired returns whether or not the flag is required
|
||||
func (f BoolTFlag) FlagsErrRequired() bool {
|
||||
return f.RequiredFlagsErr
|
||||
}
|
||||
|
||||
// TakesValue returns true of the flag takes a value, otherwise false
|
||||
func (f BoolTFlag) TakesValue() bool {
|
||||
return false
|
||||
@ -224,11 +218,6 @@ func (f DurationFlag) GetMessage() string {
|
||||
return f.RequiredFlagErr.Message
|
||||
}
|
||||
|
||||
// FlagsErrRequired returns whether or not the flag is required
|
||||
func (f DurationFlag) FlagsErrRequired() bool {
|
||||
return f.RequiredFlagsErr
|
||||
}
|
||||
|
||||
// TakesValue returns true of the flag takes a value, otherwise false
|
||||
func (f DurationFlag) TakesValue() bool {
|
||||
return true
|
||||
|
Loading…
Reference in New Issue
Block a user