From ca9df40abd0db61673fb645f9826d8fbb061b6bb Mon Sep 17 00:00:00 2001 From: Ilia Choly Date: Tue, 30 Aug 2022 18:51:16 -0400 Subject: [PATCH] Add App.InvalidFlagAccessHandler (#1446) * Add App.UnknownFlagHandler * Rename App.UnknownFlagHandler to App.InvalidFlagAccessHandler * Traverse parent contexts --- app.go | 2 ++ context.go | 15 ++++++++++++++- context_test.go | 38 ++++++++++++++++++++++++++++++++++++++ funcs.go | 3 +++ 4 files changed, 57 insertions(+), 1 deletion(-) diff --git a/app.go b/app.go index b39c8b8..1d2f6e8 100644 --- a/app.go +++ b/app.go @@ -78,6 +78,8 @@ type App struct { CommandNotFound CommandNotFoundFunc // Execute this function if a usage error occurs OnUsageError OnUsageErrorFunc + // Execute this function when an invalid flag is accessed from the context + InvalidFlagAccessHandler InvalidFlagAccessFunc // Compilation date Compiled time.Time // List of all authors who contributed diff --git a/context.go b/context.go index 3a78c26..dc0d1ef 100644 --- a/context.go +++ b/context.go @@ -46,6 +46,9 @@ func (cCtx *Context) NumFlags() int { // Set sets a context flag to a value. func (cCtx *Context) Set(name, value string) error { + if cCtx.flagSet.Lookup(name) == nil { + cCtx.onInvalidFlag(name) + } return cCtx.flagSet.Set(name, value) } @@ -158,7 +161,7 @@ func (cCtx *Context) lookupFlagSet(name string) *flag.FlagSet { return c.flagSet } } - + cCtx.onInvalidFlag(name) return nil } @@ -190,6 +193,16 @@ func (cCtx *Context) checkRequiredFlags(flags []Flag) requiredFlagsErr { return nil } +func (cCtx *Context) onInvalidFlag(name string) { + for cCtx != nil { + if cCtx.App != nil && cCtx.App.InvalidFlagAccessHandler != nil { + cCtx.App.InvalidFlagAccessHandler(cCtx, name) + break + } + cCtx = cCtx.parentContext + } +} + func makeFlagNameVisitor(names *[]string) func(*flag.Flag) { return func(f *flag.Flag) { nameParts := strings.Split(f.Name, ",") diff --git a/context_test.go b/context_test.go index 55a9ead..6601155 100644 --- a/context_test.go +++ b/context_test.go @@ -150,6 +150,31 @@ func TestContext_Value(t *testing.T) { expect(t, c.Value("unknown-flag"), nil) } +func TestContext_Value_InvalidFlagAccessHandler(t *testing.T) { + var flagName string + app := &App{ + InvalidFlagAccessHandler: func(_ *Context, name string) { + flagName = name + }, + Commands: []*Command{ + { + Name: "command", + Subcommands: []*Command{ + { + Name: "subcommand", + Action: func(ctx *Context) error { + ctx.Value("missing") + return nil + }, + }, + }, + }, + }, + } + expect(t, app.Run([]string{"run", "command", "subcommand"}), nil) + expect(t, flagName, "missing") +} + func TestContext_Args(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Bool("myflag", false, "doc") @@ -258,6 +283,19 @@ func TestContext_Set(t *testing.T) { expect(t, c.IsSet("int"), true) } +func TestContext_Set_InvalidFlagAccessHandler(t *testing.T) { + set := flag.NewFlagSet("test", 0) + var flagName string + app := &App{ + InvalidFlagAccessHandler: func(_ *Context, name string) { + flagName = name + }, + } + c := NewContext(app, set, nil) + c.Set("missing", "") + expect(t, flagName, "missing") +} + func TestContext_LocalFlagNames(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Bool("one-flag", false, "doc") diff --git a/funcs.go b/funcs.go index 0a9b22c..e77b0d0 100644 --- a/funcs.go +++ b/funcs.go @@ -23,6 +23,9 @@ type CommandNotFoundFunc func(*Context, string) // is displayed and the execution is interrupted. type OnUsageErrorFunc func(cCtx *Context, err error, isSubcommand bool) error +// InvalidFlagAccessFunc is executed when an invalid flag is accessed from the context. +type InvalidFlagAccessFunc func(*Context, string) + // ExitErrHandlerFunc is executed if provided in order to handle exitError values // returned by Actions and Before/After functions. type ExitErrHandlerFunc func(cCtx *Context, err error)