add typed error assertions
This commit is contained in:
33
context.go
33
context.go
@@ -287,7 +287,29 @@ func normalizeFlags(flags []Flag, set *flag.FlagSet) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkRequiredFlags(flags []Flag, set *flag.FlagSet) error {
|
||||
type requiredFlagsErr interface {
|
||||
error
|
||||
getMissingFlags() []string
|
||||
}
|
||||
|
||||
type errRequiredFlags struct {
|
||||
missingFlags []string
|
||||
}
|
||||
|
||||
func (e *errRequiredFlags) Error() string {
|
||||
numberOfMissingFlags := len(e.missingFlags)
|
||||
if numberOfMissingFlags == 1 {
|
||||
return fmt.Sprintf("Required flag %q not set", e.missingFlags[0])
|
||||
}
|
||||
joinedMissingFlags := strings.Join(e.missingFlags, ", ")
|
||||
return fmt.Sprintf("Required flags %q not set", joinedMissingFlags)
|
||||
}
|
||||
|
||||
func (e *errRequiredFlags) getMissingFlags() []string {
|
||||
return e.missingFlags
|
||||
}
|
||||
|
||||
func checkRequiredFlags(flags []Flag, set *flag.FlagSet) requiredFlagsErr {
|
||||
visited := make(map[string]bool)
|
||||
set.Visit(func(f *flag.Flag) {
|
||||
visited[f.Name] = true
|
||||
@@ -303,13 +325,8 @@ func checkRequiredFlags(flags []Flag, set *flag.FlagSet) error {
|
||||
}
|
||||
}
|
||||
|
||||
numberOfMissingFlags := len(missingFlags)
|
||||
if numberOfMissingFlags == 1 {
|
||||
return fmt.Errorf("Required flag %q not set", missingFlags[0])
|
||||
}
|
||||
if numberOfMissingFlags >= 2 {
|
||||
joinedMissingFlags := strings.Join(missingFlags, ", ")
|
||||
return fmt.Errorf("Required flags %q not set", joinedMissingFlags)
|
||||
if len(missingFlags) != 0 {
|
||||
return &errRequiredFlags{missingFlags: missingFlags}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
Reference in New Issue
Block a user