From cbb9e015b89225aa090c41085bdb0933f6290d96 Mon Sep 17 00:00:00 2001 From: Ajitem Sahasrabuddhe Date: Wed, 11 Sep 2019 09:21:45 +0530 Subject: [PATCH] Improve Code and Add Test Case --- context.go | 6 ++++-- context_test.go | 7 +++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/context.go b/context.go index 485f529..8f1dcd8 100644 --- a/context.go +++ b/context.go @@ -314,14 +314,16 @@ func checkRequiredFlags(flags []Flag, context *Context) requiredFlagsErr { for _, f := range flags { if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() { key := strings.Split(f.GetName(), ",") + shortName := strings.TrimSpace(key[0]) if len(key) > 1 { // has short name - if !context.IsSet(strings.TrimSpace(key[0])) && !context.IsSet(strings.TrimSpace(key[1])) { + longName := strings.TrimSpace(key[1]) + if !context.IsSet(shortName) && !context.IsSet(longName) { missingFlags = append(missingFlags, key[0]) } } else { // does not have short name - if !context.IsSet(strings.TrimSpace(key[0])) { + if !context.IsSet(shortName) { missingFlags = append(missingFlags, key[0]) } } diff --git a/context_test.go b/context_test.go index 9e594dd..1d52921 100644 --- a/context_test.go +++ b/context_test.go @@ -517,6 +517,13 @@ func TestCheckRequiredFlags(t *testing.T) { }, parseInput: []string{"--requiredFlag", "myinput", "--requiredFlagTwo", "myinput"}, }, + { + testCase: "required_flag_with_short_name", + flags: []Flag{ + StringSliceFlag{Name: "names, N", Required: true}, + }, + parseInput: []string{"-N", "asd", "-N", "qwe"}, + }, } for _, test := range tdata { t.Run(test.testCase, func(t *testing.T) {