add typed error assertions

This commit is contained in:
Lynn Cyrin 2019-08-01 21:35:15 -07:00
parent d4740d10d0
commit 78db152323
No known key found for this signature in database
GPG Key ID: EE9CCB427DFEC897
2 changed files with 28 additions and 8 deletions

View File

@ -1002,6 +1002,9 @@ func TestRequiredFlagAppRunBehavior(t *testing.T) {
if test.expectedAnError && err == nil {
t.Errorf("expected an error, but there was none")
}
if _, ok := err.(requiredFlagsErr); test.expectedAnError && !ok {
t.Errorf("expected a requiredFlagsErr, but got: %s", err)
}
if !test.expectedAnError && err != nil {
t.Errorf("did not expected an error, but there was one: %s", err)
}

View File

@ -287,7 +287,29 @@ func normalizeFlags(flags []Flag, set *flag.FlagSet) error {
return nil
}
func checkRequiredFlags(flags []Flag, set *flag.FlagSet) error {
type requiredFlagsErr interface {
error
getMissingFlags() []string
}
type errRequiredFlags struct {
missingFlags []string
}
func (e *errRequiredFlags) Error() string {
numberOfMissingFlags := len(e.missingFlags)
if numberOfMissingFlags == 1 {
return fmt.Sprintf("Required flag %q not set", e.missingFlags[0])
}
joinedMissingFlags := strings.Join(e.missingFlags, ", ")
return fmt.Sprintf("Required flags %q not set", joinedMissingFlags)
}
func (e *errRequiredFlags) getMissingFlags() []string {
return e.missingFlags
}
func checkRequiredFlags(flags []Flag, set *flag.FlagSet) requiredFlagsErr {
visited := make(map[string]bool)
set.Visit(func(f *flag.Flag) {
visited[f.Name] = true
@ -303,13 +325,8 @@ func checkRequiredFlags(flags []Flag, set *flag.FlagSet) error {
}
}
numberOfMissingFlags := len(missingFlags)
if numberOfMissingFlags == 1 {
return fmt.Errorf("Required flag %q not set", missingFlags[0])
}
if numberOfMissingFlags >= 2 {
joinedMissingFlags := strings.Join(missingFlags, ", ")
return fmt.Errorf("Required flags %q not set", joinedMissingFlags)
if len(missingFlags) != 0 {
return &errRequiredFlags{missingFlags: missingFlags}
}
return nil