diff --git a/context.go b/context.go index af80bdb..002fe6e 100644 --- a/context.go +++ b/context.go @@ -13,12 +13,11 @@ import ( // can be used to retrieve context-specific Args and // parsed command-line options. type Context struct { - App *App - Command Command - flagSet *flag.FlagSet - setFlags map[string]bool - globalSetFlags map[string]bool - parentContext *Context + App *App + Command Command + + flagSet *flag.FlagSet + parentContext *Context } // NewContext creates a new context. For use in when invoking an App or Command action. @@ -117,7 +116,7 @@ func (c *Context) Set(name, value string) error { func (c *Context) IsSet(name string) bool { if fs := lookupFlagSet(name, c); fs != nil { isSet := false - c.flagSet.Visit(func(f *flag.Flag) { + fs.Visit(func(f *flag.Flag) { if f.Name == name { isSet = true } @@ -146,25 +145,13 @@ func (c *Context) FlagNames() []string { return names } -// Parent returns the parent context, if any -func (c *Context) Parent() *Context { - return c.parentContext -} - // Lineage returns *this* context and all of its ancestor contexts in order from // child to parent func (c *Context) Lineage() []*Context { lineage := []*Context{} - cur := c - for { + for cur := c; cur != nil; cur = cur.parentContext { lineage = append(lineage, cur) - - if cur.parentContext == nil { - break - } - - cur = cur.parentContext } return lineage @@ -220,19 +207,6 @@ func (a Args) Swap(from, to int) error { return nil } -func globalContext(ctx *Context) *Context { - if ctx == nil { - return nil - } - - for { - if ctx.parentContext == nil { - return ctx - } - ctx = ctx.parentContext - } -} - func lookupFlagSet(name string, ctx *Context) *flag.FlagSet { for _, c := range ctx.Lineage() { if f := c.flagSet.Lookup(name); f != nil { diff --git a/context_test.go b/context_test.go index e819dbe..bdb77fd 100644 --- a/context_test.go +++ b/context_test.go @@ -84,18 +84,22 @@ func TestContext_NArg(t *testing.T) { func TestContext_IsSet(t *testing.T) { set := flag.NewFlagSet("test", 0) - set.Bool("myflag", false, "doc") - set.String("otherflag", "hello world", "doc") - globalSet := flag.NewFlagSet("test", 0) - globalSet.Bool("myflagGlobal", true, "doc") - globalCtx := NewContext(nil, globalSet, nil) - c := NewContext(nil, set, globalCtx) - set.Parse([]string{"--myflag", "bat", "baz"}) - globalSet.Parse([]string{"--myflagGlobal", "bat", "baz"}) - expect(t, c.IsSet("myflag"), true) - expect(t, c.IsSet("otherflag"), false) - expect(t, c.IsSet("bogusflag"), false) - expect(t, c.IsSet("myflagGlobal"), false) + set.Bool("one-flag", false, "doc") + set.Bool("two-flag", false, "doc") + set.String("three-flag", "hello world", "doc") + parentSet := flag.NewFlagSet("test", 0) + parentSet.Bool("top-flag", true, "doc") + parentCtx := NewContext(nil, parentSet, nil) + ctx := NewContext(nil, set, parentCtx) + + set.Parse([]string{"--one-flag", "--two-flag", "--three-flag", "frob"}) + parentSet.Parse([]string{"--top-flag"}) + + expect(t, ctx.IsSet("one-flag"), true) + expect(t, ctx.IsSet("two-flag"), true) + expect(t, ctx.IsSet("three-flag"), true) + expect(t, ctx.IsSet("top-flag"), true) + expect(t, ctx.IsSet("bogus"), false) } func TestContext_NumFlags(t *testing.T) { @@ -124,14 +128,14 @@ func TestContext_LocalFlagNames(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Bool("one-flag", false, "doc") set.String("two-flag", "hello world", "doc") - globalSet := flag.NewFlagSet("test", 0) - globalSet.Bool("top-flag", true, "doc") - globalCtx := NewContext(nil, globalSet, nil) - c := NewContext(nil, set, globalCtx) + parentSet := flag.NewFlagSet("test", 0) + parentSet.Bool("top-flag", true, "doc") + parentCtx := NewContext(nil, parentSet, nil) + ctx := NewContext(nil, set, parentCtx) set.Parse([]string{"--one-flag", "--two-flag=foo"}) - globalSet.Parse([]string{"--top-flag"}) + parentSet.Parse([]string{"--top-flag"}) - actualFlags := c.LocalFlagNames() + actualFlags := ctx.LocalFlagNames() sort.Strings(actualFlags) expect(t, actualFlags, []string{"one-flag", "two-flag"}) @@ -141,15 +145,52 @@ func TestContext_FlagNames(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Bool("one-flag", false, "doc") set.String("two-flag", "hello world", "doc") - globalSet := flag.NewFlagSet("test", 0) - globalSet.Bool("top-flag", true, "doc") - globalCtx := NewContext(nil, globalSet, nil) - c := NewContext(nil, set, globalCtx) + parentSet := flag.NewFlagSet("test", 0) + parentSet.Bool("top-flag", true, "doc") + parentCtx := NewContext(nil, parentSet, nil) + ctx := NewContext(nil, set, parentCtx) set.Parse([]string{"--one-flag", "--two-flag=foo"}) - globalSet.Parse([]string{"--top-flag"}) + parentSet.Parse([]string{"--top-flag"}) - actualFlags := c.FlagNames() + actualFlags := ctx.FlagNames() sort.Strings(actualFlags) expect(t, actualFlags, []string{"one-flag", "top-flag", "two-flag"}) } + +func TestContext_Lineage(t *testing.T) { + set := flag.NewFlagSet("test", 0) + set.Bool("local-flag", false, "doc") + parentSet := flag.NewFlagSet("test", 0) + parentSet.Bool("top-flag", true, "doc") + parentCtx := NewContext(nil, parentSet, nil) + ctx := NewContext(nil, set, parentCtx) + set.Parse([]string{"--local-flag"}) + parentSet.Parse([]string{"--top-flag"}) + + lineage := ctx.Lineage() + expect(t, len(lineage), 2) + expect(t, lineage[0], ctx) + expect(t, lineage[1], parentCtx) +} + +func TestContext_lookupFlagSet(t *testing.T) { + set := flag.NewFlagSet("test", 0) + set.Bool("local-flag", false, "doc") + parentSet := flag.NewFlagSet("test", 0) + parentSet.Bool("top-flag", true, "doc") + parentCtx := NewContext(nil, parentSet, nil) + ctx := NewContext(nil, set, parentCtx) + set.Parse([]string{"--local-flag"}) + parentSet.Parse([]string{"--top-flag"}) + + fs := lookupFlagSet("top-flag", ctx) + expect(t, fs, parentCtx.flagSet) + + fs = lookupFlagSet("local-flag", ctx) + expect(t, fs, ctx.flagSet) + + if fs := lookupFlagSet("frob", ctx); fs != nil { + t.Fail() + } +}