diff --git a/help.go b/help.go index e6ba0de..5bfe401 100644 --- a/help.go +++ b/help.go @@ -186,21 +186,29 @@ func printHelp(out io.Writer, templ string, data interface{}) { } func checkVersion(c *Context) bool { - if c.GlobalBool("version") || c.GlobalBool("v") || c.Bool("version") || c.Bool("v") { - ShowVersion(c) - return true + found := false + if VersionFlag.Name != "" { + eachName(VersionFlag.Name, func(name string) { + if c.GlobalBool(name) || c.Bool(name) { + ShowVersion(c) + found = true + } + }) } - - return false + return found } func checkHelp(c *Context) bool { - if c.GlobalBool("h") || c.GlobalBool("help") || c.Bool("h") || c.Bool("help") { - ShowAppHelp(c) - return true + found := false + if HelpFlag.Name != "" { + eachName(HelpFlag.Name, func(name string) { + if c.GlobalBool(name) || c.Bool(name) { + ShowAppHelp(c) + found = true + } + }) } - - return false + return found } func checkCommandHelp(c *Context, name string) bool { diff --git a/help_test.go b/help_test.go index 42d0284..350e263 100644 --- a/help_test.go +++ b/help_test.go @@ -34,3 +34,61 @@ func Test_ShowAppHelp_NoVersion(t *testing.T) { t.Errorf("expected\n%snot to include %s", output.String(), "VERSION:") } } + +func Test_Help_Custom_Flags(t *testing.T) { + oldFlag := HelpFlag + defer func() { + HelpFlag = oldFlag + }() + + HelpFlag = BoolFlag{ + Name: "help, x", + Usage: "show help", + } + + app := App{ + Flags: []Flag{ + BoolFlag{Name: "foo, h"}, + }, + Action: func(ctx *Context) { + if ctx.Bool("h") != true { + t.Errorf("custom help flag not set") + } + }, + } + output := new(bytes.Buffer) + app.Writer = output + app.Run([]string{"test", "-h"}) + if output.Len() > 0 { + t.Errorf("unexpected output: %s", output.String()) + } +} + +func Test_Version_Custom_Flags(t *testing.T) { + oldFlag := VersionFlag + defer func() { + VersionFlag = oldFlag + }() + + VersionFlag = BoolFlag{ + Name: "version, V", + Usage: "show version", + } + + app := App{ + Flags: []Flag{ + BoolFlag{Name: "foo, v"}, + }, + Action: func(ctx *Context) { + if ctx.Bool("v") != true { + t.Errorf("custom version flag not set") + } + }, + } + output := new(bytes.Buffer) + app.Writer = output + app.Run([]string{"test", "-v"}) + if output.Len() > 0 { + t.Errorf("unexpected output: %s", output.String()) + } +}