diff --git a/app.go b/app.go index f4c4af8..31a9070 100644 --- a/app.go +++ b/app.go @@ -87,14 +87,25 @@ func (a *App) Run(arguments []string) error { set.SetOutput(ioutil.Discard) err := set.Parse(arguments[1:]) nerr := normalizeFlags(a.Flags, set) + cerr := checkRequiredFlags(a.Flags, set) + + context := NewContext(a, set, set) + if nerr != nil { fmt.Println(nerr) - context := NewContext(a, set, set) + fmt.Println("") ShowAppHelp(context) fmt.Println("") return nerr } - context := NewContext(a, set, set) + + if cerr != nil { + fmt.Println(cerr) + fmt.Println("") + ShowAppHelp(context) + fmt.Println("") + return cerr + } if err != nil { fmt.Printf("Incorrect Usage.\n\n") @@ -164,10 +175,13 @@ func (a *App) RunAsSubcommand(ctx *Context) error { set.SetOutput(ioutil.Discard) err := set.Parse(ctx.Args().Tail()) nerr := normalizeFlags(a.Flags, set) + cerr := checkRequiredFlags(a.Flags, set) + context := NewContext(a, set, ctx.globalSet) if nerr != nil { fmt.Println(nerr) + fmt.Println("") if len(a.Commands) > 0 { ShowSubcommandHelp(context) } else { @@ -177,6 +191,20 @@ func (a *App) RunAsSubcommand(ctx *Context) error { return nerr } + if cerr != nil { + fmt.Println(cerr) + fmt.Println("") + if len(a.Commands) > 0 { + ShowSubcommandHelp(context) + fmt.Println("subcommands") + } else { + ShowCommandHelp(ctx, context.Args().First()) + fmt.Println("commands") + } + fmt.Println("") + return cerr + } + if err != nil { fmt.Printf("Incorrect Usage.\n\n") ShowSubcommandHelp(context) diff --git a/command.go b/command.go index 5622b38..2ade6fe 100644 --- a/command.go +++ b/command.go @@ -88,6 +88,16 @@ func (c Command) Run(ctx *Context) error { fmt.Println("") return nerr } + + cerr := checkRequiredFlags(c.Flags, set) + if cerr != nil { + fmt.Println(cerr) + fmt.Println("") + ShowCommandHelp(ctx, c.Name) + fmt.Println("") + return cerr + } + context := NewContext(ctx.App, set, ctx.globalSet) if checkCommandCompletions(context, c.Name) { diff --git a/context.go b/context.go index c9f645b..b8fe7a6 100644 --- a/context.go +++ b/context.go @@ -3,6 +3,7 @@ package cli import ( "errors" "flag" + "fmt" "strconv" "strings" "time" @@ -337,3 +338,20 @@ func normalizeFlags(flags []Flag, set *flag.FlagSet) error { } return nil } + +func checkRequiredFlags(flags []Flag, set *flag.FlagSet) error { + visited := make(map[string]bool) + set.Visit(func(f *flag.Flag) { + visited[f.Name] = true + }) + + for _, f := range flags { + if f.IsRequired() { + key := strings.Split(f.getName(), ",")[0] + if !visited[key] { + return fmt.Errorf("Required flag %s not set", f.getName()) + } + } + } + return nil +} diff --git a/flag.go b/flag.go index b30bca3..83dbba1 100644 --- a/flag.go +++ b/flag.go @@ -34,6 +34,7 @@ type Flag interface { // Apply Flag settings to the given flag set Apply(*flag.FlagSet) getName() string + IsRequired() bool } func flagSet(name string, flags []Flag) *flag.FlagSet { @@ -61,14 +62,15 @@ type Generic interface { // GenericFlag is the flag type for types implementing Generic type GenericFlag struct { - Name string - Value Generic - Usage string - EnvVar string + Name string + Value Generic + Usage string + EnvVar string + Required bool } func (f GenericFlag) String() string { - return withEnvHint(f.EnvVar, fmt.Sprintf("%s%s %v\t`%v` %s", prefixFor(f.Name), f.Name, f.Value, "-"+f.Name+" option -"+f.Name+" option", f.Usage)) + return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s%s %v\t`%v` %s", prefixFor(f.Name), f.Name, f.Value, "-"+f.Name+" option -"+f.Name+" option", f.Usage)) } func (f GenericFlag) Apply(set *flag.FlagSet) { @@ -88,6 +90,10 @@ func (f GenericFlag) getName() string { return f.Name } +func (f GenericFlag) IsRequired() bool { + return f.Required +} + type StringSlice []string func (f *StringSlice) Set(value string) error { @@ -104,16 +110,17 @@ func (f *StringSlice) Value() []string { } type StringSliceFlag struct { - Name string - Value *StringSlice - Usage string - EnvVar string + Name string + Value *StringSlice + Usage string + EnvVar string + Required bool } func (f StringSliceFlag) String() string { firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ") pref := prefixFor(firstName) - return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage)) + return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage)) } func (f StringSliceFlag) Apply(set *flag.FlagSet) { @@ -136,6 +143,10 @@ func (f StringSliceFlag) getName() string { return f.Name } +func (f StringSliceFlag) IsRequired() bool { + return f.Required +} + type IntSlice []int func (f *IntSlice) Set(value string) error { @@ -158,16 +169,17 @@ func (f *IntSlice) Value() []int { } type IntSliceFlag struct { - Name string - Value *IntSlice - Usage string - EnvVar string + Name string + Value *IntSlice + Usage string + EnvVar string + Required bool } func (f IntSliceFlag) String() string { firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ") pref := prefixFor(firstName) - return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage)) + return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage)) } func (f IntSliceFlag) Apply(set *flag.FlagSet) { @@ -193,14 +205,19 @@ func (f IntSliceFlag) getName() string { return f.Name } +func (f IntSliceFlag) IsRequired() bool { + return f.Required +} + type BoolFlag struct { - Name string - Usage string - EnvVar string + Name string + Usage string + EnvVar string + Required bool } func (f BoolFlag) String() string { - return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage)) + return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage)) } func (f BoolFlag) Apply(set *flag.FlagSet) { @@ -223,14 +240,19 @@ func (f BoolFlag) getName() string { return f.Name } +func (f BoolFlag) IsRequired() bool { + return f.Required +} + type BoolTFlag struct { - Name string - Usage string - EnvVar string + Name string + Usage string + EnvVar string + Required bool } func (f BoolTFlag) String() string { - return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage)) + return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage)) } func (f BoolTFlag) Apply(set *flag.FlagSet) { @@ -253,11 +275,16 @@ func (f BoolTFlag) getName() string { return f.Name } +func (f BoolTFlag) IsRequired() bool { + return f.Required +} + type StringFlag struct { - Name string - Value string - Usage string - EnvVar string + Name string + Value string + Usage string + EnvVar string + Required bool } func (f StringFlag) String() string { @@ -270,7 +297,7 @@ func (f StringFlag) String() string { fmtString = "%s %v\t%v" } - return withEnvHint(f.EnvVar, fmt.Sprintf(fmtString, prefixedNames(f.Name), f.Value, f.Usage)) + return withHints(f.EnvVar, f.Required, fmt.Sprintf(fmtString, prefixedNames(f.Name), f.Value, f.Usage)) } func (f StringFlag) Apply(set *flag.FlagSet) { @@ -289,15 +316,20 @@ func (f StringFlag) getName() string { return f.Name } +func (f StringFlag) IsRequired() bool { + return f.Required +} + type IntFlag struct { - Name string - Value int - Usage string - EnvVar string + Name string + Value int + Usage string + EnvVar string + Required bool } func (f IntFlag) String() string { - return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage)) + return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage)) } func (f IntFlag) Apply(set *flag.FlagSet) { @@ -319,15 +351,20 @@ func (f IntFlag) getName() string { return f.Name } +func (f IntFlag) IsRequired() bool { + return f.Required +} + type DurationFlag struct { - Name string - Value time.Duration - Usage string - EnvVar string + Name string + Value time.Duration + Usage string + EnvVar string + Required bool } func (f DurationFlag) String() string { - return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage)) + return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage)) } func (f DurationFlag) Apply(set *flag.FlagSet) { @@ -349,15 +386,20 @@ func (f DurationFlag) getName() string { return f.Name } +func (f DurationFlag) IsRequired() bool { + return f.Required +} + type Float64Flag struct { - Name string - Value float64 - Usage string - EnvVar string + Name string + Value float64 + Usage string + EnvVar string + Required bool } func (f Float64Flag) String() string { - return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage)) + return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage)) } func (f Float64Flag) Apply(set *flag.FlagSet) { @@ -379,6 +421,10 @@ func (f Float64Flag) getName() string { return f.Name } +func (f Float64Flag) IsRequired() bool { + return f.Required +} + func prefixFor(name string) (prefix string) { if len(name) == 1 { prefix = "-" @@ -408,3 +454,15 @@ func withEnvHint(envVar, str string) string { } return str + envText } + +func withRequiredHint(isRequired bool, str string) string { + if isRequired { + return str + " (required)" + } + + return str +} + +func withHints(envVar string, isRequired bool, str string) string { + return withRequiredHint(isRequired, withEnvHint(envVar, str)) +} diff --git a/required_flags_test.go b/required_flags_test.go new file mode 100644 index 0000000..f9abe7a --- /dev/null +++ b/required_flags_test.go @@ -0,0 +1,60 @@ +package cli + +import ( + "flag" + "testing" +) + +func TestContext_CheckRequiredFlagsSuccess(t *testing.T) { + flags := []Flag{ + StringFlag{ + Name: "required", + Required: true, + }, + StringFlag{ + Name: "optional", + }, + } + + set := flag.NewFlagSet("test", 0) + for _, f := range flags { + f.Apply(set) + } + + e := set.Parse([]string{"--required", "foo"}) + if e != nil { + t.Errorf("Expected no error parsing but there was one: %s", e) + } + + err := checkRequiredFlags(flags, set) + if err != nil { + t.Error("Expected flag parsing to be successful") + } +} + +func TestContext_CheckRequiredFlagsFailure(t *testing.T) { + flags := []Flag{ + StringFlag{ + Name: "required", + Required: true, + }, + StringFlag{ + Name: "optional", + }, + } + + set := flag.NewFlagSet("test", 0) + for _, f := range flags { + f.Apply(set) + } + + e := set.Parse([]string{"--optional", "foo"}) + if e != nil { + t.Errorf("Expected no error parsing but there was one: %s", e) + } + + err := checkRequiredFlags(flags, set) + if err == nil { + t.Error("Expected flag parsing to be unsuccessful") + } +}