From 78db152323afb7934f9f0dd207eeaf34147bb300 Mon Sep 17 00:00:00 2001 From: Lynn Cyrin Date: Thu, 1 Aug 2019 21:35:15 -0700 Subject: [PATCH] add typed error assertions --- app_test.go | 3 +++ context.go | 33 +++++++++++++++++++++++++-------- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/app_test.go b/app_test.go index b469644..1e42f12 100644 --- a/app_test.go +++ b/app_test.go @@ -1002,6 +1002,9 @@ func TestRequiredFlagAppRunBehavior(t *testing.T) { if test.expectedAnError && err == nil { t.Errorf("expected an error, but there was none") } + if _, ok := err.(requiredFlagsErr); test.expectedAnError && !ok { + t.Errorf("expected a requiredFlagsErr, but got: %s", err) + } if !test.expectedAnError && err != nil { t.Errorf("did not expected an error, but there was one: %s", err) } diff --git a/context.go b/context.go index 91a9575..498fd2a 100644 --- a/context.go +++ b/context.go @@ -287,7 +287,29 @@ func normalizeFlags(flags []Flag, set *flag.FlagSet) error { return nil } -func checkRequiredFlags(flags []Flag, set *flag.FlagSet) error { +type requiredFlagsErr interface { + error + getMissingFlags() []string +} + +type errRequiredFlags struct { + missingFlags []string +} + +func (e *errRequiredFlags) Error() string { + numberOfMissingFlags := len(e.missingFlags) + if numberOfMissingFlags == 1 { + return fmt.Sprintf("Required flag %q not set", e.missingFlags[0]) + } + joinedMissingFlags := strings.Join(e.missingFlags, ", ") + return fmt.Sprintf("Required flags %q not set", joinedMissingFlags) +} + +func (e *errRequiredFlags) getMissingFlags() []string { + return e.missingFlags +} + +func checkRequiredFlags(flags []Flag, set *flag.FlagSet) requiredFlagsErr { visited := make(map[string]bool) set.Visit(func(f *flag.Flag) { visited[f.Name] = true @@ -303,13 +325,8 @@ func checkRequiredFlags(flags []Flag, set *flag.FlagSet) error { } } - numberOfMissingFlags := len(missingFlags) - if numberOfMissingFlags == 1 { - return fmt.Errorf("Required flag %q not set", missingFlags[0]) - } - if numberOfMissingFlags >= 2 { - joinedMissingFlags := strings.Join(missingFlags, ", ") - return fmt.Errorf("Required flags %q not set", joinedMissingFlags) + if len(missingFlags) != 0 { + return &errRequiredFlags{missingFlags: missingFlags} } return nil