add typed error assertions

This commit is contained in:
Lynn Cyrin 2019-08-01 21:35:15 -07:00
parent d4740d10d0
commit 78db152323
No known key found for this signature in database
GPG Key ID: EE9CCB427DFEC897
2 changed files with 28 additions and 8 deletions

View File

@ -1002,6 +1002,9 @@ func TestRequiredFlagAppRunBehavior(t *testing.T) {
if test.expectedAnError && err == nil { if test.expectedAnError && err == nil {
t.Errorf("expected an error, but there was none") t.Errorf("expected an error, but there was none")
} }
if _, ok := err.(requiredFlagsErr); test.expectedAnError && !ok {
t.Errorf("expected a requiredFlagsErr, but got: %s", err)
}
if !test.expectedAnError && err != nil { if !test.expectedAnError && err != nil {
t.Errorf("did not expected an error, but there was one: %s", err) t.Errorf("did not expected an error, but there was one: %s", err)
} }

View File

@ -287,7 +287,29 @@ func normalizeFlags(flags []Flag, set *flag.FlagSet) error {
return nil return nil
} }
func checkRequiredFlags(flags []Flag, set *flag.FlagSet) error { type requiredFlagsErr interface {
error
getMissingFlags() []string
}
type errRequiredFlags struct {
missingFlags []string
}
func (e *errRequiredFlags) Error() string {
numberOfMissingFlags := len(e.missingFlags)
if numberOfMissingFlags == 1 {
return fmt.Sprintf("Required flag %q not set", e.missingFlags[0])
}
joinedMissingFlags := strings.Join(e.missingFlags, ", ")
return fmt.Sprintf("Required flags %q not set", joinedMissingFlags)
}
func (e *errRequiredFlags) getMissingFlags() []string {
return e.missingFlags
}
func checkRequiredFlags(flags []Flag, set *flag.FlagSet) requiredFlagsErr {
visited := make(map[string]bool) visited := make(map[string]bool)
set.Visit(func(f *flag.Flag) { set.Visit(func(f *flag.Flag) {
visited[f.Name] = true visited[f.Name] = true
@ -303,13 +325,8 @@ func checkRequiredFlags(flags []Flag, set *flag.FlagSet) error {
} }
} }
numberOfMissingFlags := len(missingFlags) if len(missingFlags) != 0 {
if numberOfMissingFlags == 1 { return &errRequiredFlags{missingFlags: missingFlags}
return fmt.Errorf("Required flag %q not set", missingFlags[0])
}
if numberOfMissingFlags >= 2 {
joinedMissingFlags := strings.Join(missingFlags, ", ")
return fmt.Errorf("Required flags %q not set", joinedMissingFlags)
} }
return nil return nil