adjust custom-error related interface, struct and methods to reflect change in RequiredFlag interface
This commit is contained in:
parent
eb1734ba59
commit
52a016034a
55
context.go
55
context.go
@ -289,68 +289,67 @@ func normalizeFlags(flags []Flag, set *flag.FlagSet) error {
|
|||||||
|
|
||||||
type requiredFlagsErr interface {
|
type requiredFlagsErr interface {
|
||||||
error
|
error
|
||||||
getMissingFlags() map[string]bool
|
getMissingDefaultFlags() []string
|
||||||
|
getMissingCustomFlags() []string
|
||||||
}
|
}
|
||||||
|
|
||||||
type errRequiredFlags struct {
|
type errRequiredFlags struct {
|
||||||
missingFlags map[string]bool
|
missingDefaultFlags []string
|
||||||
|
missingCustomFlags []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *errRequiredFlags) Error() string {
|
func (e *errRequiredFlags) Error() string {
|
||||||
var missingFlagNames []string
|
|
||||||
var missingFlagNamesReqErr []string
|
|
||||||
|
|
||||||
for k, v := range e.missingFlags {
|
|
||||||
if v == false {
|
|
||||||
missingFlagNames = append(missingFlagNames, k)
|
|
||||||
} else {
|
|
||||||
missingFlagNamesReqErr = append(missingFlagNamesReqErr, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var allErrors []string
|
var allErrors []string
|
||||||
numberOfMissingFlags := len(missingFlagNames)
|
numberOfMissingFlags := len(e.missingDefaultFlags)
|
||||||
numberOfMissingReqErrFlags := len(missingFlagNamesReqErr)
|
numberOfMissingReqErrFlags := len(e.missingCustomFlags)
|
||||||
|
|
||||||
if numberOfMissingFlags > 0 {
|
if numberOfMissingFlags > 0 {
|
||||||
if numberOfMissingFlags == 1 {
|
if numberOfMissingFlags == 1 {
|
||||||
allErrors = append(allErrors, fmt.Sprintf("Required flag %q not set", missingFlagNames[0]))
|
allErrors = append(allErrors, fmt.Sprintf("Required flag %q not set", e.missingDefaultFlags[0]))
|
||||||
} else {
|
} else {
|
||||||
joinedMissingFlags := strings.Join(missingFlagNames, ", ")
|
joinedMissingFlags := strings.Join(e.missingDefaultFlags, ", ")
|
||||||
allErrors = append(allErrors, fmt.Sprintf("Required flags %q not set", joinedMissingFlags))
|
allErrors = append(allErrors, fmt.Sprintf("Required flags %q not set", joinedMissingFlags))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if numberOfMissingReqErrFlags > 0 {
|
if numberOfMissingReqErrFlags > 0 {
|
||||||
|
for i := range e.missingCustomFlags {
|
||||||
// handle user defined errors and append
|
allErrors = append(allErrors, e.missingCustomFlags[i])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Join(allErrors, "\n")
|
return strings.Join(allErrors, "\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *errRequiredFlags) getMissingFlags() map[string]bool {
|
func (e *errRequiredFlags) getMissingDefaultFlags() []string {
|
||||||
return e.missingFlags
|
return e.missingDefaultFlags
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errRequiredFlags) getMissingCustomFlags() []string {
|
||||||
|
return e.missingCustomFlags
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkRequiredFlags(flags []Flag, context *Context) requiredFlagsErr {
|
func checkRequiredFlags(flags []Flag, context *Context) requiredFlagsErr {
|
||||||
missingFlags := make(map[string]bool)
|
var missingDefaultFlags []string
|
||||||
|
var missingCustomFlags []string
|
||||||
for _, f := range flags {
|
for _, f := range flags {
|
||||||
if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() {
|
if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() {
|
||||||
key := strings.Split(f.GetName(), ",")[0]
|
key := strings.Split(f.GetName(), ",")[0]
|
||||||
if !context.IsSet(key) {
|
if !context.IsSet(key) {
|
||||||
if re, ok := f.(RequiredFlagsErr); ok && re.FlagsErrRequired() {
|
if re, ok := f.(RequiredFlagErr); ok && re.IsCustom() {
|
||||||
missingFlags[key] = true
|
missingCustomFlags = append(missingCustomFlags, re.GetMessage())
|
||||||
} else {
|
} else {
|
||||||
missingFlags[key] = false
|
missingDefaultFlags = append(missingDefaultFlags, key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(missingFlags) != 0 {
|
if len(missingDefaultFlags) != 0 || len(missingCustomFlags) != 0 {
|
||||||
return &errRequiredFlags{missingFlags: missingFlags}
|
return &errRequiredFlags{
|
||||||
|
missingDefaultFlags: missingDefaultFlags,
|
||||||
|
missingCustomFlags: missingCustomFlags,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
3
flag.go
3
flag.go
@ -83,14 +83,13 @@ type RequiredFlag interface {
|
|||||||
IsRequired() bool
|
IsRequired() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequiredFlagsErr is an interface that allows users to redefine errors on required flags
|
// RequiredFlagErr is an interface that allows users to redefine errors on required flags
|
||||||
// it allows flags with user-defined errors to be backwards compatible with the Flag interface
|
// it allows flags with user-defined errors to be backwards compatible with the Flag interface
|
||||||
type RequiredFlagErr interface {
|
type RequiredFlagErr interface {
|
||||||
Flag
|
Flag
|
||||||
|
|
||||||
IsCustom() bool
|
IsCustom() bool
|
||||||
GetMessage() string
|
GetMessage() string
|
||||||
HasInterpolation() bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DocGenerationFlag is an interface that allows documentation generation for the flag
|
// DocGenerationFlag is an interface that allows documentation generation for the flag
|
||||||
|
@ -10,9 +10,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type FlagErr struct {
|
type FlagErr struct {
|
||||||
Custom bool
|
Custom bool
|
||||||
Message string
|
Message string
|
||||||
Interpolate bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// BoolFlag is a flag with type bool
|
// BoolFlag is a flag with type bool
|
||||||
@ -136,11 +135,6 @@ func (f BoolTFlag) GetMessage() string {
|
|||||||
return f.RequiredFlagErr.Message
|
return f.RequiredFlagErr.Message
|
||||||
}
|
}
|
||||||
|
|
||||||
// FlagsErrRequired returns whether or not the flag is required
|
|
||||||
func (f BoolTFlag) FlagsErrRequired() bool {
|
|
||||||
return f.RequiredFlagsErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// TakesValue returns true of the flag takes a value, otherwise false
|
// TakesValue returns true of the flag takes a value, otherwise false
|
||||||
func (f BoolTFlag) TakesValue() bool {
|
func (f BoolTFlag) TakesValue() bool {
|
||||||
return false
|
return false
|
||||||
@ -224,11 +218,6 @@ func (f DurationFlag) GetMessage() string {
|
|||||||
return f.RequiredFlagErr.Message
|
return f.RequiredFlagErr.Message
|
||||||
}
|
}
|
||||||
|
|
||||||
// FlagsErrRequired returns whether or not the flag is required
|
|
||||||
func (f DurationFlag) FlagsErrRequired() bool {
|
|
||||||
return f.RequiredFlagsErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// TakesValue returns true of the flag takes a value, otherwise false
|
// TakesValue returns true of the flag takes a value, otherwise false
|
||||||
func (f DurationFlag) TakesValue() bool {
|
func (f DurationFlag) TakesValue() bool {
|
||||||
return true
|
return true
|
||||||
|
Loading…
Reference in New Issue
Block a user