diff --git a/errors.go b/errors.go index 225e1bb..a818727 100644 --- a/errors.go +++ b/errors.go @@ -83,7 +83,7 @@ type ExitCoder interface { type exitError struct { exitCode int - message interface{} + err error } // NewExitError calls Exit to create a new ExitCoder. @@ -101,20 +101,35 @@ func NewExitError(message interface{}, exitCode int) ExitCoder { // by overriding the ExitErrHandler function on an App or the package-global // OsExiter function. func Exit(message interface{}, exitCode int) ExitCoder { + var err error + + switch e := message.(type) { + case ErrorFormatter: + err = fmt.Errorf("%+v", message) + case error: + err = e + default: + err = fmt.Errorf("%+v", message) + } + return &exitError{ - message: message, + err: err, exitCode: exitCode, } } func (ee *exitError) Error() string { - return fmt.Sprintf("%v", ee.message) + return ee.err.Error() } func (ee *exitError) ExitCode() int { return ee.exitCode } +func (ee *exitError) Unwrap() error { + return ee.err +} + // HandleExitCoder handles errors implementing ExitCoder by printing their // message and calling OsExiter with the given exit code. // diff --git a/errors_test.go b/errors_test.go index d0b1b4f..337009c 100644 --- a/errors_test.go +++ b/errors_test.go @@ -45,6 +45,25 @@ func TestHandleExitCoder_ExitCoder(t *testing.T) { expect(t, called, true) } +func TestHandleExitCoder_ErrorExitCoder(t *testing.T) { + exitCode := 0 + called := false + + OsExiter = func(rc int) { + if !called { + exitCode = rc + called = true + } + } + + defer func() { OsExiter = fakeOsExiter }() + + HandleExitCoder(Exit(errors.New("galactic perimeter breach"), 9)) + + expect(t, exitCode, 9) + expect(t, called, true) +} + func TestHandleExitCoder_MultiErrorWithExitCoder(t *testing.T) { exitCode := 0 called := false