add handling for multiple required flags

This commit is contained in:
Lynn Cyrin 2019-07-17 00:16:40 -07:00
parent 9293f5b3cc
commit cdc7af744e
No known key found for this signature in database
GPG Key ID: EE9CCB427DFEC897
2 changed files with 36 additions and 8 deletions

View File

@ -293,13 +293,32 @@ func checkRequiredFlags(flags []Flag, set *flag.FlagSet) error {
visited[f.Name] = true visited[f.Name] = true
}) })
var missingFlags []string
for _, f := range flags { for _, f := range flags {
if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() { if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() {
key := strings.Split(f.GetName(), ",")[0] key := strings.Split(f.GetName(), ",")[0]
if !visited[key] { if !visited[key] {
return fmt.Errorf("Required flag %q not set", f.GetName()) missingFlags = append(missingFlags, f.GetName())
} }
} }
} }
numberOfMissingFlags := len(missingFlags)
if numberOfMissingFlags == 1 {
return fmt.Errorf("Required flag %q not set", missingFlags[0])
}
if numberOfMissingFlags >= 2 {
var missingFlagsOutput string
for idx, f := range missingFlags {
// if not the last item, append with a ", "
if idx != numberOfMissingFlags-1 {
missingFlagsOutput = fmt.Sprintf("%s%s, ", missingFlagsOutput, f)
} else {
missingFlagsOutput = fmt.Sprintf("%s%s", missingFlagsOutput, f)
}
}
return fmt.Errorf("Required flags %q not set", missingFlagsOutput)
}
return nil return nil
} }

View File

@ -3,6 +3,7 @@ package cli
import ( import (
"flag" "flag"
"os" "os"
"strings"
"testing" "testing"
"time" "time"
) )
@ -404,10 +405,11 @@ func TestContext_GlobalSet(t *testing.T) {
func TestCheckRequiredFlags(t *testing.T) { func TestCheckRequiredFlags(t *testing.T) {
tdata := []struct { tdata := []struct {
testCase string testCase string
parseInput []string parseInput []string
flags []Flag flags []Flag
expectedAnError bool expectedAnError bool
expectedErrorContents []string
}{ }{
{ {
testCase: "empty", testCase: "empty",
@ -423,7 +425,8 @@ func TestCheckRequiredFlags(t *testing.T) {
flags: []Flag{ flags: []Flag{
StringFlag{Name: "requiredFlag", Required: true}, StringFlag{Name: "requiredFlag", Required: true},
}, },
expectedAnError: true, expectedAnError: true,
expectedErrorContents: []string{"requiredFlag"},
}, },
{ {
testCase: "required_and_present", testCase: "required_and_present",
@ -460,10 +463,11 @@ func TestCheckRequiredFlags(t *testing.T) {
{ {
testCase: "two_required", testCase: "two_required",
flags: []Flag{ flags: []Flag{
StringFlag{Name: "requiredFlag", Required: true}, StringFlag{Name: "requiredFlagOne", Required: true},
StringFlag{Name: "requiredFlagTwo", Required: true}, StringFlag{Name: "requiredFlagTwo", Required: true},
}, },
expectedAnError: true, expectedAnError: true,
expectedErrorContents: []string{"requiredFlagOne", "requiredFlagTwo"},
}, },
{ {
testCase: "two_required_and_one_present", testCase: "two_required_and_one_present",
@ -502,6 +506,11 @@ func TestCheckRequiredFlags(t *testing.T) {
if !test.expectedAnError && err != nil { if !test.expectedAnError && err != nil {
t.Errorf("did not expected an error, but there was one: %s", err) t.Errorf("did not expected an error, but there was one: %s", err)
} }
for _, errString := range test.expectedErrorContents {
if !strings.Contains(err.Error(), errString) {
t.Errorf("expected error %q to contain %q, but it didn't!", err.Error(), errString)
}
}
}) })
} }
} }