diff --git a/app.go b/app.go index 051c801..9ed492f 100644 --- a/app.go +++ b/app.go @@ -228,7 +228,7 @@ func (a *App) Run(arguments []string) (err error) { return nil } - cerr := checkRequiredFlags(a.Flags, set) + cerr := checkRequiredFlags(a.Flags, context) if cerr != nil { ShowAppHelp(context) return cerr @@ -358,7 +358,7 @@ func (a *App) RunAsSubcommand(ctx *Context) (err error) { } } - cerr := checkRequiredFlags(a.Flags, set) + cerr := checkRequiredFlags(a.Flags, context) if cerr != nil { ShowSubcommandHelp(context) return cerr diff --git a/command.go b/command.go index cbf06bb..e3b57db 100644 --- a/command.go +++ b/command.go @@ -135,7 +135,7 @@ func (c Command) Run(ctx *Context) (err error) { return nil } - cerr := checkRequiredFlags(c.Flags, set) + cerr := checkRequiredFlags(c.Flags, context) if cerr != nil { ShowCommandHelp(context, c.Name) return cerr diff --git a/context.go b/context.go index 8af3264..3e516c8 100644 --- a/context.go +++ b/context.go @@ -309,17 +309,12 @@ func (e *errRequiredFlags) getMissingFlags() []string { return e.missingFlags } -func checkRequiredFlags(flags []Flag, set *flag.FlagSet) requiredFlagsErr { - visited := make(map[string]bool) - set.Visit(func(f *flag.Flag) { - visited[f.Name] = true - }) - +func checkRequiredFlags(flags []Flag, context *Context) requiredFlagsErr { var missingFlags []string for _, f := range flags { if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() { key := strings.Split(f.GetName(), ",")[0] - if !visited[key] { + if !context.IsSet(key) { missingFlags = append(missingFlags, key) } } diff --git a/context_test.go b/context_test.go index 585ca82..f2fc250 100644 --- a/context_test.go +++ b/context_test.go @@ -495,9 +495,11 @@ func TestCheckRequiredFlags(t *testing.T) { flags.Apply(set) } set.Parse(test.parseInput) + ctx := &Context{} + context := NewContext(ctx.App, set, ctx) // logic under test - err := checkRequiredFlags(test.flags, set) + err := checkRequiredFlags(test.flags, context) // assertions if test.expectedAnError && err == nil {