diff --git a/context.go b/context.go index 312efb5..81416da 100644 --- a/context.go +++ b/context.go @@ -3,6 +3,7 @@ package cli import ( "context" "flag" + "fmt" "strings" ) @@ -46,10 +47,11 @@ func (cCtx *Context) NumFlags() int { // Set sets a context flag to a value. func (cCtx *Context) Set(name, value string) error { - if cCtx.flagSet.Lookup(name) == nil { - cCtx.onInvalidFlag(name) + if fs := cCtx.lookupFlagSet(name); fs != nil { + return fs.Set(name, value) } - return cCtx.flagSet.Set(name, value) + + return fmt.Errorf("no such flag -%s", name) } // IsSet determines if the flag was actually set diff --git a/context_test.go b/context_test.go index 6601155..246590d 100644 --- a/context_test.go +++ b/context_test.go @@ -643,3 +643,19 @@ func TestCheckRequiredFlags(t *testing.T) { }) } } + +func TestContext_ParentContext_Set(t *testing.T) { + parentSet := flag.NewFlagSet("parent", flag.ContinueOnError) + parentSet.String("Name", "", "") + + context := NewContext( + nil, + flag.NewFlagSet("child", flag.ContinueOnError), + NewContext(nil, parentSet, nil), + ) + + err := context.Set("Name", "aaa") + if err != nil { + t.Errorf("expect nil. set parent context flag return err: %s", err) + } +}