diff --git a/app.go b/app.go index a5b0aa3..89c741b 100644 --- a/app.go +++ b/app.go @@ -168,9 +168,7 @@ func (a *App) Run(arguments []string) (err error) { if err != nil { if a.OnUsageError != nil { err := a.OnUsageError(context, err, false) - if err != nil { - HandleExitCoder(err) - } + HandleExitCoder(err) return err } else { fmt.Fprintf(a.Writer, "%s\n\n", "Incorrect Usage.") @@ -224,9 +222,7 @@ func (a *App) Run(arguments []string) (err error) { // Run default Action err = HandleAction(a.Action, context) - if err != nil { - HandleExitCoder(err) - } + HandleExitCoder(err) return err } @@ -237,7 +233,7 @@ func (a *App) RunAndExitOnError() { contactSysadmin, runAndExitOnErrorDeprecationURL) if err := a.Run(os.Args); err != nil { fmt.Fprintln(os.Stderr, err) - os.Exit(1) + OsExiter(1) } } @@ -346,9 +342,7 @@ func (a *App) RunAsSubcommand(ctx *Context) (err error) { // Run default Action err = HandleAction(a.Action, context) - if err != nil { - HandleExitCoder(err) - } + HandleExitCoder(err) return err } @@ -438,7 +432,7 @@ func HandleAction(action interface{}, context *Context) (err error) { return errInvalidActionSignature } - if retErr, ok := reflect.ValueOf(vals[0]).Interface().(error); ok { + if retErr, ok := vals[0].Interface().(error); vals[0].IsValid() && ok { return retErr } diff --git a/errors.go b/errors.go index 1a6a8c7..5f1e83b 100644 --- a/errors.go +++ b/errors.go @@ -6,6 +6,8 @@ import ( "strings" ) +var OsExiter = os.Exit + type MultiError struct { Errors []error } @@ -26,6 +28,7 @@ func (m MultiError) Error() string { // ExitCoder is the interface checked by `App` and `Command` for a custom exit // code type ExitCoder interface { + error ExitCode() int } @@ -56,15 +59,20 @@ func (ee *ExitError) ExitCode() int { } // HandleExitCoder checks if the error fulfills the ExitCoder interface, and if -// so prints the error to stderr (if it is non-empty) and calls os.Exit with the +// so prints the error to stderr (if it is non-empty) and calls OsExiter with the // given exit code. If the given error is a MultiError, then this func is // called on all members of the Errors slice. func HandleExitCoder(err error) { + if err == nil { + return + } + if exitErr, ok := err.(ExitCoder); ok { if err.Error() != "" { fmt.Fprintln(os.Stderr, err) } - os.Exit(exitErr.ExitCode()) + OsExiter(exitErr.ExitCode()) + return } if multiErr, ok := err.(MultiError); ok { diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..6863105 --- /dev/null +++ b/errors_test.go @@ -0,0 +1,60 @@ +package cli + +import ( + "errors" + "os" + "testing" +) + +func TestHandleExitCoder_nil(t *testing.T) { + exitCode := 0 + called := false + + OsExiter = func(rc int) { + exitCode = rc + called = true + } + + defer func() { OsExiter = os.Exit }() + + HandleExitCoder(nil) + + expect(t, exitCode, 0) + expect(t, called, false) +} + +func TestHandleExitCoder_ExitCoder(t *testing.T) { + exitCode := 0 + called := false + + OsExiter = func(rc int) { + exitCode = rc + called = true + } + + defer func() { OsExiter = os.Exit }() + + HandleExitCoder(NewExitError("galactic perimiter breach", 9)) + + expect(t, exitCode, 9) + expect(t, called, true) +} + +func TestHandleExitCoder_MultiErrorWithExitCoder(t *testing.T) { + exitCode := 0 + called := false + + OsExiter = func(rc int) { + exitCode = rc + called = true + } + + defer func() { OsExiter = os.Exit }() + + exitErr := NewExitError("galactic perimiter breach", 9) + err := NewMultiError(errors.New("wowsa"), errors.New("egad"), exitErr) + HandleExitCoder(err) + + expect(t, exitCode, 9) + expect(t, called, true) +}