diff --git a/context.go b/context.go index db7cd69..ecfc032 100644 --- a/context.go +++ b/context.go @@ -313,9 +313,20 @@ func checkRequiredFlags(flags []Flag, context *Context) requiredFlagsErr { var missingFlags []string for _, f := range flags { if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() { - key := strings.Split(f.GetName(), ",")[0] - if !context.IsSet(key) { - missingFlags = append(missingFlags, key) + var flagPresent bool + var flagName string + for _, key := range strings.Split(f.GetName(), ",") { + if len(key) > 1 { + flagName = key + } + + if context.IsSet(strings.TrimSpace(key)) { + flagPresent = true + } + } + + if !flagPresent && flagName != "" { + missingFlags = append(missingFlags, flagName) } } } diff --git a/context_test.go b/context_test.go index 9e594dd..28f5e08 100644 --- a/context_test.go +++ b/context_test.go @@ -517,6 +517,20 @@ func TestCheckRequiredFlags(t *testing.T) { }, parseInput: []string{"--requiredFlag", "myinput", "--requiredFlagTwo", "myinput"}, }, + { + testCase: "required_flag_with_short_name", + flags: []Flag{ + StringSliceFlag{Name: "names, N", Required: true}, + }, + parseInput: []string{"-N", "asd", "-N", "qwe"}, + }, + { + testCase: "required_flag_with_multiple_short_names", + flags: []Flag{ + StringSliceFlag{Name: "names, N, n", Required: true}, + }, + parseInput: []string{"-n", "asd", "-n", "qwe"}, + }, } for _, test := range tdata { t.Run(test.testCase, func(t *testing.T) {