add propagation tests

This commit is contained in:
marwan-at-work 2019-08-06 14:04:51 -04:00
parent 1f7d1684b8
commit 98e64f4507
2 changed files with 23 additions and 4 deletions

View File

@ -236,7 +236,7 @@ func ExampleApp_Run_shellComplete() {
os.Args = []string{"greet", fmt.Sprintf("--%s", genCompName())} os.Args = []string{"greet", fmt.Sprintf("--%s", genCompName())}
app := &App{ app := &App{
Name: "greet", Name: "greet",
EnableShellCompletion: true, EnableShellCompletion: true,
Commands: []*Command{ Commands: []*Command{
{ {
@ -503,7 +503,6 @@ func TestApp_Float64Flag(t *testing.T) {
} }
func TestApp_ParseSliceFlags(t *testing.T) { func TestApp_ParseSliceFlags(t *testing.T) {
var parsedOption, firstArg string
var parsedIntSlice []int var parsedIntSlice []int
var parsedStringSlice []string var parsedStringSlice []string
@ -518,8 +517,6 @@ func TestApp_ParseSliceFlags(t *testing.T) {
Action: func(c *Context) error { Action: func(c *Context) error {
parsedIntSlice = c.IntSlice("p") parsedIntSlice = c.IntSlice("p")
parsedStringSlice = c.StringSlice("ip") parsedStringSlice = c.StringSlice("ip")
parsedOption = c.String("option")
firstArg = c.Args().First()
return nil return nil
}, },
}, },

View File

@ -1,6 +1,7 @@
package cli package cli
import ( import (
"context"
"flag" "flag"
"sort" "sort"
"testing" "testing"
@ -262,3 +263,24 @@ func TestContext_lookupFlagSet(t *testing.T) {
t.Fail() t.Fail()
} }
} }
// TestContextPropagation tests that
// *cli.Context always has a valid
// context.Context
func TestContextPropagation(t *testing.T) {
ctx := NewContext(nil, nil, nil)
if ctx.Context == nil {
t.Fatal("expected a non nil context when no parent is present")
}
parent := NewContext(nil, nil, nil)
parent.Context = context.WithValue(context.Background(), "key", "val")
ctx = NewContext(nil, nil, parent)
val := ctx.Value("key")
if val == nil {
t.Fatal("expected a parent context to be inherited but got nil")
}
valstr, _ := val.(string)
if valstr != "val" {
t.Fatalf("expected the context value to be %q but got %q", "val", valstr)
}
}