diff --git a/altsrc/flag_test.go b/altsrc/flag_test.go index ac4d1f5..4e25be6 100644 --- a/altsrc/flag_test.go +++ b/altsrc/flag_test.go @@ -296,7 +296,7 @@ func TestFloat64ApplyInputSourceMethodEnvVarSet(t *testing.T) { } func runTest(t *testing.T, test testApplyInputSource) *cli.Context { - inputSource := &MapInputSource{valueMap: map[string]interface{}{test.FlagName: test.MapValue}} + inputSource := &MapInputSource{valueMap: map[interface{}]interface{}{test.FlagName: test.MapValue}} set := flag.NewFlagSet(test.FlagSetName, flag.ContinueOnError) c := cli.NewContext(nil, set, nil) if test.EnvVarName != "" && test.EnvVarValue != "" { diff --git a/altsrc/map_input_source.go b/altsrc/map_input_source.go index f1670fb..19f87af 100644 --- a/altsrc/map_input_source.go +++ b/altsrc/map_input_source.go @@ -3,6 +3,7 @@ package altsrc import ( "fmt" "reflect" + "strings" "time" "github.com/codegangsta/cli" @@ -11,7 +12,31 @@ import ( // MapInputSource implements InputSourceContext to return // data from the map that is loaded. type MapInputSource struct { - valueMap map[string]interface{} + valueMap map[interface{}]interface{} +} + +// nestedVal checks if the name has '.' delimiters. +// If so, it tries to traverse the tree by the '.' delimited sections to find +// a nested value for the key. +func nestedVal(name string, tree map[interface{}]interface{}) (interface{}, bool) { + if sections := strings.Split(name, "."); len(sections) > 1 { + node := tree + for _, section := range sections[:len(sections)-1] { + if child, ok := node[section]; !ok { + return nil, false + } else { + if ctype, ok := child.(map[interface{}]interface{}); !ok { + return nil, false + } else { + node = ctype + } + } + } + if val, ok := node[sections[len(sections)-1]]; ok { + return val, true + } + } + return nil, false } // Int returns an int from the map if it exists otherwise returns 0 @@ -22,7 +47,14 @@ func (fsm *MapInputSource) Int(name string) (int, error) { if !isType { return 0, incorrectTypeForFlagError(name, "int", otherGenericValue) } - + return otherValue, nil + } + nestedGenericValue, exists := nestedVal(name, fsm.valueMap) + if exists { + otherValue, isType := nestedGenericValue.(int) + if !isType { + return 0, incorrectTypeForFlagError(name, "int", nestedGenericValue) + } return otherValue, nil } @@ -39,6 +71,14 @@ func (fsm *MapInputSource) Duration(name string) (time.Duration, error) { } return otherValue, nil } + nestedGenericValue, exists := nestedVal(name, fsm.valueMap) + if exists { + otherValue, isType := nestedGenericValue.(time.Duration) + if !isType { + return 0, incorrectTypeForFlagError(name, "duration", nestedGenericValue) + } + return otherValue, nil + } return 0, nil } @@ -53,6 +93,14 @@ func (fsm *MapInputSource) Float64(name string) (float64, error) { } return otherValue, nil } + nestedGenericValue, exists := nestedVal(name, fsm.valueMap) + if exists { + otherValue, isType := nestedGenericValue.(float64) + if !isType { + return 0, incorrectTypeForFlagError(name, "float64", nestedGenericValue) + } + return otherValue, nil + } return 0, nil } @@ -67,6 +115,14 @@ func (fsm *MapInputSource) String(name string) (string, error) { } return otherValue, nil } + nestedGenericValue, exists := nestedVal(name, fsm.valueMap) + if exists { + otherValue, isType := nestedGenericValue.(string) + if !isType { + return "", incorrectTypeForFlagError(name, "string", nestedGenericValue) + } + return otherValue, nil + } return "", nil } @@ -81,6 +137,14 @@ func (fsm *MapInputSource) StringSlice(name string) ([]string, error) { } return otherValue, nil } + nestedGenericValue, exists := nestedVal(name, fsm.valueMap) + if exists { + otherValue, isType := nestedGenericValue.([]string) + if !isType { + return nil, incorrectTypeForFlagError(name, "[]string", nestedGenericValue) + } + return otherValue, nil + } return nil, nil } @@ -95,6 +159,14 @@ func (fsm *MapInputSource) IntSlice(name string) ([]int, error) { } return otherValue, nil } + nestedGenericValue, exists := nestedVal(name, fsm.valueMap) + if exists { + otherValue, isType := nestedGenericValue.([]int) + if !isType { + return nil, incorrectTypeForFlagError(name, "[]int", nestedGenericValue) + } + return otherValue, nil + } return nil, nil } @@ -109,6 +181,14 @@ func (fsm *MapInputSource) Generic(name string) (cli.Generic, error) { } return otherValue, nil } + nestedGenericValue, exists := nestedVal(name, fsm.valueMap) + if exists { + otherValue, isType := nestedGenericValue.(cli.Generic) + if !isType { + return nil, incorrectTypeForFlagError(name, "cli.Generic", nestedGenericValue) + } + return otherValue, nil + } return nil, nil } @@ -123,6 +203,14 @@ func (fsm *MapInputSource) Bool(name string) (bool, error) { } return otherValue, nil } + nestedGenericValue, exists := nestedVal(name, fsm.valueMap) + if exists { + otherValue, isType := nestedGenericValue.(bool) + if !isType { + return false, incorrectTypeForFlagError(name, "bool", nestedGenericValue) + } + return otherValue, nil + } return false, nil } @@ -137,6 +225,14 @@ func (fsm *MapInputSource) BoolT(name string) (bool, error) { } return otherValue, nil } + nestedGenericValue, exists := nestedVal(name, fsm.valueMap) + if exists { + otherValue, isType := nestedGenericValue.(bool) + if !isType { + return true, incorrectTypeForFlagError(name, "bool", nestedGenericValue) + } + return otherValue, nil + } return true, nil } diff --git a/altsrc/yaml_command_test.go b/altsrc/yaml_command_test.go index 275bc64..29ead8d 100644 --- a/altsrc/yaml_command_test.go +++ b/altsrc/yaml_command_test.go @@ -76,6 +76,40 @@ func TestCommandYamlFileTestGlobalEnvVarWins(t *testing.T) { expect(t, err, nil) } +func TestCommandYamlFileTestGlobalEnvVarWinsNested(t *testing.T) { + app := cli.NewApp() + set := flag.NewFlagSet("test", 0) + ioutil.WriteFile("current.yaml", []byte(`top: + test: 15`), 0666) + defer os.Remove("current.yaml") + + os.Setenv("THE_TEST", "10") + defer os.Setenv("THE_TEST", "") + test := []string{"test-cmd", "--load", "current.yaml"} + set.Parse(test) + + c := cli.NewContext(app, set, nil) + + command := &cli.Command{ + Name: "test-cmd", + Aliases: []string{"tc"}, + Usage: "this is for testing", + Description: "testing", + Action: func(c *cli.Context) { + val := c.Int("top.test") + expect(t, val, 10) + }, + Flags: []cli.Flag{ + NewIntFlag(cli.IntFlag{Name: "top.test", EnvVar: "THE_TEST"}), + cli.StringFlag{Name: "load"}}, + } + command.Before = InitInputSourceWithContext(command.Flags, NewYamlSourceFromFlagFunc("load")) + + err := command.Run(c) + + expect(t, err, nil) +} + func TestCommandYamlFileTestSpecifiedFlagWins(t *testing.T) { app := cli.NewApp() set := flag.NewFlagSet("test", 0) @@ -107,6 +141,38 @@ func TestCommandYamlFileTestSpecifiedFlagWins(t *testing.T) { expect(t, err, nil) } +func TestCommandYamlFileTestSpecifiedFlagWinsNested(t *testing.T) { + app := cli.NewApp() + set := flag.NewFlagSet("test", 0) + ioutil.WriteFile("current.yaml", []byte(`top: + test: 15`), 0666) + defer os.Remove("current.yaml") + + test := []string{"test-cmd", "--load", "current.yaml", "--top.test", "7"} + set.Parse(test) + + c := cli.NewContext(app, set, nil) + + command := &cli.Command{ + Name: "test-cmd", + Aliases: []string{"tc"}, + Usage: "this is for testing", + Description: "testing", + Action: func(c *cli.Context) { + val := c.Int("top.test") + expect(t, val, 7) + }, + Flags: []cli.Flag{ + NewIntFlag(cli.IntFlag{Name: "top.test"}), + cli.StringFlag{Name: "load"}}, + } + command.Before = InitInputSourceWithContext(command.Flags, NewYamlSourceFromFlagFunc("load")) + + err := command.Run(c) + + expect(t, err, nil) +} + func TestCommandYamlFileTestDefaultValueFileWins(t *testing.T) { app := cli.NewApp() set := flag.NewFlagSet("test", 0) @@ -138,6 +204,38 @@ func TestCommandYamlFileTestDefaultValueFileWins(t *testing.T) { expect(t, err, nil) } +func TestCommandYamlFileTestDefaultValueFileWinsNested(t *testing.T) { + app := cli.NewApp() + set := flag.NewFlagSet("test", 0) + ioutil.WriteFile("current.yaml", []byte(`top: + test: 15`), 0666) + defer os.Remove("current.yaml") + + test := []string{"test-cmd", "--load", "current.yaml"} + set.Parse(test) + + c := cli.NewContext(app, set, nil) + + command := &cli.Command{ + Name: "test-cmd", + Aliases: []string{"tc"}, + Usage: "this is for testing", + Description: "testing", + Action: func(c *cli.Context) { + val := c.Int("top.test") + expect(t, val, 15) + }, + Flags: []cli.Flag{ + NewIntFlag(cli.IntFlag{Name: "top.test", Value: 7}), + cli.StringFlag{Name: "load"}}, + } + command.Before = InitInputSourceWithContext(command.Flags, NewYamlSourceFromFlagFunc("load")) + + err := command.Run(c) + + expect(t, err, nil) +} + func TestCommandYamlFileFlagHasDefaultGlobalEnvYamlSetGlobalEnvWins(t *testing.T) { app := cli.NewApp() set := flag.NewFlagSet("test", 0) @@ -170,3 +268,37 @@ func TestCommandYamlFileFlagHasDefaultGlobalEnvYamlSetGlobalEnvWins(t *testing.T expect(t, err, nil) } + +func TestCommandYamlFileFlagHasDefaultGlobalEnvYamlSetGlobalEnvWinsNested(t *testing.T) { + app := cli.NewApp() + set := flag.NewFlagSet("test", 0) + ioutil.WriteFile("current.yaml", []byte(`top: + test: 15`), 0666) + defer os.Remove("current.yaml") + + os.Setenv("THE_TEST", "11") + defer os.Setenv("THE_TEST", "") + + test := []string{"test-cmd", "--load", "current.yaml"} + set.Parse(test) + + c := cli.NewContext(app, set, nil) + + command := &cli.Command{ + Name: "test-cmd", + Aliases: []string{"tc"}, + Usage: "this is for testing", + Description: "testing", + Action: func(c *cli.Context) { + val := c.Int("top.test") + expect(t, val, 11) + }, + Flags: []cli.Flag{ + NewIntFlag(cli.IntFlag{Name: "top.test", Value: 7, EnvVar: "THE_TEST"}), + cli.StringFlag{Name: "load"}}, + } + command.Before = InitInputSourceWithContext(command.Flags, NewYamlSourceFromFlagFunc("load")) + err := command.Run(c) + + expect(t, err, nil) +} diff --git a/altsrc/yaml_file_loader.go b/altsrc/yaml_file_loader.go index 4fb0965..01797ad 100644 --- a/altsrc/yaml_file_loader.go +++ b/altsrc/yaml_file_loader.go @@ -24,7 +24,7 @@ type yamlSourceContext struct { // NewYamlSourceFromFile creates a new Yaml InputSourceContext from a filepath. func NewYamlSourceFromFile(file string) (InputSourceContext, error) { ysc := &yamlSourceContext{FilePath: file} - var results map[string]interface{} + var results map[interface{}]interface{} err := readCommandYaml(ysc.FilePath, &results) if err != nil { return nil, fmt.Errorf("Unable to load Yaml file '%s': inner error: \n'%v'", ysc.FilePath, err.Error())