diff --git a/app_test.go b/app_test.go index 796f665..84631fd 100644 --- a/app_test.go +++ b/app_test.go @@ -236,7 +236,7 @@ func ExampleApp_Run_shellComplete() { os.Args = []string{"greet", fmt.Sprintf("--%s", genCompName())} app := &App{ - Name: "greet", + Name: "greet", EnableShellCompletion: true, Commands: []*Command{ { @@ -525,6 +525,8 @@ func TestApp_ParseSliceFlags(t *testing.T) { }, }, } + var _ = parsedOption + var _ = firstArg app.Run([]string{"", "cmd", "-p", "22", "-p", "80", "-ip", "8.8.8.8", "-ip", "8.8.4.4", "my-arg"}) diff --git a/context.go b/context.go index 9802594..f1a01b4 100644 --- a/context.go +++ b/context.go @@ -1,11 +1,14 @@ package cli import ( + "context" "errors" "flag" "os" + "os/signal" "reflect" "strings" + "syscall" ) // Context is a type that is passed through to @@ -13,6 +16,7 @@ import ( // can be used to retrieve context-specific args and // parsed command-line options. type Context struct { + context.Context App *App Command *Command shellComplete bool @@ -24,10 +28,20 @@ type Context struct { // NewContext creates a new context. For use in when invoking an App or Command action. func NewContext(app *App, set *flag.FlagSet, parentCtx *Context) *Context { c := &Context{App: app, flagSet: set, parentContext: parentCtx} - if parentCtx != nil { + c.Context = parentCtx.Context c.shellComplete = parentCtx.shellComplete } + if c.Context == nil { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + defer cancel() + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + <-sigs + }() + c.Context = ctx + } return c } diff --git a/context_test.go b/context_test.go index 0509488..7333ae0 100644 --- a/context_test.go +++ b/context_test.go @@ -1,6 +1,7 @@ package cli import ( + "context" "flag" "sort" "testing" @@ -262,3 +263,33 @@ func TestContext_lookupFlagSet(t *testing.T) { t.Fail() } } + +func TestNonNilContext(t *testing.T) { + ctx := NewContext(nil, nil, nil) + if ctx.Context == nil { + t.Fatal("expected a non nil context when no parent is present") + } +} + +// TestContextPropagation tests that +// *cli.Context always has a valid +// context.Context +func TestContextPropagation(t *testing.T) { + parent := NewContext(nil, nil, nil) + parent.Context = context.WithValue(context.Background(), "key", "val") + ctx := NewContext(nil, nil, parent) + val := ctx.Value("key") + if val == nil { + t.Fatal("expected a parent context to be inherited but got nil") + } + valstr, _ := val.(string) + if valstr != "val" { + t.Fatalf("expected the context value to be %q but got %q", "val", valstr) + } + parent = NewContext(nil, nil, nil) + parent.Context = nil + ctx = NewContext(nil, nil, parent) + if ctx.Context == nil { + t.Fatal("expected context to not be nil even if the parent's context is nil") + } +}