diff --git a/app.go b/app.go index 1d22bd1..f30595c 100644 --- a/app.go +++ b/app.go @@ -78,6 +78,8 @@ type App struct { CommandNotFound CommandNotFoundFunc // Execute this function if a usage error occurs OnUsageError OnUsageErrorFunc + // Execute this function when an unknown flag is accessed from the context + UnknownFlagHandler UnknownFlagFunc // Compilation date Compiled time.Time // List of all authors who contributed diff --git a/context.go b/context.go index e79ec34..a14a341 100644 --- a/context.go +++ b/context.go @@ -46,6 +46,9 @@ func (cCtx *Context) NumFlags() int { // Set sets a context flag to a value. func (cCtx *Context) Set(name, value string) error { + if cCtx.flagSet.Lookup(name) == nil { + cCtx.onUnknownFlag(name) + } return cCtx.flagSet.Set(name, value) } @@ -158,7 +161,7 @@ func (cCtx *Context) lookupFlagSet(name string) *flag.FlagSet { return c.flagSet } } - + cCtx.onUnknownFlag(name) return nil } @@ -190,6 +193,12 @@ func (cCtx *Context) checkRequiredFlags(flags []Flag) requiredFlagsErr { return nil } +func (cCtx *Context) onUnknownFlag(name string) { + if cCtx.App != nil && cCtx.App.UnknownFlagHandler != nil { + cCtx.App.UnknownFlagHandler(cCtx, name) + } +} + func makeFlagNameVisitor(names *[]string) func(*flag.Flag) { return func(f *flag.Flag) { nameParts := strings.Split(f.Name, ",") diff --git a/context_test.go b/context_test.go index 55a9ead..1ffe040 100644 --- a/context_test.go +++ b/context_test.go @@ -150,6 +150,19 @@ func TestContext_Value(t *testing.T) { expect(t, c.Value("unknown-flag"), nil) } +func TestContext_Value_UnknownFlagHandler(t *testing.T) { + set := flag.NewFlagSet("test", 0) + var flagName string + app := &App{ + UnknownFlagHandler: func(_ *Context, name string) { + flagName = name + }, + } + c := NewContext(app, set, nil) + c.Value("missing") + expect(t, flagName, "missing") +} + func TestContext_Args(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Bool("myflag", false, "doc") @@ -258,6 +271,19 @@ func TestContext_Set(t *testing.T) { expect(t, c.IsSet("int"), true) } +func TestContext_Set_StrictLookup(t *testing.T) { + set := flag.NewFlagSet("test", 0) + var flagName string + app := &App{ + UnknownFlagHandler: func(_ *Context, name string) { + flagName = name + }, + } + c := NewContext(app, set, nil) + c.Set("missing", "") + expect(t, flagName, "missing") +} + func TestContext_LocalFlagNames(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Bool("one-flag", false, "doc") diff --git a/funcs.go b/funcs.go index 0a9b22c..1342bd2 100644 --- a/funcs.go +++ b/funcs.go @@ -23,6 +23,9 @@ type CommandNotFoundFunc func(*Context, string) // is displayed and the execution is interrupted. type OnUsageErrorFunc func(cCtx *Context, err error, isSubcommand bool) error +// UnknownFlagFunc is executed when an unknown flag is accessed from the context. +type UnknownFlagFunc func(*Context, string) + // ExitErrHandlerFunc is executed if provided in order to handle exitError values // returned by Actions and Before/After functions. type ExitErrHandlerFunc func(cCtx *Context, err error)