From 1794792adfc0f1cea5a4e56eefc36fd849018d05 Mon Sep 17 00:00:00 2001 From: "Joe Richey joerichey@google.com" Date: Fri, 5 May 2017 20:07:18 -0700 Subject: [PATCH] Add ability to use custom Flag types Users can now use custom flags types (conforming to the Flag interface) in their applications. They can also use custom flags for the three global flags (Help, Version, bash completion). --- app_test.go | 61 +++++++++++++++++++++++++++++++++++++++++++++++++++-- flag.go | 12 +++++------ help.go | 10 ++++----- 3 files changed, 70 insertions(+), 13 deletions(-) diff --git a/app_test.go b/app_test.go index 10f1562..e14ddaf 100644 --- a/app_test.go +++ b/app_test.go @@ -1520,6 +1520,63 @@ func TestApp_OnUsageError_WithWrongFlagValue_ForSubcommand(t *testing.T) { } } +// A custom flag that conforms to the relevant interfaces, but has none of the +// fields that the other flag types do. +type customBoolFlag struct { + Nombre string +} + +// Don't use the normal FlagStringer +func (c *customBoolFlag) String() string { + return "***" + c.Nombre + "***" +} + +func (c *customBoolFlag) GetName() string { + return c.Nombre +} + +func (c *customBoolFlag) Apply(set *flag.FlagSet) { + set.String(c.Nombre, c.Nombre, "") +} + +func TestCustomFlagsUnused(t *testing.T) { + app := NewApp() + app.Flags = []Flag{&customBoolFlag{"custom"}} + + err := app.Run([]string{"foo"}) + if err != nil { + t.Errorf("Run returned unexpected error: %v", err) + } +} + +func TestCustomFlagsUsed(t *testing.T) { + app := NewApp() + app.Flags = []Flag{&customBoolFlag{"custom"}} + + err := app.Run([]string{"foo", "--custom=bar"}) + if err != nil { + t.Errorf("Run returned unexpected error: %v", err) + } +} + +func TestCustomHelpVersionFlags(t *testing.T) { + app := NewApp() + + // Be sure to reset the global flags + defer func(helpFlag Flag, versionFlag Flag) { + HelpFlag = helpFlag + VersionFlag = versionFlag + }(HelpFlag, VersionFlag) + + HelpFlag = &customBoolFlag{"help-custom"} + VersionFlag = &customBoolFlag{"version-custom"} + + err := app.Run([]string{"foo", "--help-custom=bar"}) + if err != nil { + t.Errorf("Run returned unexpected error: %v", err) + } +} + func TestHandleAction_WithNonFuncAction(t *testing.T) { app := NewApp() app.Action = 42 @@ -1642,7 +1699,7 @@ func TestShellCompletionForIncompleteFlags(t *testing.T) { for _, flag := range ctx.App.Flags { for _, name := range strings.Split(flag.GetName(), ",") { - if name == BashCompletionFlag.Name { + if name == BashCompletionFlag.GetName() { continue } @@ -1659,7 +1716,7 @@ func TestShellCompletionForIncompleteFlags(t *testing.T) { app.Action = func(ctx *Context) error { return fmt.Errorf("should not get here") } - err := app.Run([]string{"", "--test-completion", "--" + BashCompletionFlag.Name}) + err := app.Run([]string{"", "--test-completion", "--" + BashCompletionFlag.GetName()}) if err != nil { t.Errorf("app should not return an error: %s", err) } diff --git a/flag.go b/flag.go index 7dd8a2c..877ff35 100644 --- a/flag.go +++ b/flag.go @@ -14,13 +14,13 @@ import ( const defaultPlaceholder = "value" // BashCompletionFlag enables bash-completion for all commands and subcommands -var BashCompletionFlag = BoolFlag{ +var BashCompletionFlag Flag = BoolFlag{ Name: "generate-bash-completion", Hidden: true, } // VersionFlag prints the version for the application -var VersionFlag = BoolFlag{ +var VersionFlag Flag = BoolFlag{ Name: "version, v", Usage: "print the version", } @@ -28,7 +28,7 @@ var VersionFlag = BoolFlag{ // HelpFlag prints the help for all commands and subcommands // Set to the zero value (BoolFlag{}) to disable flag -- keeps subcommand // unless HideHelp is set to true) -var HelpFlag = BoolFlag{ +var HelpFlag Flag = BoolFlag{ Name: "help, h", Usage: "show help", } @@ -630,7 +630,8 @@ func (f Float64Flag) ApplyWithError(set *flag.FlagSet) error { func visibleFlags(fl []Flag) []Flag { visible := []Flag{} for _, flag := range fl { - if !flagValue(flag).FieldByName("Hidden").Bool() { + field := flagValue(flag).FieldByName("Hidden") + if !field.IsValid() || !field.Bool() { visible = append(visible, flag) } } @@ -723,9 +724,8 @@ func stringifyFlag(f Flag) string { needsPlaceholder := false defaultValueString := "" - val := fv.FieldByName("Value") - if val.IsValid() { + if val := fv.FieldByName("Value"); val.IsValid() { needsPlaceholder = true defaultValueString = fmt.Sprintf(" (default: %v)", val.Interface()) diff --git a/help.go b/help.go index d00e4da..df4cb56 100644 --- a/help.go +++ b/help.go @@ -212,8 +212,8 @@ func printHelp(out io.Writer, templ string, data interface{}) { func checkVersion(c *Context) bool { found := false - if VersionFlag.Name != "" { - eachName(VersionFlag.Name, func(name string) { + if VersionFlag.GetName() != "" { + eachName(VersionFlag.GetName(), func(name string) { if c.GlobalBool(name) || c.Bool(name) { found = true } @@ -224,8 +224,8 @@ func checkVersion(c *Context) bool { func checkHelp(c *Context) bool { found := false - if HelpFlag.Name != "" { - eachName(HelpFlag.Name, func(name string) { + if HelpFlag.GetName() != "" { + eachName(HelpFlag.GetName(), func(name string) { if c.GlobalBool(name) || c.Bool(name) { found = true } @@ -260,7 +260,7 @@ func checkShellCompleteFlag(a *App, arguments []string) (bool, []string) { pos := len(arguments) - 1 lastArg := arguments[pos] - if lastArg != "--"+BashCompletionFlag.Name { + if lastArg != "--"+BashCompletionFlag.GetName() { return false, arguments }