diff --git a/app.go b/app.go index 4efba5e..0125dae 100644 --- a/app.go +++ b/app.go @@ -2,8 +2,11 @@ package cli import ( "fmt" + "io" "io/ioutil" "os" + "text/tabwriter" + "text/template" "time" ) @@ -37,6 +40,8 @@ type App struct { Author string // Author e-mail Email string + // Stdout writer to write output to + Stdout io.Writer } // Tries to find out when this binary was compiled. @@ -60,11 +65,28 @@ func NewApp() *App { Compiled: compileTime(), Author: "Author", Email: "unknown@email", + Stdout: os.Stdout, } } // Entry point to the cli app. Parses the arguments slice and routes to the proper flag/args combination func (a *App) Run(arguments []string) error { + if HelpPrinter == nil { + defer func() { + HelpPrinter = nil + }() + + HelpPrinter = func(templ string, data interface{}) { + w := tabwriter.NewWriter(a.Stdout, 0, 8, 1, '\t', 0) + t := template.Must(template.New("help").Parse(templ)) + err := t.Execute(w, data) + if err != nil { + panic(err) + } + w.Flush() + } + } + // append help to commands if a.Command(helpCommand.Name) == nil { a.Commands = append(a.Commands, helpCommand) @@ -83,18 +105,18 @@ func (a *App) Run(arguments []string) error { err := set.Parse(arguments[1:]) nerr := normalizeFlags(a.Flags, set) if nerr != nil { - fmt.Println(nerr) + io.WriteString(a.Stdout, fmt.Sprintln(nerr)) context := NewContext(a, set, set) ShowAppHelp(context) - fmt.Println("") + io.WriteString(a.Stdout, fmt.Sprintln("")) return nerr } context := NewContext(a, set, set) if err != nil { - fmt.Printf("Incorrect Usage.\n\n") + io.WriteString(a.Stdout, fmt.Sprintf("Incorrect Usage.\n\n")) ShowAppHelp(context) - fmt.Println("") + io.WriteString(a.Stdout, fmt.Sprintln("")) return err } @@ -154,18 +176,18 @@ func (a *App) RunAsSubcommand(ctx *Context) error { context := NewContext(a, set, set) if nerr != nil { - fmt.Println(nerr) + io.WriteString(a.Stdout, fmt.Sprintln(nerr)) if len(a.Commands) > 0 { ShowSubcommandHelp(context) } else { ShowCommandHelp(ctx, context.Args().First()) } - fmt.Println("") + io.WriteString(a.Stdout, fmt.Sprintln("")) return nerr } if err != nil { - fmt.Printf("Incorrect Usage.\n\n") + io.WriteString(a.Stdout, fmt.Sprintf("Incorrect Usage.\n\n")) ShowSubcommandHelp(context) return err } diff --git a/app_test.go b/app_test.go index 0b9e154..e8937e1 100644 --- a/app_test.go +++ b/app_test.go @@ -262,6 +262,50 @@ func TestApp_ParseSliceFlags(t *testing.T) { } } +func TestApp_DefaultStdout(t *testing.T) { + app := cli.NewApp() + + if app.Stdout != os.Stdout { + t.Error("Default output writer not set.") + } +} + +type fakeWriter struct { + written []byte +} + +func (fw *fakeWriter) Write(p []byte) (n int, err error) { + if fw.written == nil { + fw.written = p + } else { + fw.written = append(fw.written, p...) + } + + return len(p), nil +} + +func (fw *fakeWriter) GetWritten() (b []byte) { + return fw.written +} + +func TestApp_SetStdout(t *testing.T) { + mockWriter := &fakeWriter{} + + app := cli.NewApp() + app.Name = "test" + app.Stdout = mockWriter + + err := app.Run([]string{"help"}) + + if err != nil { + t.Fatalf("Run error: %s", err) + } + + if len(mockWriter.written) == 0 { + t.Error("App did not write output to desired writer.") + } +} + func TestApp_BeforeFunc(t *testing.T) { beforeRun, subcommandRun := false, false beforeError := fmt.Errorf("fail") diff --git a/command.go b/command.go index 9d8fff4..3d470c8 100644 --- a/command.go +++ b/command.go @@ -2,6 +2,7 @@ package cli import ( "fmt" + "io" "io/ioutil" "strings" ) @@ -70,18 +71,18 @@ func (c Command) Run(ctx *Context) error { } if err != nil { - fmt.Printf("Incorrect Usage.\n\n") + io.WriteString(ctx.App.Stdout, fmt.Sprintf("Incorrect Usage.\n\n")) ShowCommandHelp(ctx, c.Name) - fmt.Println("") + io.WriteString(ctx.App.Stdout, fmt.Sprintln("")) return err } nerr := normalizeFlags(c.Flags, set) if nerr != nil { - fmt.Println(nerr) - fmt.Println("") + io.WriteString(ctx.App.Stdout, fmt.Sprintln(nerr)) + io.WriteString(ctx.App.Stdout, fmt.Sprintln("")) ShowCommandHelp(ctx, c.Name) - fmt.Println("") + io.WriteString(ctx.App.Stdout, fmt.Sprintln("")) return nerr } context := NewContext(ctx.App, set, ctx.globalSet) diff --git a/help.go b/help.go index 7c04005..ae60a13 100644 --- a/help.go +++ b/help.go @@ -2,9 +2,7 @@ package cli import ( "fmt" - "os" - "text/tabwriter" - "text/template" + "io" ) // The text template for the Default help topic. @@ -90,7 +88,9 @@ var helpSubcommand = Command{ } // Prints help for the App -var HelpPrinter = printHelp +type helpPrinter func(templ string, data interface{}) + +var HelpPrinter helpPrinter = nil func ShowAppHelp(c *Context) { HelpPrinter(AppHelpTemplate, c.App) @@ -99,9 +99,9 @@ func ShowAppHelp(c *Context) { // Prints the list of subcommands as the default app completion method func DefaultAppComplete(c *Context) { for _, command := range c.App.Commands { - fmt.Println(command.Name) + io.WriteString(c.App.Stdout, fmt.Sprintln(command.Name)) if command.ShortName != "" { - fmt.Println(command.ShortName) + io.WriteString(c.App.Stdout, fmt.Sprintln(command.ShortName)) } } } @@ -118,7 +118,7 @@ func ShowCommandHelp(c *Context, command string) { if c.App.CommandNotFound != nil { c.App.CommandNotFound(c, command) } else { - fmt.Printf("No help topic for '%v'\n", command) + io.WriteString(c.App.Stdout, fmt.Sprintf("No help topic for '%v'\n", command)) } } @@ -129,7 +129,7 @@ func ShowSubcommandHelp(c *Context) { // Prints the version number of the App func ShowVersion(c *Context) { - fmt.Printf("%v version %v\n", c.App.Name, c.App.Version) + io.WriteString(c.App.Stdout, fmt.Sprintf("%v version %v\n", c.App.Name, c.App.Version)) } // Prints the lists of commands within a given context @@ -148,16 +148,6 @@ func ShowCommandCompletions(ctx *Context, command string) { } } -func printHelp(templ string, data interface{}) { - w := tabwriter.NewWriter(os.Stdout, 0, 8, 1, '\t', 0) - t := template.Must(template.New("help").Parse(templ)) - err := t.Execute(w, data) - if err != nil { - panic(err) - } - w.Flush() -} - func checkVersion(c *Context) bool { if c.GlobalBool("version") { ShowVersion(c)