Use iterative logic to determine missing flag
This commit is contained in:
parent
1547ac2f6a
commit
c6ee3b4904
21
context.go
21
context.go
@ -313,18 +313,21 @@ func checkRequiredFlags(flags []Flag, context *Context) requiredFlagsErr {
|
|||||||
var missingFlags []string
|
var missingFlags []string
|
||||||
for _, f := range flags {
|
for _, f := range flags {
|
||||||
if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() {
|
if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() {
|
||||||
key := strings.Split(f.GetName(), ",")
|
var flagPresent bool
|
||||||
if len(key) > 1 {
|
var flagName string
|
||||||
// has short name
|
for _, key := range strings.Split(f.GetName(), ",") {
|
||||||
if !context.IsSet(strings.TrimSpace(key[0])) && !context.IsSet(strings.TrimSpace(key[1])) {
|
if len(key) > 1 {
|
||||||
missingFlags = append(missingFlags, key[0])
|
flagName = key
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// does not have short name
|
if context.IsSet(strings.TrimSpace(key)) {
|
||||||
if !context.IsSet(strings.TrimSpace(key[0])) {
|
flagPresent = true
|
||||||
missingFlags = append(missingFlags, key[0])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !flagPresent {
|
||||||
|
missingFlags = append(missingFlags, flagName)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -524,6 +524,13 @@ func TestCheckRequiredFlags(t *testing.T) {
|
|||||||
},
|
},
|
||||||
parseInput: []string{"-N", "asd", "-N", "qwe"},
|
parseInput: []string{"-N", "asd", "-N", "qwe"},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
testCase: "required_flag_with_short_name",
|
||||||
|
flags: []Flag{
|
||||||
|
StringSliceFlag{Name: "names, N, n", Required: true},
|
||||||
|
},
|
||||||
|
parseInput: []string{"-n", "asd", "-n", "qwe"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, test := range tdata {
|
for _, test := range tdata {
|
||||||
t.Run(test.testCase, func(t *testing.T) {
|
t.Run(test.testCase, func(t *testing.T) {
|
||||||
|
Loading…
Reference in New Issue
Block a user