Add ability to use custom Flag types

Users can now use custom flags types (conforming to the Flag interface)
in their applications. They can also use custom flags for the three
global flags (Help, Version, bash completion).
This commit is contained in:
Joe Richey joerichey@google.com 2017-05-05 20:07:18 -07:00
parent d70f47eeca
commit 1794792adf
3 changed files with 70 additions and 13 deletions

View File

@ -1520,6 +1520,63 @@ func TestApp_OnUsageError_WithWrongFlagValue_ForSubcommand(t *testing.T) {
} }
} }
// A custom flag that conforms to the relevant interfaces, but has none of the
// fields that the other flag types do.
type customBoolFlag struct {
Nombre string
}
// Don't use the normal FlagStringer
func (c *customBoolFlag) String() string {
return "***" + c.Nombre + "***"
}
func (c *customBoolFlag) GetName() string {
return c.Nombre
}
func (c *customBoolFlag) Apply(set *flag.FlagSet) {
set.String(c.Nombre, c.Nombre, "")
}
func TestCustomFlagsUnused(t *testing.T) {
app := NewApp()
app.Flags = []Flag{&customBoolFlag{"custom"}}
err := app.Run([]string{"foo"})
if err != nil {
t.Errorf("Run returned unexpected error: %v", err)
}
}
func TestCustomFlagsUsed(t *testing.T) {
app := NewApp()
app.Flags = []Flag{&customBoolFlag{"custom"}}
err := app.Run([]string{"foo", "--custom=bar"})
if err != nil {
t.Errorf("Run returned unexpected error: %v", err)
}
}
func TestCustomHelpVersionFlags(t *testing.T) {
app := NewApp()
// Be sure to reset the global flags
defer func(helpFlag Flag, versionFlag Flag) {
HelpFlag = helpFlag
VersionFlag = versionFlag
}(HelpFlag, VersionFlag)
HelpFlag = &customBoolFlag{"help-custom"}
VersionFlag = &customBoolFlag{"version-custom"}
err := app.Run([]string{"foo", "--help-custom=bar"})
if err != nil {
t.Errorf("Run returned unexpected error: %v", err)
}
}
func TestHandleAction_WithNonFuncAction(t *testing.T) { func TestHandleAction_WithNonFuncAction(t *testing.T) {
app := NewApp() app := NewApp()
app.Action = 42 app.Action = 42
@ -1642,7 +1699,7 @@ func TestShellCompletionForIncompleteFlags(t *testing.T) {
for _, flag := range ctx.App.Flags { for _, flag := range ctx.App.Flags {
for _, name := range strings.Split(flag.GetName(), ",") { for _, name := range strings.Split(flag.GetName(), ",") {
if name == BashCompletionFlag.Name { if name == BashCompletionFlag.GetName() {
continue continue
} }
@ -1659,7 +1716,7 @@ func TestShellCompletionForIncompleteFlags(t *testing.T) {
app.Action = func(ctx *Context) error { app.Action = func(ctx *Context) error {
return fmt.Errorf("should not get here") return fmt.Errorf("should not get here")
} }
err := app.Run([]string{"", "--test-completion", "--" + BashCompletionFlag.Name}) err := app.Run([]string{"", "--test-completion", "--" + BashCompletionFlag.GetName()})
if err != nil { if err != nil {
t.Errorf("app should not return an error: %s", err) t.Errorf("app should not return an error: %s", err)
} }

12
flag.go
View File

@ -14,13 +14,13 @@ import (
const defaultPlaceholder = "value" const defaultPlaceholder = "value"
// BashCompletionFlag enables bash-completion for all commands and subcommands // BashCompletionFlag enables bash-completion for all commands and subcommands
var BashCompletionFlag = BoolFlag{ var BashCompletionFlag Flag = BoolFlag{
Name: "generate-bash-completion", Name: "generate-bash-completion",
Hidden: true, Hidden: true,
} }
// VersionFlag prints the version for the application // VersionFlag prints the version for the application
var VersionFlag = BoolFlag{ var VersionFlag Flag = BoolFlag{
Name: "version, v", Name: "version, v",
Usage: "print the version", Usage: "print the version",
} }
@ -28,7 +28,7 @@ var VersionFlag = BoolFlag{
// HelpFlag prints the help for all commands and subcommands // HelpFlag prints the help for all commands and subcommands
// Set to the zero value (BoolFlag{}) to disable flag -- keeps subcommand // Set to the zero value (BoolFlag{}) to disable flag -- keeps subcommand
// unless HideHelp is set to true) // unless HideHelp is set to true)
var HelpFlag = BoolFlag{ var HelpFlag Flag = BoolFlag{
Name: "help, h", Name: "help, h",
Usage: "show help", Usage: "show help",
} }
@ -630,7 +630,8 @@ func (f Float64Flag) ApplyWithError(set *flag.FlagSet) error {
func visibleFlags(fl []Flag) []Flag { func visibleFlags(fl []Flag) []Flag {
visible := []Flag{} visible := []Flag{}
for _, flag := range fl { for _, flag := range fl {
if !flagValue(flag).FieldByName("Hidden").Bool() { field := flagValue(flag).FieldByName("Hidden")
if !field.IsValid() || !field.Bool() {
visible = append(visible, flag) visible = append(visible, flag)
} }
} }
@ -723,9 +724,8 @@ func stringifyFlag(f Flag) string {
needsPlaceholder := false needsPlaceholder := false
defaultValueString := "" defaultValueString := ""
val := fv.FieldByName("Value")
if val.IsValid() { if val := fv.FieldByName("Value"); val.IsValid() {
needsPlaceholder = true needsPlaceholder = true
defaultValueString = fmt.Sprintf(" (default: %v)", val.Interface()) defaultValueString = fmt.Sprintf(" (default: %v)", val.Interface())

10
help.go
View File

@ -212,8 +212,8 @@ func printHelp(out io.Writer, templ string, data interface{}) {
func checkVersion(c *Context) bool { func checkVersion(c *Context) bool {
found := false found := false
if VersionFlag.Name != "" { if VersionFlag.GetName() != "" {
eachName(VersionFlag.Name, func(name string) { eachName(VersionFlag.GetName(), func(name string) {
if c.GlobalBool(name) || c.Bool(name) { if c.GlobalBool(name) || c.Bool(name) {
found = true found = true
} }
@ -224,8 +224,8 @@ func checkVersion(c *Context) bool {
func checkHelp(c *Context) bool { func checkHelp(c *Context) bool {
found := false found := false
if HelpFlag.Name != "" { if HelpFlag.GetName() != "" {
eachName(HelpFlag.Name, func(name string) { eachName(HelpFlag.GetName(), func(name string) {
if c.GlobalBool(name) || c.Bool(name) { if c.GlobalBool(name) || c.Bool(name) {
found = true found = true
} }
@ -260,7 +260,7 @@ func checkShellCompleteFlag(a *App, arguments []string) (bool, []string) {
pos := len(arguments) - 1 pos := len(arguments) - 1
lastArg := arguments[pos] lastArg := arguments[pos]
if lastArg != "--"+BashCompletionFlag.Name { if lastArg != "--"+BashCompletionFlag.GetName() {
return false, arguments return false, arguments
} }