diff --git a/app_test.go b/app_test.go index 46ec33f..b224d63 100644 --- a/app_test.go +++ b/app_test.go @@ -627,6 +627,23 @@ func TestAppCommandNotFound(t *testing.T) { expect(t, subcommandRun, false) } +func TestGlobalFlag(t *testing.T) { + var globalFlag string + var globalFlagSet bool + app := cli.NewApp() + app.Flags = []cli.Flag{ + cli.StringFlag{Name: "global, g", Usage: "global"}, + } + app.Action = func(c *cli.Context) { + globalFlag = c.GlobalString("global") + globalFlagSet = c.GlobalIsSet("global") + } + app.Run([]string{"command", "-g", "foo"}) + expect(t, globalFlag, "foo") + expect(t, globalFlagSet, true) + +} + func TestGlobalFlagsInSubcommands(t *testing.T) { subcommandRun := false parentFlag := false diff --git a/context.go b/context.go index c75607e..f541f41 100644 --- a/context.go +++ b/context.go @@ -73,7 +73,7 @@ func (c *Context) Generic(name string) interface{} { // Looks up the value of a global int flag, returns 0 if no int flag exists func (c *Context) GlobalInt(name string) int { - if fs := lookupParentFlagSet(name, c); fs != nil { + if fs := lookupGlobalFlagSet(name, c); fs != nil { return lookupInt(name, fs) } return 0 @@ -81,7 +81,7 @@ func (c *Context) GlobalInt(name string) int { // Looks up the value of a global time.Duration flag, returns 0 if no time.Duration flag exists func (c *Context) GlobalDuration(name string) time.Duration { - if fs := lookupParentFlagSet(name, c); fs != nil { + if fs := lookupGlobalFlagSet(name, c); fs != nil { return lookupDuration(name, fs) } return 0 @@ -89,7 +89,7 @@ func (c *Context) GlobalDuration(name string) time.Duration { // Looks up the value of a global bool flag, returns false if no bool flag exists func (c *Context) GlobalBool(name string) bool { - if fs := lookupParentFlagSet(name, c); fs != nil { + if fs := lookupGlobalFlagSet(name, c); fs != nil { return lookupBool(name, fs) } return false @@ -97,7 +97,7 @@ func (c *Context) GlobalBool(name string) bool { // Looks up the value of a global string flag, returns "" if no string flag exists func (c *Context) GlobalString(name string) string { - if fs := lookupParentFlagSet(name, c); fs != nil { + if fs := lookupGlobalFlagSet(name, c); fs != nil { return lookupString(name, fs) } return "" @@ -105,7 +105,7 @@ func (c *Context) GlobalString(name string) string { // Looks up the value of a global string slice flag, returns nil if no string slice flag exists func (c *Context) GlobalStringSlice(name string) []string { - if fs := lookupParentFlagSet(name, c); fs != nil { + if fs := lookupGlobalFlagSet(name, c); fs != nil { return lookupStringSlice(name, fs) } return nil @@ -113,7 +113,7 @@ func (c *Context) GlobalStringSlice(name string) []string { // Looks up the value of a global int slice flag, returns nil if no int slice flag exists func (c *Context) GlobalIntSlice(name string) []int { - if fs := lookupParentFlagSet(name, c); fs != nil { + if fs := lookupGlobalFlagSet(name, c); fs != nil { return lookupIntSlice(name, fs) } return nil @@ -121,7 +121,7 @@ func (c *Context) GlobalIntSlice(name string) []int { // Looks up the value of a global generic flag, returns nil if no generic flag exists func (c *Context) GlobalGeneric(name string) interface{} { - if fs := lookupParentFlagSet(name, c); fs != nil { + if fs := lookupGlobalFlagSet(name, c); fs != nil { return lookupGeneric(name, fs) } return nil @@ -147,7 +147,11 @@ func (c *Context) IsSet(name string) bool { func (c *Context) GlobalIsSet(name string) bool { if c.globalSetFlags == nil { c.globalSetFlags = make(map[string]bool) - for ctx := c.parentContext; ctx != nil && c.globalSetFlags[name] == false; ctx = ctx.parentContext { + ctx := c + if ctx.parentContext != nil { + ctx = ctx.parentContext + } + for ; ctx != nil && c.globalSetFlags[name] == false; ctx = ctx.parentContext { ctx.flagSet.Visit(func(f *flag.Flag) { c.globalSetFlags[f.Name] = true }) @@ -229,8 +233,11 @@ func (a Args) Swap(from, to int) error { return nil } -func lookupParentFlagSet(name string, ctx *Context) *flag.FlagSet { - for ctx := ctx.parentContext; ctx != nil; ctx = ctx.parentContext { +func lookupGlobalFlagSet(name string, ctx *Context) *flag.FlagSet { + if ctx.parentContext != nil { + ctx = ctx.parentContext + } + for ; ctx != nil; ctx = ctx.parentContext { if f := ctx.flagSet.Lookup(name); f != nil { return ctx.flagSet }