diff --git a/command.go b/command.go index fac754d..824e77b 100644 --- a/command.go +++ b/command.go @@ -74,34 +74,40 @@ func (c Command) Run(ctx *Context) error { set := flagSet(c.Name, c.Flags) set.SetOutput(ioutil.Discard) - firstFlagIndex := -1 - terminatorIndex := -1 - for index, arg := range ctx.Args() { - if arg == "--" { - terminatorIndex = index - break - } else if strings.HasPrefix(arg, "-") && firstFlagIndex == -1 { - firstFlagIndex = index - } - } - var err error - if firstFlagIndex > -1 && !c.SkipFlagParsing { - args := ctx.Args() - regularArgs := make([]string, len(args[1:firstFlagIndex])) - copy(regularArgs, args[1:firstFlagIndex]) - - var flagArgs []string - if terminatorIndex > -1 { - flagArgs = args[firstFlagIndex:terminatorIndex] - regularArgs = append(regularArgs, args[terminatorIndex:]...) - } else { - flagArgs = args[firstFlagIndex:] + if !c.SkipFlagParsing { + firstFlagIndex := -1 + terminatorIndex := -1 + for index, arg := range ctx.Args() { + if arg == "--" { + terminatorIndex = index + break + } else if strings.HasPrefix(arg, "-") && firstFlagIndex == -1 { + firstFlagIndex = index + } } - err = set.Parse(append(flagArgs, regularArgs...)) + if firstFlagIndex > -1 { + args := ctx.Args() + regularArgs := make([]string, len(args[1:firstFlagIndex])) + copy(regularArgs, args[1:firstFlagIndex]) + + var flagArgs []string + if terminatorIndex > -1 { + flagArgs = args[firstFlagIndex:terminatorIndex] + regularArgs = append(regularArgs, args[terminatorIndex:]...) + } else { + flagArgs = args[firstFlagIndex:] + } + + err = set.Parse(append(flagArgs, regularArgs...)) + } else { + err = set.Parse(ctx.Args().Tail()) + } } else { - err = set.Parse(ctx.Args().Tail()) + if c.SkipFlagParsing { + err = set.Parse(append([]string{"--"}, ctx.Args().Tail()...)) + } } if err != nil { diff --git a/command_test.go b/command_test.go index fd39548..dd9fc87 100644 --- a/command_test.go +++ b/command_test.go @@ -1,47 +1,43 @@ package cli import ( + "errors" "flag" "testing" ) -func TestCommandDoNotIgnoreFlags(t *testing.T) { - app := NewApp() - set := flag.NewFlagSet("test", 0) - test := []string{"blah", "blah", "-break"} - set.Parse(test) - - c := NewContext(app, set, nil) - - command := Command{ - Name: "test-cmd", - Aliases: []string{"tc"}, - Usage: "this is for testing", - Description: "testing", - Action: func(_ *Context) {}, +func TestCommandFlagParsing(t *testing.T) { + cases := []struct { + testArgs []string + skipFlagParsing bool + expectedErr error + }{ + {[]string{"blah", "blah", "-break"}, false, errors.New("flag provided but not defined: -break")}, // Test normal "not ignoring flags" flow + {[]string{"blah", "blah"}, true, nil}, // Test SkipFlagParsing without any args that look like flags + {[]string{"blah", "-break"}, true, nil}, // Test SkipFlagParsing with random flag arg + {[]string{"blah", "-help"}, true, nil}, // Test SkipFlagParsing with "special" help flag arg } - err := command.Run(c) - expect(t, err.Error(), "flag provided but not defined: -break") -} + for _, c := range cases { + app := NewApp() + set := flag.NewFlagSet("test", 0) + set.Parse(c.testArgs) -func TestCommandIgnoreFlags(t *testing.T) { - app := NewApp() - set := flag.NewFlagSet("test", 0) - test := []string{"blah", "blah", "-break"} - set.Parse(test) + context := NewContext(app, set, nil) - c := NewContext(app, set, nil) + command := Command{ + Name: "test-cmd", + Aliases: []string{"tc"}, + Usage: "this is for testing", + Description: "testing", + Action: func(_ *Context) {}, + } - command := Command{ - Name: "test-cmd", - Aliases: []string{"tc"}, - Usage: "this is for testing", - Description: "testing", - Action: func(_ *Context) {}, - SkipFlagParsing: true, + command.SkipFlagParsing = c.skipFlagParsing + + err := command.Run(context) + + expect(t, err, c.expectedErr) + expect(t, []string(context.Args()), c.testArgs) } - err := command.Run(c) - - expect(t, err, nil) } diff --git a/helpers_test.go b/helpers_test.go index 3ce8e93..b1b7339 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -7,13 +7,13 @@ import ( /* Test Helpers */ func expect(t *testing.T, a interface{}, b interface{}) { - if a != b { + if !reflect.DeepEqual(a, b) { t.Errorf("Expected %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a)) } } func refute(t *testing.T, a interface{}, b interface{}) { - if a == b { + if reflect.DeepEqual(a, b) { t.Errorf("Did not expect %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a)) } }