From 65d50017d4f34772b8d767fb7478b9416b5d30c5 Mon Sep 17 00:00:00 2001 From: Fabian Ruff Date: Mon, 18 May 2015 17:39:48 +0200 Subject: [PATCH] search context hierachy for global flags --- app.go | 6 ++--- app_test.go | 10 ++++++++- command.go | 2 +- command_test.go | 4 ++-- context.go | 60 +++++++++++++++++++++++++++++++++++++------------ context_test.go | 24 +++++++++++--------- 6 files changed, 75 insertions(+), 31 deletions(-) diff --git a/app.go b/app.go index 891416d..5e551b8 100644 --- a/app.go +++ b/app.go @@ -104,12 +104,12 @@ func (a *App) Run(arguments []string) (err error) { nerr := normalizeFlags(a.Flags, set) if nerr != nil { fmt.Fprintln(a.Writer, nerr) - context := NewContext(a, set, set) + context := NewContext(a, set, nil) ShowAppHelp(context) fmt.Fprintln(a.Writer) return nerr } - context := NewContext(a, set, set) + context := NewContext(a, set, nil) if err != nil { fmt.Fprintf(a.Writer, "Incorrect Usage.\n\n") @@ -190,7 +190,7 @@ func (a *App) RunAsSubcommand(ctx *Context) (err error) { set.SetOutput(ioutil.Discard) err = set.Parse(ctx.Args().Tail()) nerr := normalizeFlags(a.Flags, set) - context := NewContext(a, set, ctx.globalSet) + context := NewContext(a, set, ctx) if nerr != nil { fmt.Fprintln(a.Writer, nerr) diff --git a/app_test.go b/app_test.go index ae8bb0f..4a0aa3c 100644 --- a/app_test.go +++ b/app_test.go @@ -597,6 +597,7 @@ func TestAppCommandNotFound(t *testing.T) { func TestGlobalFlagsInSubcommands(t *testing.T) { subcommandRun := false + parentFlag := false app := cli.NewApp() app.Flags = []cli.Flag{ @@ -606,6 +607,9 @@ func TestGlobalFlagsInSubcommands(t *testing.T) { app.Commands = []cli.Command{ cli.Command{ Name: "foo", + Flags: []cli.Flag{ + cli.BoolFlag{Name: "parent, p", Usage: "Parent flag"}, + }, Subcommands: []cli.Command{ { Name: "bar", @@ -613,15 +617,19 @@ func TestGlobalFlagsInSubcommands(t *testing.T) { if c.GlobalBool("debug") { subcommandRun = true } + if c.GlobalBool("parent") { + parentFlag = true + } }, }, }, }, } - app.Run([]string{"command", "-d", "foo", "bar"}) + app.Run([]string{"command", "-d", "foo", "-p", "bar"}) expect(t, subcommandRun, true) + expect(t, parentFlag, true) } func TestApp_Run_CommandWithSubcommandHasHelpTopic(t *testing.T) { diff --git a/command.go b/command.go index d0bbd0c..b721c0a 100644 --- a/command.go +++ b/command.go @@ -105,7 +105,7 @@ func (c Command) Run(ctx *Context) error { fmt.Fprintln(ctx.App.Writer) return nerr } - context := NewContext(ctx.App, set, ctx.globalSet) + context := NewContext(ctx.App, set, ctx) if checkCommandCompletions(context, c.Name) { return nil diff --git a/command_test.go b/command_test.go index 4125b0c..db81db2 100644 --- a/command_test.go +++ b/command_test.go @@ -13,7 +13,7 @@ func TestCommandDoNotIgnoreFlags(t *testing.T) { test := []string{"blah", "blah", "-break"} set.Parse(test) - c := cli.NewContext(app, set, set) + c := cli.NewContext(app, set, nil) command := cli.Command{ Name: "test-cmd", @@ -33,7 +33,7 @@ func TestCommandIgnoreFlags(t *testing.T) { test := []string{"blah", "blah"} set.Parse(test) - c := cli.NewContext(app, set, set) + c := cli.NewContext(app, set, nil) command := cli.Command{ Name: "test-cmd", diff --git a/context.go b/context.go index 37221bd..5b67129 100644 --- a/context.go +++ b/context.go @@ -16,14 +16,14 @@ type Context struct { App *App Command Command flagSet *flag.FlagSet - globalSet *flag.FlagSet setFlags map[string]bool globalSetFlags map[string]bool + parentContext *Context } // Creates a new context. For use in when invoking an App or Command action. -func NewContext(app *App, set *flag.FlagSet, globalSet *flag.FlagSet) *Context { - return &Context{App: app, flagSet: set, globalSet: globalSet} +func NewContext(app *App, set *flag.FlagSet, parentCtx *Context) *Context { + return &Context{App: app, flagSet: set, parentContext: parentCtx} } // Looks up the value of a local int flag, returns 0 if no int flag exists @@ -73,37 +73,58 @@ func (c *Context) Generic(name string) interface{} { // Looks up the value of a global int flag, returns 0 if no int flag exists func (c *Context) GlobalInt(name string) int { - return lookupInt(name, c.globalSet) + if fs := lookupParentFlagSet(name, c); fs != nil { + return lookupInt(name, fs) + } + return 0 } // Looks up the value of a global time.Duration flag, returns 0 if no time.Duration flag exists func (c *Context) GlobalDuration(name string) time.Duration { - return lookupDuration(name, c.globalSet) + if fs := lookupParentFlagSet(name, c); fs != nil { + return lookupDuration(name, fs) + } + return 0 } // Looks up the value of a global bool flag, returns false if no bool flag exists func (c *Context) GlobalBool(name string) bool { - return lookupBool(name, c.globalSet) + if fs := lookupParentFlagSet(name, c); fs != nil { + return lookupBool(name, fs) + } + return false } // Looks up the value of a global string flag, returns "" if no string flag exists func (c *Context) GlobalString(name string) string { - return lookupString(name, c.globalSet) + if fs := lookupParentFlagSet(name, c); fs != nil { + return lookupString(name, fs) + } + return "" } // Looks up the value of a global string slice flag, returns nil if no string slice flag exists func (c *Context) GlobalStringSlice(name string) []string { - return lookupStringSlice(name, c.globalSet) + if fs := lookupParentFlagSet(name, c); fs != nil { + return lookupStringSlice(name, fs) + } + return nil } // Looks up the value of a global int slice flag, returns nil if no int slice flag exists func (c *Context) GlobalIntSlice(name string) []int { - return lookupIntSlice(name, c.globalSet) + if fs := lookupParentFlagSet(name, c); fs != nil { + return lookupIntSlice(name, fs) + } + return nil } // Looks up the value of a global generic flag, returns nil if no generic flag exists func (c *Context) GlobalGeneric(name string) interface{} { - return lookupGeneric(name, c.globalSet) + if fs := lookupParentFlagSet(name, c); fs != nil { + return lookupGeneric(name, fs) + } + return nil } // Returns the number of flags set @@ -126,11 +147,13 @@ func (c *Context) IsSet(name string) bool { func (c *Context) GlobalIsSet(name string) bool { if c.globalSetFlags == nil { c.globalSetFlags = make(map[string]bool) - c.globalSet.Visit(func(f *flag.Flag) { - c.globalSetFlags[f.Name] = true - }) + for ctx := c.parentContext; ctx != nil && c.globalSetFlags[name] == false; ctx = ctx.parentContext { + ctx.flagSet.Visit(func(f *flag.Flag) { + c.globalSetFlags[f.Name] = true + }) + } } - return c.globalSetFlags[name] == true + return c.globalSetFlags[name] } // Returns a slice of flag names used in this context. @@ -201,6 +224,15 @@ func (a Args) Swap(from, to int) error { return nil } +func lookupParentFlagSet(name string, ctx *Context) *flag.FlagSet { + for ctx := ctx.parentContext; ctx != nil; ctx = ctx.parentContext { + if f := ctx.flagSet.Lookup(name); f != nil { + return ctx.flagSet + } + } + return nil +} + func lookupInt(name string, set *flag.FlagSet) int { f := set.Lookup(name) if f != nil { diff --git a/context_test.go b/context_test.go index d4a1877..6c27d06 100644 --- a/context_test.go +++ b/context_test.go @@ -13,8 +13,9 @@ func TestNewContext(t *testing.T) { set.Int("myflag", 12, "doc") globalSet := flag.NewFlagSet("test", 0) globalSet.Int("myflag", 42, "doc") + globalCtx := cli.NewContext(nil, globalSet, nil) command := cli.Command{Name: "mycommand"} - c := cli.NewContext(nil, set, globalSet) + c := cli.NewContext(nil, set, globalCtx) c.Command = command expect(t, c.Int("myflag"), 12) expect(t, c.GlobalInt("myflag"), 42) @@ -24,42 +25,42 @@ func TestNewContext(t *testing.T) { func TestContext_Int(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Int("myflag", 12, "doc") - c := cli.NewContext(nil, set, set) + c := cli.NewContext(nil, set, nil) expect(t, c.Int("myflag"), 12) } func TestContext_Duration(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Duration("myflag", time.Duration(12*time.Second), "doc") - c := cli.NewContext(nil, set, set) + c := cli.NewContext(nil, set, nil) expect(t, c.Duration("myflag"), time.Duration(12*time.Second)) } func TestContext_String(t *testing.T) { set := flag.NewFlagSet("test", 0) set.String("myflag", "hello world", "doc") - c := cli.NewContext(nil, set, set) + c := cli.NewContext(nil, set, nil) expect(t, c.String("myflag"), "hello world") } func TestContext_Bool(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Bool("myflag", false, "doc") - c := cli.NewContext(nil, set, set) + c := cli.NewContext(nil, set, nil) expect(t, c.Bool("myflag"), false) } func TestContext_BoolT(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Bool("myflag", true, "doc") - c := cli.NewContext(nil, set, set) + c := cli.NewContext(nil, set, nil) expect(t, c.BoolT("myflag"), true) } func TestContext_Args(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Bool("myflag", false, "doc") - c := cli.NewContext(nil, set, set) + c := cli.NewContext(nil, set, nil) set.Parse([]string{"--myflag", "bat", "baz"}) expect(t, len(c.Args()), 2) expect(t, c.Bool("myflag"), true) @@ -71,7 +72,8 @@ func TestContext_IsSet(t *testing.T) { set.String("otherflag", "hello world", "doc") globalSet := flag.NewFlagSet("test", 0) globalSet.Bool("myflagGlobal", true, "doc") - c := cli.NewContext(nil, set, globalSet) + globalCtx := cli.NewContext(nil, globalSet, nil) + c := cli.NewContext(nil, set, globalCtx) set.Parse([]string{"--myflag", "bat", "baz"}) globalSet.Parse([]string{"--myflagGlobal", "bat", "baz"}) expect(t, c.IsSet("myflag"), true) @@ -87,7 +89,8 @@ func TestContext_GlobalIsSet(t *testing.T) { globalSet := flag.NewFlagSet("test", 0) globalSet.Bool("myflagGlobal", true, "doc") globalSet.Bool("myflagGlobalUnset", true, "doc") - c := cli.NewContext(nil, set, globalSet) + globalCtx := cli.NewContext(nil, globalSet, nil) + c := cli.NewContext(nil, set, globalCtx) set.Parse([]string{"--myflag", "bat", "baz"}) globalSet.Parse([]string{"--myflagGlobal", "bat", "baz"}) expect(t, c.GlobalIsSet("myflag"), false) @@ -104,7 +107,8 @@ func TestContext_NumFlags(t *testing.T) { set.String("otherflag", "hello world", "doc") globalSet := flag.NewFlagSet("test", 0) globalSet.Bool("myflagGlobal", true, "doc") - c := cli.NewContext(nil, set, globalSet) + globalCtx := cli.NewContext(nil, globalSet, nil) + c := cli.NewContext(nil, set, globalCtx) set.Parse([]string{"--myflag", "--otherflag=foo"}) globalSet.Parse([]string{"--myflagGlobal"}) expect(t, c.NumFlags(), 2)