From c6ee3b4904ed76d34f277c315c2097ae7b22d38f Mon Sep 17 00:00:00 2001 From: Ajitem Sahasrabuddhe Date: Wed, 11 Sep 2019 14:34:41 +0530 Subject: [PATCH] Use iterative logic to determine missing flag --- context.go | 21 ++++++++++++--------- context_test.go | 7 +++++++ 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/context.go b/context.go index 485f529..2f18f3f 100644 --- a/context.go +++ b/context.go @@ -313,18 +313,21 @@ func checkRequiredFlags(flags []Flag, context *Context) requiredFlagsErr { var missingFlags []string for _, f := range flags { if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() { - key := strings.Split(f.GetName(), ",") - if len(key) > 1 { - // has short name - if !context.IsSet(strings.TrimSpace(key[0])) && !context.IsSet(strings.TrimSpace(key[1])) { - missingFlags = append(missingFlags, key[0]) + var flagPresent bool + var flagName string + for _, key := range strings.Split(f.GetName(), ",") { + if len(key) > 1 { + flagName = key } - } else { - // does not have short name - if !context.IsSet(strings.TrimSpace(key[0])) { - missingFlags = append(missingFlags, key[0]) + + if context.IsSet(strings.TrimSpace(key)) { + flagPresent = true } } + + if !flagPresent { + missingFlags = append(missingFlags, flagName) + } } } diff --git a/context_test.go b/context_test.go index 1d52921..13c3701 100644 --- a/context_test.go +++ b/context_test.go @@ -524,6 +524,13 @@ func TestCheckRequiredFlags(t *testing.T) { }, parseInput: []string{"-N", "asd", "-N", "qwe"}, }, + { + testCase: "required_flag_with_short_name", + flags: []Flag{ + StringSliceFlag{Name: "names, N, n", Required: true}, + }, + parseInput: []string{"-n", "asd", "-n", "qwe"}, + }, } for _, test := range tdata { t.Run(test.testCase, func(t *testing.T) {