Ensure context lookups traverse lineage

Closes #487
This commit is contained in:
Dan Buch 2016-07-23 21:53:55 -04:00
parent 924956d859
commit ca2a0f72bf
No known key found for this signature in database
GPG Key ID: FAEF12936DD3E3EC
3 changed files with 98 additions and 24 deletions

View File

@ -35,57 +35,89 @@ func TestNewContext(t *testing.T) {
func TestContext_Int(t *testing.T) { func TestContext_Int(t *testing.T) {
set := flag.NewFlagSet("test", 0) set := flag.NewFlagSet("test", 0)
set.Int("myflag", 12, "doc") set.Int("myflag", 12, "doc")
c := NewContext(nil, set, nil) parentSet := flag.NewFlagSet("test", 0)
parentSet.Int("top-flag", 13, "doc")
parentCtx := NewContext(nil, parentSet, nil)
c := NewContext(nil, set, parentCtx)
expect(t, c.Int("myflag"), 12) expect(t, c.Int("myflag"), 12)
expect(t, c.Int("top-flag"), 13)
} }
func TestContext_Int64(t *testing.T) { func TestContext_Int64(t *testing.T) {
set := flag.NewFlagSet("test", 0) set := flag.NewFlagSet("test", 0)
set.Int64("myflagInt64", 12, "doc") set.Int64("myflagInt64", 12, "doc")
c := NewContext(nil, set, nil) parentSet := flag.NewFlagSet("test", 0)
parentSet.Int64("top-flag", 13, "doc")
parentCtx := NewContext(nil, parentSet, nil)
c := NewContext(nil, set, parentCtx)
expect(t, c.Int64("myflagInt64"), int64(12)) expect(t, c.Int64("myflagInt64"), int64(12))
expect(t, c.Int64("top-flag"), int64(13))
} }
func TestContext_Uint(t *testing.T) { func TestContext_Uint(t *testing.T) {
set := flag.NewFlagSet("test", 0) set := flag.NewFlagSet("test", 0)
set.Uint("myflagUint", uint(13), "doc") set.Uint("myflagUint", uint(13), "doc")
c := NewContext(nil, set, nil) parentSet := flag.NewFlagSet("test", 0)
parentSet.Uint("top-flag", uint(14), "doc")
parentCtx := NewContext(nil, parentSet, nil)
c := NewContext(nil, set, parentCtx)
expect(t, c.Uint("myflagUint"), uint(13)) expect(t, c.Uint("myflagUint"), uint(13))
expect(t, c.Uint("top-flag"), uint(14))
} }
func TestContext_Uint64(t *testing.T) { func TestContext_Uint64(t *testing.T) {
set := flag.NewFlagSet("test", 0) set := flag.NewFlagSet("test", 0)
set.Uint64("myflagUint64", uint64(9), "doc") set.Uint64("myflagUint64", uint64(9), "doc")
c := NewContext(nil, set, nil) parentSet := flag.NewFlagSet("test", 0)
parentSet.Uint64("top-flag", uint64(10), "doc")
parentCtx := NewContext(nil, parentSet, nil)
c := NewContext(nil, set, parentCtx)
expect(t, c.Uint64("myflagUint64"), uint64(9)) expect(t, c.Uint64("myflagUint64"), uint64(9))
expect(t, c.Uint64("top-flag"), uint64(10))
} }
func TestContext_Float64(t *testing.T) { func TestContext_Float64(t *testing.T) {
set := flag.NewFlagSet("test", 0) set := flag.NewFlagSet("test", 0)
set.Float64("myflag", float64(17), "doc") set.Float64("myflag", float64(17), "doc")
c := NewContext(nil, set, nil) parentSet := flag.NewFlagSet("test", 0)
parentSet.Float64("top-flag", float64(18), "doc")
parentCtx := NewContext(nil, parentSet, nil)
c := NewContext(nil, set, parentCtx)
expect(t, c.Float64("myflag"), float64(17)) expect(t, c.Float64("myflag"), float64(17))
expect(t, c.Float64("top-flag"), float64(18))
} }
func TestContext_Duration(t *testing.T) { func TestContext_Duration(t *testing.T) {
set := flag.NewFlagSet("test", 0) set := flag.NewFlagSet("test", 0)
set.Duration("myflag", time.Duration(12*time.Second), "doc") set.Duration("myflag", 12*time.Second, "doc")
c := NewContext(nil, set, nil) parentSet := flag.NewFlagSet("test", 0)
expect(t, c.Duration("myflag"), time.Duration(12*time.Second)) parentSet.Duration("top-flag", 13*time.Second, "doc")
parentCtx := NewContext(nil, parentSet, nil)
c := NewContext(nil, set, parentCtx)
expect(t, c.Duration("myflag"), 12*time.Second)
expect(t, c.Duration("top-flag"), 13*time.Second)
} }
func TestContext_String(t *testing.T) { func TestContext_String(t *testing.T) {
set := flag.NewFlagSet("test", 0) set := flag.NewFlagSet("test", 0)
set.String("myflag", "hello world", "doc") set.String("myflag", "hello world", "doc")
c := NewContext(nil, set, nil) parentSet := flag.NewFlagSet("test", 0)
parentSet.String("top-flag", "hai veld", "doc")
parentCtx := NewContext(nil, parentSet, nil)
c := NewContext(nil, set, parentCtx)
expect(t, c.String("myflag"), "hello world") expect(t, c.String("myflag"), "hello world")
expect(t, c.String("top-flag"), "hai veld")
} }
func TestContext_Bool(t *testing.T) { func TestContext_Bool(t *testing.T) {
set := flag.NewFlagSet("test", 0) set := flag.NewFlagSet("test", 0)
set.Bool("myflag", false, "doc") set.Bool("myflag", false, "doc")
c := NewContext(nil, set, nil) parentSet := flag.NewFlagSet("test", 0)
parentSet.Bool("top-flag", true, "doc")
parentCtx := NewContext(nil, parentSet, nil)
c := NewContext(nil, set, parentCtx)
expect(t, c.Bool("myflag"), false) expect(t, c.Bool("myflag"), false)
expect(t, c.Bool("top-flag"), true)
} }
func TestContext_Args(t *testing.T) { func TestContext_Args(t *testing.T) {

View File

@ -33,7 +33,10 @@ func (f *BoolFlag) Names() []string {
// Bool looks up the value of a local BoolFlag, returns // Bool looks up the value of a local BoolFlag, returns
// false if not found // false if not found
func (c *Context) Bool(name string) bool { func (c *Context) Bool(name string) bool {
return lookupBool(name, c.flagSet) if fs := lookupFlagSet(name, c); fs != nil {
return lookupBool(name, fs)
}
return false
} }
func lookupBool(name string, set *flag.FlagSet) bool { func lookupBool(name string, set *flag.FlagSet) bool {
@ -73,7 +76,10 @@ func (f *DurationFlag) Names() []string {
// Duration looks up the value of a local DurationFlag, returns // Duration looks up the value of a local DurationFlag, returns
// 0 if not found // 0 if not found
func (c *Context) Duration(name string) time.Duration { func (c *Context) Duration(name string) time.Duration {
return lookupDuration(name, c.flagSet) if fs := lookupFlagSet(name, c); fs != nil {
return lookupDuration(name, fs)
}
return 0
} }
func lookupDuration(name string, set *flag.FlagSet) time.Duration { func lookupDuration(name string, set *flag.FlagSet) time.Duration {
@ -113,7 +119,10 @@ func (f *Float64Flag) Names() []string {
// Float64 looks up the value of a local Float64Flag, returns // Float64 looks up the value of a local Float64Flag, returns
// 0 if not found // 0 if not found
func (c *Context) Float64(name string) float64 { func (c *Context) Float64(name string) float64 {
return lookupFloat64(name, c.flagSet) if fs := lookupFlagSet(name, c); fs != nil {
return lookupFloat64(name, fs)
}
return 0
} }
func lookupFloat64(name string, set *flag.FlagSet) float64 { func lookupFloat64(name string, set *flag.FlagSet) float64 {
@ -152,7 +161,10 @@ func (f *GenericFlag) Names() []string {
// Generic looks up the value of a local GenericFlag, returns // Generic looks up the value of a local GenericFlag, returns
// nil if not found // nil if not found
func (c *Context) Generic(name string) interface{} { func (c *Context) Generic(name string) interface{} {
return lookupGeneric(name, c.flagSet) if fs := lookupFlagSet(name, c); fs != nil {
return lookupGeneric(name, fs)
}
return nil
} }
func lookupGeneric(name string, set *flag.FlagSet) interface{} { func lookupGeneric(name string, set *flag.FlagSet) interface{} {
@ -192,7 +204,10 @@ func (f *Int64Flag) Names() []string {
// Int64 looks up the value of a local Int64Flag, returns // Int64 looks up the value of a local Int64Flag, returns
// 0 if not found // 0 if not found
func (c *Context) Int64(name string) int64 { func (c *Context) Int64(name string) int64 {
return lookupInt64(name, c.flagSet) if fs := lookupFlagSet(name, c); fs != nil {
return lookupInt64(name, fs)
}
return 0
} }
func lookupInt64(name string, set *flag.FlagSet) int64 { func lookupInt64(name string, set *flag.FlagSet) int64 {
@ -232,7 +247,10 @@ func (f *IntFlag) Names() []string {
// Int looks up the value of a local IntFlag, returns // Int looks up the value of a local IntFlag, returns
// 0 if not found // 0 if not found
func (c *Context) Int(name string) int { func (c *Context) Int(name string) int {
return lookupInt(name, c.flagSet) if fs := lookupFlagSet(name, c); fs != nil {
return lookupInt(name, fs)
}
return 0
} }
func lookupInt(name string, set *flag.FlagSet) int { func lookupInt(name string, set *flag.FlagSet) int {
@ -271,7 +289,10 @@ func (f *IntSliceFlag) Names() []string {
// IntSlice looks up the value of a local IntSliceFlag, returns // IntSlice looks up the value of a local IntSliceFlag, returns
// nil if not found // nil if not found
func (c *Context) IntSlice(name string) []int { func (c *Context) IntSlice(name string) []int {
return lookupIntSlice(name, c.flagSet) if fs := lookupFlagSet(name, c); fs != nil {
return lookupIntSlice(name, fs)
}
return nil
} }
func lookupIntSlice(name string, set *flag.FlagSet) []int { func lookupIntSlice(name string, set *flag.FlagSet) []int {
@ -310,7 +331,10 @@ func (f *Int64SliceFlag) Names() []string {
// Int64Slice looks up the value of a local Int64SliceFlag, returns // Int64Slice looks up the value of a local Int64SliceFlag, returns
// nil if not found // nil if not found
func (c *Context) Int64Slice(name string) []int64 { func (c *Context) Int64Slice(name string) []int64 {
return lookupInt64Slice(name, c.flagSet) if fs := lookupFlagSet(name, c); fs != nil {
return lookupInt64Slice(name, fs)
}
return nil
} }
func lookupInt64Slice(name string, set *flag.FlagSet) []int64 { func lookupInt64Slice(name string, set *flag.FlagSet) []int64 {
@ -349,7 +373,10 @@ func (f *Float64SliceFlag) Names() []string {
// Float64Slice looks up the value of a local Float64SliceFlag, returns // Float64Slice looks up the value of a local Float64SliceFlag, returns
// nil if not found // nil if not found
func (c *Context) Float64Slice(name string) []float64 { func (c *Context) Float64Slice(name string) []float64 {
return lookupFloat64Slice(name, c.flagSet) if fs := lookupFlagSet(name, c); fs != nil {
return lookupFloat64Slice(name, fs)
}
return nil
} }
func lookupFloat64Slice(name string, set *flag.FlagSet) []float64 { func lookupFloat64Slice(name string, set *flag.FlagSet) []float64 {
@ -389,7 +416,10 @@ func (f *StringFlag) Names() []string {
// String looks up the value of a local StringFlag, returns // String looks up the value of a local StringFlag, returns
// "" if not found // "" if not found
func (c *Context) String(name string) string { func (c *Context) String(name string) string {
return lookupString(name, c.flagSet) if fs := lookupFlagSet(name, c); fs != nil {
return lookupString(name, fs)
}
return ""
} }
func lookupString(name string, set *flag.FlagSet) string { func lookupString(name string, set *flag.FlagSet) string {
@ -428,7 +458,10 @@ func (f *StringSliceFlag) Names() []string {
// StringSlice looks up the value of a local StringSliceFlag, returns // StringSlice looks up the value of a local StringSliceFlag, returns
// nil if not found // nil if not found
func (c *Context) StringSlice(name string) []string { func (c *Context) StringSlice(name string) []string {
return lookupStringSlice(name, c.flagSet) if fs := lookupFlagSet(name, c); fs != nil {
return lookupStringSlice(name, fs)
}
return nil
} }
func lookupStringSlice(name string, set *flag.FlagSet) []string { func lookupStringSlice(name string, set *flag.FlagSet) []string {
@ -468,7 +501,10 @@ func (f *Uint64Flag) Names() []string {
// Uint64 looks up the value of a local Uint64Flag, returns // Uint64 looks up the value of a local Uint64Flag, returns
// 0 if not found // 0 if not found
func (c *Context) Uint64(name string) uint64 { func (c *Context) Uint64(name string) uint64 {
return lookupUint64(name, c.flagSet) if fs := lookupFlagSet(name, c); fs != nil {
return lookupUint64(name, fs)
}
return 0
} }
func lookupUint64(name string, set *flag.FlagSet) uint64 { func lookupUint64(name string, set *flag.FlagSet) uint64 {
@ -508,7 +544,10 @@ func (f *UintFlag) Names() []string {
// Uint looks up the value of a local UintFlag, returns // Uint looks up the value of a local UintFlag, returns
// 0 if not found // 0 if not found
func (c *Context) Uint(name string) uint { func (c *Context) Uint(name string) uint {
return lookupUint(name, c.flagSet) if fs := lookupFlagSet(name, c); fs != nil {
return lookupUint(name, fs)
}
return 0
} }
func lookupUint(name string, set *flag.FlagSet) uint { func lookupUint(name string, set *flag.FlagSet) uint {

View File

@ -169,7 +169,10 @@ def _write_cli_flag_types(outfile, types):
// {name} looks up the value of a local {name}Flag, returns // {name} looks up the value of a local {name}Flag, returns
// {context_default} if not found // {context_default} if not found
func (c *Context) {name}(name string) {context_type} {{ func (c *Context) {name}(name string) {context_type} {{
return lookup{name}(name, c.flagSet) if fs := lookupFlagSet(name, c); fs != nil {{
return lookup{name}(name, fs)
}}
return {context_default}
}} }}
func lookup{name}(name string, set *flag.FlagSet) {context_type} {{ func lookup{name}(name string, set *flag.FlagSet) {context_type} {{