diff --git a/app.go b/app.go index 5e551b8..8f3bb30 100644 --- a/app.go +++ b/app.go @@ -132,10 +132,14 @@ func (a *App) Run(arguments []string) (err error) { if a.After != nil { defer func() { - // err is always nil here. - // There is a check to see if it is non-nil - // just few lines before. - err = a.After(context) + afterErr := a.After(context) + if afterErr != nil { + if err != nil { + err = NewMultiError(err, afterErr) + } else { + err = afterErr + } + } }() } @@ -225,10 +229,14 @@ func (a *App) RunAsSubcommand(ctx *Context) (err error) { if a.After != nil { defer func() { - // err is always nil here. - // There is a check to see if it is non-nil - // just few lines before. - err = a.After(context) + afterErr := a.After(context) + if afterErr != nil { + if err != nil { + err = NewMultiError(err, afterErr) + } else { + err = afterErr + } + } }() } diff --git a/app_test.go b/app_test.go index fb8111d..57ec93e 100644 --- a/app_test.go +++ b/app_test.go @@ -717,3 +717,45 @@ func TestApp_Run_CommandWithSubcommandHasHelpTopic(t *testing.T) { } } } + +func TestApp_Run_DoesNotOverwriteErrorFromBefore(t *testing.T) { + app := cli.NewApp() + app.Action = func(c *cli.Context) {} + app.Before = func(c *cli.Context) error { return fmt.Errorf("before error") } + app.After = func(c *cli.Context) error { return fmt.Errorf("after error") } + + err := app.Run([]string{"foo"}) + if err == nil { + t.Fatalf("expected to recieve error from Run, got none") + } + + if !strings.Contains(err.Error(), "before error") { + t.Errorf("expected text of error from Before method, but got none in \"%v\"", err) + } + if !strings.Contains(err.Error(), "after error") { + t.Errorf("expected text of error from After method, but got none in \"%v\"", err) + } +} + +func TestApp_Run_SubcommandDoesNotOverwriteErrorFromBefore(t *testing.T) { + app := cli.NewApp() + app.Commands = []cli.Command{ + cli.Command{ + Name: "bar", + Before: func(c *cli.Context) error { return fmt.Errorf("before error") }, + After: func(c *cli.Context) error { return fmt.Errorf("after error") }, + }, + } + + err := app.Run([]string{"foo", "bar"}) + if err == nil { + t.Fatalf("expected to recieve error from Run, got none") + } + + if !strings.Contains(err.Error(), "before error") { + t.Errorf("expected text of error from Before method, but got none in \"%v\"", err) + } + if !strings.Contains(err.Error(), "after error") { + t.Errorf("expected text of error from After method, but got none in \"%v\"", err) + } +} diff --git a/cli.go b/cli.go index b742545..31dc912 100644 --- a/cli.go +++ b/cli.go @@ -17,3 +17,24 @@ // app.Run(os.Args) // } package cli + +import ( + "strings" +) + +type MultiError struct { + Errors []error +} + +func NewMultiError(err ...error) MultiError { + return MultiError{Errors: err} +} + +func (m MultiError) Error() string { + errs := make([]string, len(m.Errors)) + for i, err := range m.Errors { + errs[i] = err.Error() + } + + return strings.Join(errs, "\n") +}