From ca2a0f72bfc23378d7b6fcfa06939a89b1fa9909 Mon Sep 17 00:00:00 2001 From: Dan Buch Date: Sat, 23 Jul 2016 21:53:55 -0400 Subject: [PATCH] Ensure context lookups traverse lineage Closes #487 --- context_test.go | 52 +++++++++++++++++++++++++++++------- flag_generated.go | 65 ++++++++++++++++++++++++++++++++++++--------- generate-flag-types | 5 +++- 3 files changed, 98 insertions(+), 24 deletions(-) diff --git a/context_test.go b/context_test.go index ac61dba..1083eeb 100644 --- a/context_test.go +++ b/context_test.go @@ -35,57 +35,89 @@ func TestNewContext(t *testing.T) { func TestContext_Int(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Int("myflag", 12, "doc") - c := NewContext(nil, set, nil) + parentSet := flag.NewFlagSet("test", 0) + parentSet.Int("top-flag", 13, "doc") + parentCtx := NewContext(nil, parentSet, nil) + c := NewContext(nil, set, parentCtx) expect(t, c.Int("myflag"), 12) + expect(t, c.Int("top-flag"), 13) } func TestContext_Int64(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Int64("myflagInt64", 12, "doc") - c := NewContext(nil, set, nil) + parentSet := flag.NewFlagSet("test", 0) + parentSet.Int64("top-flag", 13, "doc") + parentCtx := NewContext(nil, parentSet, nil) + c := NewContext(nil, set, parentCtx) expect(t, c.Int64("myflagInt64"), int64(12)) + expect(t, c.Int64("top-flag"), int64(13)) } func TestContext_Uint(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Uint("myflagUint", uint(13), "doc") - c := NewContext(nil, set, nil) + parentSet := flag.NewFlagSet("test", 0) + parentSet.Uint("top-flag", uint(14), "doc") + parentCtx := NewContext(nil, parentSet, nil) + c := NewContext(nil, set, parentCtx) expect(t, c.Uint("myflagUint"), uint(13)) + expect(t, c.Uint("top-flag"), uint(14)) } func TestContext_Uint64(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Uint64("myflagUint64", uint64(9), "doc") - c := NewContext(nil, set, nil) + parentSet := flag.NewFlagSet("test", 0) + parentSet.Uint64("top-flag", uint64(10), "doc") + parentCtx := NewContext(nil, parentSet, nil) + c := NewContext(nil, set, parentCtx) expect(t, c.Uint64("myflagUint64"), uint64(9)) + expect(t, c.Uint64("top-flag"), uint64(10)) } func TestContext_Float64(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Float64("myflag", float64(17), "doc") - c := NewContext(nil, set, nil) + parentSet := flag.NewFlagSet("test", 0) + parentSet.Float64("top-flag", float64(18), "doc") + parentCtx := NewContext(nil, parentSet, nil) + c := NewContext(nil, set, parentCtx) expect(t, c.Float64("myflag"), float64(17)) + expect(t, c.Float64("top-flag"), float64(18)) } func TestContext_Duration(t *testing.T) { set := flag.NewFlagSet("test", 0) - set.Duration("myflag", time.Duration(12*time.Second), "doc") - c := NewContext(nil, set, nil) - expect(t, c.Duration("myflag"), time.Duration(12*time.Second)) + set.Duration("myflag", 12*time.Second, "doc") + parentSet := flag.NewFlagSet("test", 0) + parentSet.Duration("top-flag", 13*time.Second, "doc") + parentCtx := NewContext(nil, parentSet, nil) + c := NewContext(nil, set, parentCtx) + expect(t, c.Duration("myflag"), 12*time.Second) + expect(t, c.Duration("top-flag"), 13*time.Second) } func TestContext_String(t *testing.T) { set := flag.NewFlagSet("test", 0) set.String("myflag", "hello world", "doc") - c := NewContext(nil, set, nil) + parentSet := flag.NewFlagSet("test", 0) + parentSet.String("top-flag", "hai veld", "doc") + parentCtx := NewContext(nil, parentSet, nil) + c := NewContext(nil, set, parentCtx) expect(t, c.String("myflag"), "hello world") + expect(t, c.String("top-flag"), "hai veld") } func TestContext_Bool(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Bool("myflag", false, "doc") - c := NewContext(nil, set, nil) + parentSet := flag.NewFlagSet("test", 0) + parentSet.Bool("top-flag", true, "doc") + parentCtx := NewContext(nil, parentSet, nil) + c := NewContext(nil, set, parentCtx) expect(t, c.Bool("myflag"), false) + expect(t, c.Bool("top-flag"), true) } func TestContext_Args(t *testing.T) { diff --git a/flag_generated.go b/flag_generated.go index 626ec78..e224fb8 100644 --- a/flag_generated.go +++ b/flag_generated.go @@ -33,7 +33,10 @@ func (f *BoolFlag) Names() []string { // Bool looks up the value of a local BoolFlag, returns // false if not found func (c *Context) Bool(name string) bool { - return lookupBool(name, c.flagSet) + if fs := lookupFlagSet(name, c); fs != nil { + return lookupBool(name, fs) + } + return false } func lookupBool(name string, set *flag.FlagSet) bool { @@ -73,7 +76,10 @@ func (f *DurationFlag) Names() []string { // Duration looks up the value of a local DurationFlag, returns // 0 if not found func (c *Context) Duration(name string) time.Duration { - return lookupDuration(name, c.flagSet) + if fs := lookupFlagSet(name, c); fs != nil { + return lookupDuration(name, fs) + } + return 0 } func lookupDuration(name string, set *flag.FlagSet) time.Duration { @@ -113,7 +119,10 @@ func (f *Float64Flag) Names() []string { // Float64 looks up the value of a local Float64Flag, returns // 0 if not found func (c *Context) Float64(name string) float64 { - return lookupFloat64(name, c.flagSet) + if fs := lookupFlagSet(name, c); fs != nil { + return lookupFloat64(name, fs) + } + return 0 } func lookupFloat64(name string, set *flag.FlagSet) float64 { @@ -152,7 +161,10 @@ func (f *GenericFlag) Names() []string { // Generic looks up the value of a local GenericFlag, returns // nil if not found func (c *Context) Generic(name string) interface{} { - return lookupGeneric(name, c.flagSet) + if fs := lookupFlagSet(name, c); fs != nil { + return lookupGeneric(name, fs) + } + return nil } func lookupGeneric(name string, set *flag.FlagSet) interface{} { @@ -192,7 +204,10 @@ func (f *Int64Flag) Names() []string { // Int64 looks up the value of a local Int64Flag, returns // 0 if not found func (c *Context) Int64(name string) int64 { - return lookupInt64(name, c.flagSet) + if fs := lookupFlagSet(name, c); fs != nil { + return lookupInt64(name, fs) + } + return 0 } func lookupInt64(name string, set *flag.FlagSet) int64 { @@ -232,7 +247,10 @@ func (f *IntFlag) Names() []string { // Int looks up the value of a local IntFlag, returns // 0 if not found func (c *Context) Int(name string) int { - return lookupInt(name, c.flagSet) + if fs := lookupFlagSet(name, c); fs != nil { + return lookupInt(name, fs) + } + return 0 } func lookupInt(name string, set *flag.FlagSet) int { @@ -271,7 +289,10 @@ func (f *IntSliceFlag) Names() []string { // IntSlice looks up the value of a local IntSliceFlag, returns // nil if not found func (c *Context) IntSlice(name string) []int { - return lookupIntSlice(name, c.flagSet) + if fs := lookupFlagSet(name, c); fs != nil { + return lookupIntSlice(name, fs) + } + return nil } func lookupIntSlice(name string, set *flag.FlagSet) []int { @@ -310,7 +331,10 @@ func (f *Int64SliceFlag) Names() []string { // Int64Slice looks up the value of a local Int64SliceFlag, returns // nil if not found func (c *Context) Int64Slice(name string) []int64 { - return lookupInt64Slice(name, c.flagSet) + if fs := lookupFlagSet(name, c); fs != nil { + return lookupInt64Slice(name, fs) + } + return nil } func lookupInt64Slice(name string, set *flag.FlagSet) []int64 { @@ -349,7 +373,10 @@ func (f *Float64SliceFlag) Names() []string { // Float64Slice looks up the value of a local Float64SliceFlag, returns // nil if not found func (c *Context) Float64Slice(name string) []float64 { - return lookupFloat64Slice(name, c.flagSet) + if fs := lookupFlagSet(name, c); fs != nil { + return lookupFloat64Slice(name, fs) + } + return nil } func lookupFloat64Slice(name string, set *flag.FlagSet) []float64 { @@ -389,7 +416,10 @@ func (f *StringFlag) Names() []string { // String looks up the value of a local StringFlag, returns // "" if not found func (c *Context) String(name string) string { - return lookupString(name, c.flagSet) + if fs := lookupFlagSet(name, c); fs != nil { + return lookupString(name, fs) + } + return "" } func lookupString(name string, set *flag.FlagSet) string { @@ -428,7 +458,10 @@ func (f *StringSliceFlag) Names() []string { // StringSlice looks up the value of a local StringSliceFlag, returns // nil if not found func (c *Context) StringSlice(name string) []string { - return lookupStringSlice(name, c.flagSet) + if fs := lookupFlagSet(name, c); fs != nil { + return lookupStringSlice(name, fs) + } + return nil } func lookupStringSlice(name string, set *flag.FlagSet) []string { @@ -468,7 +501,10 @@ func (f *Uint64Flag) Names() []string { // Uint64 looks up the value of a local Uint64Flag, returns // 0 if not found func (c *Context) Uint64(name string) uint64 { - return lookupUint64(name, c.flagSet) + if fs := lookupFlagSet(name, c); fs != nil { + return lookupUint64(name, fs) + } + return 0 } func lookupUint64(name string, set *flag.FlagSet) uint64 { @@ -508,7 +544,10 @@ func (f *UintFlag) Names() []string { // Uint looks up the value of a local UintFlag, returns // 0 if not found func (c *Context) Uint(name string) uint { - return lookupUint(name, c.flagSet) + if fs := lookupFlagSet(name, c); fs != nil { + return lookupUint(name, fs) + } + return 0 } func lookupUint(name string, set *flag.FlagSet) uint { diff --git a/generate-flag-types b/generate-flag-types index 4ac7c04..6244ff9 100755 --- a/generate-flag-types +++ b/generate-flag-types @@ -169,7 +169,10 @@ def _write_cli_flag_types(outfile, types): // {name} looks up the value of a local {name}Flag, returns // {context_default} if not found func (c *Context) {name}(name string) {context_type} {{ - return lookup{name}(name, c.flagSet) + if fs := lookupFlagSet(name, c); fs != nil {{ + return lookup{name}(name, fs) + }} + return {context_default} }} func lookup{name}(name string, set *flag.FlagSet) {context_type} {{