diff --git a/app.go b/app.go index f4c4af8..31a9070 100644 --- a/app.go +++ b/app.go @@ -87,14 +87,25 @@ func (a *App) Run(arguments []string) error { set.SetOutput(ioutil.Discard) err := set.Parse(arguments[1:]) nerr := normalizeFlags(a.Flags, set) + cerr := checkRequiredFlags(a.Flags, set) + + context := NewContext(a, set, set) + if nerr != nil { fmt.Println(nerr) - context := NewContext(a, set, set) + fmt.Println("") ShowAppHelp(context) fmt.Println("") return nerr } - context := NewContext(a, set, set) + + if cerr != nil { + fmt.Println(cerr) + fmt.Println("") + ShowAppHelp(context) + fmt.Println("") + return cerr + } if err != nil { fmt.Printf("Incorrect Usage.\n\n") @@ -164,10 +175,13 @@ func (a *App) RunAsSubcommand(ctx *Context) error { set.SetOutput(ioutil.Discard) err := set.Parse(ctx.Args().Tail()) nerr := normalizeFlags(a.Flags, set) + cerr := checkRequiredFlags(a.Flags, set) + context := NewContext(a, set, ctx.globalSet) if nerr != nil { fmt.Println(nerr) + fmt.Println("") if len(a.Commands) > 0 { ShowSubcommandHelp(context) } else { @@ -177,6 +191,20 @@ func (a *App) RunAsSubcommand(ctx *Context) error { return nerr } + if cerr != nil { + fmt.Println(cerr) + fmt.Println("") + if len(a.Commands) > 0 { + ShowSubcommandHelp(context) + fmt.Println("subcommands") + } else { + ShowCommandHelp(ctx, context.Args().First()) + fmt.Println("commands") + } + fmt.Println("") + return cerr + } + if err != nil { fmt.Printf("Incorrect Usage.\n\n") ShowSubcommandHelp(context) diff --git a/command.go b/command.go index 5622b38..2ade6fe 100644 --- a/command.go +++ b/command.go @@ -88,6 +88,16 @@ func (c Command) Run(ctx *Context) error { fmt.Println("") return nerr } + + cerr := checkRequiredFlags(c.Flags, set) + if cerr != nil { + fmt.Println(cerr) + fmt.Println("") + ShowCommandHelp(ctx, c.Name) + fmt.Println("") + return cerr + } + context := NewContext(ctx.App, set, ctx.globalSet) if checkCommandCompletions(context, c.Name) { diff --git a/context.go b/context.go index c9f645b..b8fe7a6 100644 --- a/context.go +++ b/context.go @@ -3,6 +3,7 @@ package cli import ( "errors" "flag" + "fmt" "strconv" "strings" "time" @@ -337,3 +338,20 @@ func normalizeFlags(flags []Flag, set *flag.FlagSet) error { } return nil } + +func checkRequiredFlags(flags []Flag, set *flag.FlagSet) error { + visited := make(map[string]bool) + set.Visit(func(f *flag.Flag) { + visited[f.Name] = true + }) + + for _, f := range flags { + if f.IsRequired() { + key := strings.Split(f.getName(), ",")[0] + if !visited[key] { + return fmt.Errorf("Required flag %s not set", f.getName()) + } + } + } + return nil +} diff --git a/flag.go b/flag.go index b30bca3..eb10368 100644 --- a/flag.go +++ b/flag.go @@ -34,6 +34,7 @@ type Flag interface { // Apply Flag settings to the given flag set Apply(*flag.FlagSet) getName() string + IsRequired() bool } func flagSet(name string, flags []Flag) *flag.FlagSet { @@ -61,10 +62,11 @@ type Generic interface { // GenericFlag is the flag type for types implementing Generic type GenericFlag struct { - Name string - Value Generic - Usage string - EnvVar string + Name string + Value Generic + Usage string + EnvVar string + Required bool } func (f GenericFlag) String() string { @@ -88,6 +90,10 @@ func (f GenericFlag) getName() string { return f.Name } +func (f GenericFlag) IsRequired() bool { + return f.Required +} + type StringSlice []string func (f *StringSlice) Set(value string) error { @@ -104,10 +110,11 @@ func (f *StringSlice) Value() []string { } type StringSliceFlag struct { - Name string - Value *StringSlice - Usage string - EnvVar string + Name string + Value *StringSlice + Usage string + EnvVar string + Required bool } func (f StringSliceFlag) String() string { @@ -136,6 +143,10 @@ func (f StringSliceFlag) getName() string { return f.Name } +func (f StringSliceFlag) IsRequired() bool { + return f.Required +} + type IntSlice []int func (f *IntSlice) Set(value string) error { @@ -158,10 +169,11 @@ func (f *IntSlice) Value() []int { } type IntSliceFlag struct { - Name string - Value *IntSlice - Usage string - EnvVar string + Name string + Value *IntSlice + Usage string + EnvVar string + Required bool } func (f IntSliceFlag) String() string { @@ -193,10 +205,15 @@ func (f IntSliceFlag) getName() string { return f.Name } +func (f IntSliceFlag) IsRequired() bool { + return f.Required +} + type BoolFlag struct { - Name string - Usage string - EnvVar string + Name string + Usage string + EnvVar string + Required bool } func (f BoolFlag) String() string { @@ -223,10 +240,15 @@ func (f BoolFlag) getName() string { return f.Name } +func (f BoolFlag) IsRequired() bool { + return f.Required +} + type BoolTFlag struct { - Name string - Usage string - EnvVar string + Name string + Usage string + EnvVar string + Required bool } func (f BoolTFlag) String() string { @@ -253,11 +275,16 @@ func (f BoolTFlag) getName() string { return f.Name } +func (f BoolTFlag) IsRequired() bool { + return f.Required +} + type StringFlag struct { - Name string - Value string - Usage string - EnvVar string + Name string + Value string + Usage string + EnvVar string + Required bool } func (f StringFlag) String() string { @@ -289,11 +316,16 @@ func (f StringFlag) getName() string { return f.Name } +func (f StringFlag) IsRequired() bool { + return f.Required +} + type IntFlag struct { - Name string - Value int - Usage string - EnvVar string + Name string + Value int + Usage string + EnvVar string + Required bool } func (f IntFlag) String() string { @@ -319,11 +351,16 @@ func (f IntFlag) getName() string { return f.Name } +func (f IntFlag) IsRequired() bool { + return f.Required +} + type DurationFlag struct { - Name string - Value time.Duration - Usage string - EnvVar string + Name string + Value time.Duration + Usage string + EnvVar string + Required bool } func (f DurationFlag) String() string { @@ -349,11 +386,16 @@ func (f DurationFlag) getName() string { return f.Name } +func (f DurationFlag) IsRequired() bool { + return f.Required +} + type Float64Flag struct { - Name string - Value float64 - Usage string - EnvVar string + Name string + Value float64 + Usage string + EnvVar string + Required bool } func (f Float64Flag) String() string { @@ -379,6 +421,10 @@ func (f Float64Flag) getName() string { return f.Name } +func (f Float64Flag) IsRequired() bool { + return f.Required +} + func prefixFor(name string) (prefix string) { if len(name) == 1 { prefix = "-"