diff --git a/app.go b/app.go index 2f992d0..6884920 100644 --- a/app.go +++ b/app.go @@ -164,6 +164,9 @@ func (a *App) Run(arguments []string) (err error) { if a.Before != nil { err := a.Before(context) if err != nil { + fmt.Fprintln(a.Writer, err) + fmt.Fprintln(a.Writer) + ShowAppHelp(context) return err } } diff --git a/app_test.go b/app_test.go index 59fa75a..9a09405 100644 --- a/app_test.go +++ b/app_test.go @@ -942,6 +942,11 @@ func TestApp_Run_SubcommandDoesNotOverwriteErrorFromBefore(t *testing.T) { app := NewApp() app.Commands = []Command{ Command{ + Subcommands: []Command{ + Command{ + Name: "sub", + }, + }, Name: "bar", Before: func(c *Context) error { return fmt.Errorf("before error") }, After: func(c *Context) error { return fmt.Errorf("after error") }, diff --git a/command.go b/command.go index 824e77b..e42178e 100644 --- a/command.go +++ b/command.go @@ -54,8 +54,8 @@ func (c Command) FullName() string { } // Invokes the command given the context, parses ctx.Args() to generate command-specific flags -func (c Command) Run(ctx *Context) error { - if len(c.Subcommands) > 0 || c.Before != nil || c.After != nil { +func (c Command) Run(ctx *Context) (err error) { + if len(c.Subcommands) > 0 { return c.startApp(ctx) } @@ -74,7 +74,6 @@ func (c Command) Run(ctx *Context) error { set := flagSet(c.Name, c.Flags) set.SetOutput(ioutil.Discard) - var err error if !c.SkipFlagParsing { firstFlagIndex := -1 terminatorIndex := -1 @@ -133,6 +132,30 @@ func (c Command) Run(ctx *Context) error { if checkCommandHelp(context, c.Name) { return nil } + + if c.After != nil { + defer func() { + afterErr := c.After(context) + if afterErr != nil { + if err != nil { + err = NewMultiError(err, afterErr) + } else { + err = afterErr + } + } + }() + } + + if c.Before != nil { + err := c.Before(context) + if err != nil { + fmt.Fprintln(ctx.App.Writer, err) + fmt.Fprintln(ctx.App.Writer) + ShowCommandHelp(ctx, c.Name) + return err + } + } + context.Command = c c.Action(context) return nil diff --git a/command_test.go b/command_test.go index ac10652..50bd875 100644 --- a/command_test.go +++ b/command_test.go @@ -5,6 +5,8 @@ import ( "flag" "io/ioutil" "testing" + "fmt" + "strings" ) func TestCommandFlagParsing(t *testing.T) { @@ -43,3 +45,26 @@ func TestCommandFlagParsing(t *testing.T) { expect(t, []string(context.Args()), c.testArgs) } } + +func TestCommand_Run_DoesNotOverwriteErrorFromBefore(t *testing.T) { + app := NewApp() + app.Commands = []Command{ + Command{ + Name: "bar", + Before: func(c *Context) error { return fmt.Errorf("before error") }, + After: func(c *Context) error { return fmt.Errorf("after error") }, + }, + } + + err := app.Run([]string{"foo", "bar"}) + if err == nil { + t.Fatalf("expected to receive error from Run, got none") + } + + if !strings.Contains(err.Error(), "before error") { + t.Errorf("expected text of error from Before method, but got none in \"%v\"", err) + } + if !strings.Contains(err.Error(), "after error") { + t.Errorf("expected text of error from After method, but got none in \"%v\"", err) + } +}