From cdc7af744e07ac8dbb34793f8b392af46ba443f7 Mon Sep 17 00:00:00 2001 From: Lynn Cyrin Date: Wed, 17 Jul 2019 00:16:40 -0700 Subject: [PATCH] add handling for multiple required flags --- context.go | 21 ++++++++++++++++++++- context_test.go | 23 ++++++++++++++++------- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/context.go b/context.go index 8caf90d..383749d 100644 --- a/context.go +++ b/context.go @@ -293,13 +293,32 @@ func checkRequiredFlags(flags []Flag, set *flag.FlagSet) error { visited[f.Name] = true }) + var missingFlags []string for _, f := range flags { if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() { key := strings.Split(f.GetName(), ",")[0] if !visited[key] { - return fmt.Errorf("Required flag %q not set", f.GetName()) + missingFlags = append(missingFlags, f.GetName()) } } } + + numberOfMissingFlags := len(missingFlags) + if numberOfMissingFlags == 1 { + return fmt.Errorf("Required flag %q not set", missingFlags[0]) + } + if numberOfMissingFlags >= 2 { + var missingFlagsOutput string + for idx, f := range missingFlags { + // if not the last item, append with a ", " + if idx != numberOfMissingFlags-1 { + missingFlagsOutput = fmt.Sprintf("%s%s, ", missingFlagsOutput, f) + } else { + missingFlagsOutput = fmt.Sprintf("%s%s", missingFlagsOutput, f) + } + } + return fmt.Errorf("Required flags %q not set", missingFlagsOutput) + } + return nil } diff --git a/context_test.go b/context_test.go index fac1d4b..585ca82 100644 --- a/context_test.go +++ b/context_test.go @@ -3,6 +3,7 @@ package cli import ( "flag" "os" + "strings" "testing" "time" ) @@ -404,10 +405,11 @@ func TestContext_GlobalSet(t *testing.T) { func TestCheckRequiredFlags(t *testing.T) { tdata := []struct { - testCase string - parseInput []string - flags []Flag - expectedAnError bool + testCase string + parseInput []string + flags []Flag + expectedAnError bool + expectedErrorContents []string }{ { testCase: "empty", @@ -423,7 +425,8 @@ func TestCheckRequiredFlags(t *testing.T) { flags: []Flag{ StringFlag{Name: "requiredFlag", Required: true}, }, - expectedAnError: true, + expectedAnError: true, + expectedErrorContents: []string{"requiredFlag"}, }, { testCase: "required_and_present", @@ -460,10 +463,11 @@ func TestCheckRequiredFlags(t *testing.T) { { testCase: "two_required", flags: []Flag{ - StringFlag{Name: "requiredFlag", Required: true}, + StringFlag{Name: "requiredFlagOne", Required: true}, StringFlag{Name: "requiredFlagTwo", Required: true}, }, - expectedAnError: true, + expectedAnError: true, + expectedErrorContents: []string{"requiredFlagOne", "requiredFlagTwo"}, }, { testCase: "two_required_and_one_present", @@ -502,6 +506,11 @@ func TestCheckRequiredFlags(t *testing.T) { if !test.expectedAnError && err != nil { t.Errorf("did not expected an error, but there was one: %s", err) } + for _, errString := range test.expectedErrorContents { + if !strings.Contains(err.Error(), errString) { + t.Errorf("expected error %q to contain %q, but it didn't!", err.Error(), errString) + } + } }) } }