Merge pull request #1 from ivey/required_flags

Required flags
This commit is contained in:
Jesse Howarth 2014-12-02 15:23:01 -08:00
commit 8f1fb06a58
5 changed files with 219 additions and 45 deletions

32
app.go
View File

@ -87,14 +87,25 @@ func (a *App) Run(arguments []string) error {
set.SetOutput(ioutil.Discard) set.SetOutput(ioutil.Discard)
err := set.Parse(arguments[1:]) err := set.Parse(arguments[1:])
nerr := normalizeFlags(a.Flags, set) nerr := normalizeFlags(a.Flags, set)
cerr := checkRequiredFlags(a.Flags, set)
context := NewContext(a, set, set)
if nerr != nil { if nerr != nil {
fmt.Println(nerr) fmt.Println(nerr)
context := NewContext(a, set, set) fmt.Println("")
ShowAppHelp(context) ShowAppHelp(context)
fmt.Println("") fmt.Println("")
return nerr return nerr
} }
context := NewContext(a, set, set)
if cerr != nil {
fmt.Println(cerr)
fmt.Println("")
ShowAppHelp(context)
fmt.Println("")
return cerr
}
if err != nil { if err != nil {
fmt.Printf("Incorrect Usage.\n\n") fmt.Printf("Incorrect Usage.\n\n")
@ -164,10 +175,13 @@ func (a *App) RunAsSubcommand(ctx *Context) error {
set.SetOutput(ioutil.Discard) set.SetOutput(ioutil.Discard)
err := set.Parse(ctx.Args().Tail()) err := set.Parse(ctx.Args().Tail())
nerr := normalizeFlags(a.Flags, set) nerr := normalizeFlags(a.Flags, set)
cerr := checkRequiredFlags(a.Flags, set)
context := NewContext(a, set, ctx.globalSet) context := NewContext(a, set, ctx.globalSet)
if nerr != nil { if nerr != nil {
fmt.Println(nerr) fmt.Println(nerr)
fmt.Println("")
if len(a.Commands) > 0 { if len(a.Commands) > 0 {
ShowSubcommandHelp(context) ShowSubcommandHelp(context)
} else { } else {
@ -177,6 +191,20 @@ func (a *App) RunAsSubcommand(ctx *Context) error {
return nerr return nerr
} }
if cerr != nil {
fmt.Println(cerr)
fmt.Println("")
if len(a.Commands) > 0 {
ShowSubcommandHelp(context)
fmt.Println("subcommands")
} else {
ShowCommandHelp(ctx, context.Args().First())
fmt.Println("commands")
}
fmt.Println("")
return cerr
}
if err != nil { if err != nil {
fmt.Printf("Incorrect Usage.\n\n") fmt.Printf("Incorrect Usage.\n\n")
ShowSubcommandHelp(context) ShowSubcommandHelp(context)

View File

@ -88,6 +88,16 @@ func (c Command) Run(ctx *Context) error {
fmt.Println("") fmt.Println("")
return nerr return nerr
} }
cerr := checkRequiredFlags(c.Flags, set)
if cerr != nil {
fmt.Println(cerr)
fmt.Println("")
ShowCommandHelp(ctx, c.Name)
fmt.Println("")
return cerr
}
context := NewContext(ctx.App, set, ctx.globalSet) context := NewContext(ctx.App, set, ctx.globalSet)
if checkCommandCompletions(context, c.Name) { if checkCommandCompletions(context, c.Name) {

View File

@ -3,6 +3,7 @@ package cli
import ( import (
"errors" "errors"
"flag" "flag"
"fmt"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -337,3 +338,20 @@ func normalizeFlags(flags []Flag, set *flag.FlagSet) error {
} }
return nil return nil
} }
func checkRequiredFlags(flags []Flag, set *flag.FlagSet) error {
visited := make(map[string]bool)
set.Visit(func(f *flag.Flag) {
visited[f.Name] = true
})
for _, f := range flags {
if f.IsRequired() {
key := strings.Split(f.getName(), ",")[0]
if !visited[key] {
return fmt.Errorf("Required flag %s not set", f.getName())
}
}
}
return nil
}

144
flag.go
View File

@ -34,6 +34,7 @@ type Flag interface {
// Apply Flag settings to the given flag set // Apply Flag settings to the given flag set
Apply(*flag.FlagSet) Apply(*flag.FlagSet)
getName() string getName() string
IsRequired() bool
} }
func flagSet(name string, flags []Flag) *flag.FlagSet { func flagSet(name string, flags []Flag) *flag.FlagSet {
@ -61,14 +62,15 @@ type Generic interface {
// GenericFlag is the flag type for types implementing Generic // GenericFlag is the flag type for types implementing Generic
type GenericFlag struct { type GenericFlag struct {
Name string Name string
Value Generic Value Generic
Usage string Usage string
EnvVar string EnvVar string
Required bool
} }
func (f GenericFlag) String() string { func (f GenericFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s%s %v\t`%v` %s", prefixFor(f.Name), f.Name, f.Value, "-"+f.Name+" option -"+f.Name+" option", f.Usage)) return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s%s %v\t`%v` %s", prefixFor(f.Name), f.Name, f.Value, "-"+f.Name+" option -"+f.Name+" option", f.Usage))
} }
func (f GenericFlag) Apply(set *flag.FlagSet) { func (f GenericFlag) Apply(set *flag.FlagSet) {
@ -88,6 +90,10 @@ func (f GenericFlag) getName() string {
return f.Name return f.Name
} }
func (f GenericFlag) IsRequired() bool {
return f.Required
}
type StringSlice []string type StringSlice []string
func (f *StringSlice) Set(value string) error { func (f *StringSlice) Set(value string) error {
@ -104,16 +110,17 @@ func (f *StringSlice) Value() []string {
} }
type StringSliceFlag struct { type StringSliceFlag struct {
Name string Name string
Value *StringSlice Value *StringSlice
Usage string Usage string
EnvVar string EnvVar string
Required bool
} }
func (f StringSliceFlag) String() string { func (f StringSliceFlag) String() string {
firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ") firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ")
pref := prefixFor(firstName) pref := prefixFor(firstName)
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage)) return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
} }
func (f StringSliceFlag) Apply(set *flag.FlagSet) { func (f StringSliceFlag) Apply(set *flag.FlagSet) {
@ -136,6 +143,10 @@ func (f StringSliceFlag) getName() string {
return f.Name return f.Name
} }
func (f StringSliceFlag) IsRequired() bool {
return f.Required
}
type IntSlice []int type IntSlice []int
func (f *IntSlice) Set(value string) error { func (f *IntSlice) Set(value string) error {
@ -158,16 +169,17 @@ func (f *IntSlice) Value() []int {
} }
type IntSliceFlag struct { type IntSliceFlag struct {
Name string Name string
Value *IntSlice Value *IntSlice
Usage string Usage string
EnvVar string EnvVar string
Required bool
} }
func (f IntSliceFlag) String() string { func (f IntSliceFlag) String() string {
firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ") firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ")
pref := prefixFor(firstName) pref := prefixFor(firstName)
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage)) return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
} }
func (f IntSliceFlag) Apply(set *flag.FlagSet) { func (f IntSliceFlag) Apply(set *flag.FlagSet) {
@ -193,14 +205,19 @@ func (f IntSliceFlag) getName() string {
return f.Name return f.Name
} }
func (f IntSliceFlag) IsRequired() bool {
return f.Required
}
type BoolFlag struct { type BoolFlag struct {
Name string Name string
Usage string Usage string
EnvVar string EnvVar string
Required bool
} }
func (f BoolFlag) String() string { func (f BoolFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage)) return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage))
} }
func (f BoolFlag) Apply(set *flag.FlagSet) { func (f BoolFlag) Apply(set *flag.FlagSet) {
@ -223,14 +240,19 @@ func (f BoolFlag) getName() string {
return f.Name return f.Name
} }
func (f BoolFlag) IsRequired() bool {
return f.Required
}
type BoolTFlag struct { type BoolTFlag struct {
Name string Name string
Usage string Usage string
EnvVar string EnvVar string
Required bool
} }
func (f BoolTFlag) String() string { func (f BoolTFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage)) return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage))
} }
func (f BoolTFlag) Apply(set *flag.FlagSet) { func (f BoolTFlag) Apply(set *flag.FlagSet) {
@ -253,11 +275,16 @@ func (f BoolTFlag) getName() string {
return f.Name return f.Name
} }
func (f BoolTFlag) IsRequired() bool {
return f.Required
}
type StringFlag struct { type StringFlag struct {
Name string Name string
Value string Value string
Usage string Usage string
EnvVar string EnvVar string
Required bool
} }
func (f StringFlag) String() string { func (f StringFlag) String() string {
@ -270,7 +297,7 @@ func (f StringFlag) String() string {
fmtString = "%s %v\t%v" fmtString = "%s %v\t%v"
} }
return withEnvHint(f.EnvVar, fmt.Sprintf(fmtString, prefixedNames(f.Name), f.Value, f.Usage)) return withHints(f.EnvVar, f.Required, fmt.Sprintf(fmtString, prefixedNames(f.Name), f.Value, f.Usage))
} }
func (f StringFlag) Apply(set *flag.FlagSet) { func (f StringFlag) Apply(set *flag.FlagSet) {
@ -289,15 +316,20 @@ func (f StringFlag) getName() string {
return f.Name return f.Name
} }
func (f StringFlag) IsRequired() bool {
return f.Required
}
type IntFlag struct { type IntFlag struct {
Name string Name string
Value int Value int
Usage string Usage string
EnvVar string EnvVar string
Required bool
} }
func (f IntFlag) String() string { func (f IntFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage)) return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage))
} }
func (f IntFlag) Apply(set *flag.FlagSet) { func (f IntFlag) Apply(set *flag.FlagSet) {
@ -319,15 +351,20 @@ func (f IntFlag) getName() string {
return f.Name return f.Name
} }
func (f IntFlag) IsRequired() bool {
return f.Required
}
type DurationFlag struct { type DurationFlag struct {
Name string Name string
Value time.Duration Value time.Duration
Usage string Usage string
EnvVar string EnvVar string
Required bool
} }
func (f DurationFlag) String() string { func (f DurationFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage)) return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage))
} }
func (f DurationFlag) Apply(set *flag.FlagSet) { func (f DurationFlag) Apply(set *flag.FlagSet) {
@ -349,15 +386,20 @@ func (f DurationFlag) getName() string {
return f.Name return f.Name
} }
func (f DurationFlag) IsRequired() bool {
return f.Required
}
type Float64Flag struct { type Float64Flag struct {
Name string Name string
Value float64 Value float64
Usage string Usage string
EnvVar string EnvVar string
Required bool
} }
func (f Float64Flag) String() string { func (f Float64Flag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage)) return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage))
} }
func (f Float64Flag) Apply(set *flag.FlagSet) { func (f Float64Flag) Apply(set *flag.FlagSet) {
@ -379,6 +421,10 @@ func (f Float64Flag) getName() string {
return f.Name return f.Name
} }
func (f Float64Flag) IsRequired() bool {
return f.Required
}
func prefixFor(name string) (prefix string) { func prefixFor(name string) (prefix string) {
if len(name) == 1 { if len(name) == 1 {
prefix = "-" prefix = "-"
@ -408,3 +454,15 @@ func withEnvHint(envVar, str string) string {
} }
return str + envText return str + envText
} }
func withRequiredHint(isRequired bool, str string) string {
if isRequired {
return str + " (required)"
}
return str
}
func withHints(envVar string, isRequired bool, str string) string {
return withRequiredHint(isRequired, withEnvHint(envVar, str))
}

60
required_flags_test.go Normal file
View File

@ -0,0 +1,60 @@
package cli
import (
"flag"
"testing"
)
func TestContext_CheckRequiredFlagsSuccess(t *testing.T) {
flags := []Flag{
StringFlag{
Name: "required",
Required: true,
},
StringFlag{
Name: "optional",
},
}
set := flag.NewFlagSet("test", 0)
for _, f := range flags {
f.Apply(set)
}
e := set.Parse([]string{"--required", "foo"})
if e != nil {
t.Errorf("Expected no error parsing but there was one: %s", e)
}
err := checkRequiredFlags(flags, set)
if err != nil {
t.Error("Expected flag parsing to be successful")
}
}
func TestContext_CheckRequiredFlagsFailure(t *testing.T) {
flags := []Flag{
StringFlag{
Name: "required",
Required: true,
},
StringFlag{
Name: "optional",
},
}
set := flag.NewFlagSet("test", 0)
for _, f := range flags {
f.Apply(set)
}
e := set.Parse([]string{"--optional", "foo"})
if e != nil {
t.Errorf("Expected no error parsing but there was one: %s", e)
}
err := checkRequiredFlags(flags, set)
if err == nil {
t.Error("Expected flag parsing to be unsuccessful")
}
}