Merge branch 'master' into master

main
Yogesh Lonkar 5 years ago committed by GitHub
commit 01ab016427
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -228,6 +228,12 @@ func (a *App) Run(arguments []string) (err error) {
return nil
}
cerr := checkRequiredFlags(a.Flags, context)
if cerr != nil {
ShowAppHelp(context)
return cerr
}
if a.After != nil {
defer func() {
if afterErr := a.After(context); afterErr != nil {
@ -352,6 +358,12 @@ func (a *App) RunAsSubcommand(ctx *Context) (err error) {
}
}
cerr := checkRequiredFlags(a.Flags, context)
if cerr != nil {
ShowSubcommandHelp(context)
return cerr
}
if a.After != nil {
defer func() {
afterErr := a.After(context)

@ -931,6 +931,145 @@ func TestAppNoHelpFlag(t *testing.T) {
}
}
func TestRequiredFlagAppRunBehavior(t *testing.T) {
tdata := []struct {
testCase string
appFlags []Flag
appRunInput []string
appCommands []Command
expectedAnError bool
}{
// assertion: empty input, when a required flag is present, errors
{
testCase: "error_case_empty_input_with_required_flag_on_app",
appRunInput: []string{"myCLI"},
appFlags: []Flag{StringFlag{Name: "requiredFlag", Required: true}},
expectedAnError: true,
},
{
testCase: "error_case_empty_input_with_required_flag_on_command",
appRunInput: []string{"myCLI", "myCommand"},
appCommands: []Command{Command{
Name: "myCommand",
Flags: []Flag{StringFlag{Name: "requiredFlag", Required: true}},
}},
expectedAnError: true,
},
{
testCase: "error_case_empty_input_with_required_flag_on_subcommand",
appRunInput: []string{"myCLI", "myCommand", "mySubCommand"},
appCommands: []Command{Command{
Name: "myCommand",
Subcommands: []Command{Command{
Name: "mySubCommand",
Flags: []Flag{StringFlag{Name: "requiredFlag", Required: true}},
}},
}},
expectedAnError: true,
},
// assertion: inputing --help, when a required flag is present, does not error
{
testCase: "valid_case_help_input_with_required_flag_on_app",
appRunInput: []string{"myCLI", "--help"},
appFlags: []Flag{StringFlag{Name: "requiredFlag", Required: true}},
},
{
testCase: "valid_case_help_input_with_required_flag_on_command",
appRunInput: []string{"myCLI", "myCommand", "--help"},
appCommands: []Command{Command{
Name: "myCommand",
Flags: []Flag{StringFlag{Name: "requiredFlag", Required: true}},
}},
},
{
testCase: "valid_case_help_input_with_required_flag_on_subcommand",
appRunInput: []string{"myCLI", "myCommand", "mySubCommand", "--help"},
appCommands: []Command{Command{
Name: "myCommand",
Subcommands: []Command{Command{
Name: "mySubCommand",
Flags: []Flag{StringFlag{Name: "requiredFlag", Required: true}},
}},
}},
},
// assertion: giving optional input, when a required flag is present, errors
{
testCase: "error_case_optional_input_with_required_flag_on_app",
appRunInput: []string{"myCLI", "--optional", "cats"},
appFlags: []Flag{StringFlag{Name: "requiredFlag", Required: true}, StringFlag{Name: "optional"}},
expectedAnError: true,
},
{
testCase: "error_case_optional_input_with_required_flag_on_command",
appRunInput: []string{"myCLI", "myCommand", "--optional", "cats"},
appCommands: []Command{Command{
Name: "myCommand",
Flags: []Flag{StringFlag{Name: "requiredFlag", Required: true}, StringFlag{Name: "optional"}},
}},
expectedAnError: true,
},
{
testCase: "error_case_optional_input_with_required_flag_on_subcommand",
appRunInput: []string{"myCLI", "myCommand", "mySubCommand", "--optional", "cats"},
appCommands: []Command{Command{
Name: "myCommand",
Subcommands: []Command{Command{
Name: "mySubCommand",
Flags: []Flag{StringFlag{Name: "requiredFlag", Required: true}, StringFlag{Name: "optional"}},
}},
}},
expectedAnError: true,
},
// assertion: when a required flag is present, inputting that required flag does not error
{
testCase: "valid_case_required_flag_input_on_app",
appRunInput: []string{"myCLI", "--requiredFlag", "cats"},
appFlags: []Flag{StringFlag{Name: "requiredFlag", Required: true}},
},
{
testCase: "valid_case_required_flag_input_on_command",
appRunInput: []string{"myCLI", "myCommand", "--requiredFlag", "cats"},
appCommands: []Command{Command{
Name: "myCommand",
Flags: []Flag{StringFlag{Name: "requiredFlag", Required: true}},
}},
},
{
testCase: "valid_case_required_flag_input_on_subcommand",
appRunInput: []string{"myCLI", "myCommand", "mySubCommand", "--requiredFlag", "cats"},
appCommands: []Command{Command{
Name: "myCommand",
Subcommands: []Command{Command{
Name: "mySubCommand",
Flags: []Flag{StringFlag{Name: "requiredFlag", Required: true}},
}},
}},
},
}
for _, test := range tdata {
t.Run(test.testCase, func(t *testing.T) {
// setup
app := NewApp()
app.Flags = test.appFlags
app.Commands = test.appCommands
// logic under test
err := app.Run(test.appRunInput)
// assertions
if test.expectedAnError && err == nil {
t.Errorf("expected an error, but there was none")
}
if _, ok := err.(requiredFlagsErr); test.expectedAnError && !ok {
t.Errorf("expected a requiredFlagsErr, but got: %s", err)
}
if !test.expectedAnError && err != nil {
t.Errorf("did not expected an error, but there was one: %s", err)
}
})
}
}
func TestAppHelpPrinter(t *testing.T) {
oldPrinter := HelpPrinter
defer func() {

@ -135,6 +135,12 @@ func (c Command) Run(ctx *Context) (err error) {
return nil
}
cerr := checkRequiredFlags(c.Flags, context)
if cerr != nil {
ShowCommandHelp(context, c.Name)
return cerr
}
if c.After != nil {
defer func() {
afterErr := c.After(context)

@ -3,6 +3,7 @@ package cli
import (
"errors"
"flag"
"fmt"
"os"
"reflect"
"strings"
@ -285,3 +286,43 @@ func normalizeFlags(flags []Flag, set *flag.FlagSet) error {
}
return nil
}
type requiredFlagsErr interface {
error
getMissingFlags() []string
}
type errRequiredFlags struct {
missingFlags []string
}
func (e *errRequiredFlags) Error() string {
numberOfMissingFlags := len(e.missingFlags)
if numberOfMissingFlags == 1 {
return fmt.Sprintf("Required flag %q not set", e.missingFlags[0])
}
joinedMissingFlags := strings.Join(e.missingFlags, ", ")
return fmt.Sprintf("Required flags %q not set", joinedMissingFlags)
}
func (e *errRequiredFlags) getMissingFlags() []string {
return e.missingFlags
}
func checkRequiredFlags(flags []Flag, context *Context) requiredFlagsErr {
var missingFlags []string
for _, f := range flags {
if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() {
key := strings.Split(f.GetName(), ",")[0]
if !context.IsSet(key) {
missingFlags = append(missingFlags, key)
}
}
}
if len(missingFlags) != 0 {
return &errRequiredFlags{missingFlags: missingFlags}
}
return nil
}

@ -3,6 +3,7 @@ package cli
import (
"flag"
"os"
"strings"
"testing"
"time"
)
@ -253,18 +254,23 @@ func TestContext_GlobalIsSet_fromEnv(t *testing.T) {
timeoutIsSet, tIsSet bool
noEnvVarIsSet, nIsSet bool
passwordIsSet, pIsSet bool
passwordValue string
unparsableIsSet, uIsSet bool
overrideIsSet, oIsSet bool
overrideValue string
)
clearenv()
os.Setenv("APP_TIMEOUT_SECONDS", "15.5")
os.Setenv("APP_PASSWORD", "")
os.Setenv("APP_PASSWORD", "badpass")
os.Setenv("APP_OVERRIDE", "overridden")
a := App{
Flags: []Flag{
Float64Flag{Name: "timeout, t", EnvVar: "APP_TIMEOUT_SECONDS"},
StringFlag{Name: "password, p", EnvVar: "APP_PASSWORD"},
Float64Flag{Name: "no-env-var, n"},
Float64Flag{Name: "unparsable, u", EnvVar: "APP_UNPARSABLE"},
StringFlag{Name: "overrides-default, o", Value: "default", EnvVar: "APP_OVERRIDE"},
},
Commands: []Command{
{
@ -274,10 +280,14 @@ func TestContext_GlobalIsSet_fromEnv(t *testing.T) {
tIsSet = ctx.GlobalIsSet("t")
passwordIsSet = ctx.GlobalIsSet("password")
pIsSet = ctx.GlobalIsSet("p")
passwordValue = ctx.GlobalString("password")
unparsableIsSet = ctx.GlobalIsSet("unparsable")
uIsSet = ctx.GlobalIsSet("u")
noEnvVarIsSet = ctx.GlobalIsSet("no-env-var")
nIsSet = ctx.GlobalIsSet("n")
overrideIsSet = ctx.GlobalIsSet("overrides-default")
oIsSet = ctx.GlobalIsSet("o")
overrideValue = ctx.GlobalString("overrides-default")
return nil
},
},
@ -290,8 +300,13 @@ func TestContext_GlobalIsSet_fromEnv(t *testing.T) {
expect(t, tIsSet, true)
expect(t, passwordIsSet, true)
expect(t, pIsSet, true)
expect(t, passwordValue, "badpass")
expect(t, unparsableIsSet, false)
expect(t, noEnvVarIsSet, false)
expect(t, nIsSet, false)
expect(t, overrideIsSet, true)
expect(t, oIsSet, true)
expect(t, overrideValue, "overridden")
os.Setenv("APP_UNPARSABLE", "foobar")
if err := a.Run([]string{"run"}); err != nil {
@ -401,3 +416,139 @@ func TestContext_GlobalSet(t *testing.T) {
expect(t, c.GlobalInt("int"), 1)
expect(t, c.GlobalIsSet("int"), true)
}
func TestCheckRequiredFlags(t *testing.T) {
tdata := []struct {
testCase string
parseInput []string
envVarInput [2]string
flags []Flag
expectedAnError bool
expectedErrorContents []string
}{
{
testCase: "empty",
},
{
testCase: "optional",
flags: []Flag{
StringFlag{Name: "optionalFlag"},
},
},
{
testCase: "required",
flags: []Flag{
StringFlag{Name: "requiredFlag", Required: true},
},
expectedAnError: true,
expectedErrorContents: []string{"requiredFlag"},
},
{
testCase: "required_and_present",
flags: []Flag{
StringFlag{Name: "requiredFlag", Required: true},
},
parseInput: []string{"--requiredFlag", "myinput"},
},
{
testCase: "required_and_present_via_env_var",
flags: []Flag{
StringFlag{Name: "requiredFlag", Required: true, EnvVar: "REQUIRED_FLAG"},
},
envVarInput: [2]string{"REQUIRED_FLAG", "true"},
},
{
testCase: "required_and_optional",
flags: []Flag{
StringFlag{Name: "requiredFlag", Required: true},
StringFlag{Name: "optionalFlag"},
},
expectedAnError: true,
},
{
testCase: "required_and_optional_and_optional_present",
flags: []Flag{
StringFlag{Name: "requiredFlag", Required: true},
StringFlag{Name: "optionalFlag"},
},
parseInput: []string{"--optionalFlag", "myinput"},
expectedAnError: true,
},
{
testCase: "required_and_optional_and_optional_present_via_env_var",
flags: []Flag{
StringFlag{Name: "requiredFlag", Required: true},
StringFlag{Name: "optionalFlag", EnvVar: "OPTIONAL_FLAG"},
},
envVarInput: [2]string{"OPTIONAL_FLAG", "true"},
expectedAnError: true,
},
{
testCase: "required_and_optional_and_required_present",
flags: []Flag{
StringFlag{Name: "requiredFlag", Required: true},
StringFlag{Name: "optionalFlag"},
},
parseInput: []string{"--requiredFlag", "myinput"},
},
{
testCase: "two_required",
flags: []Flag{
StringFlag{Name: "requiredFlagOne", Required: true},
StringFlag{Name: "requiredFlagTwo", Required: true},
},
expectedAnError: true,
expectedErrorContents: []string{"requiredFlagOne", "requiredFlagTwo"},
},
{
testCase: "two_required_and_one_present",
flags: []Flag{
StringFlag{Name: "requiredFlag", Required: true},
StringFlag{Name: "requiredFlagTwo", Required: true},
},
parseInput: []string{"--requiredFlag", "myinput"},
expectedAnError: true,
},
{
testCase: "two_required_and_both_present",
flags: []Flag{
StringFlag{Name: "requiredFlag", Required: true},
StringFlag{Name: "requiredFlagTwo", Required: true},
},
parseInput: []string{"--requiredFlag", "myinput", "--requiredFlagTwo", "myinput"},
},
}
for _, test := range tdata {
t.Run(test.testCase, func(t *testing.T) {
// setup
set := flag.NewFlagSet("test", 0)
for _, flags := range test.flags {
flags.Apply(set)
}
set.Parse(test.parseInput)
if test.envVarInput[0] != "" {
os.Clearenv()
os.Setenv(test.envVarInput[0], test.envVarInput[1])
}
ctx := &Context{}
context := NewContext(ctx.App, set, ctx)
context.Command.Flags = test.flags
// logic under test
err := checkRequiredFlags(test.flags, context)
// assertions
if test.expectedAnError && err == nil {
t.Errorf("expected an error, but there was none")
}
if !test.expectedAnError && err != nil {
t.Errorf("did not expected an error, but there was one: %s", err)
}
for _, errString := range test.expectedErrorContents {
if !strings.Contains(err.Error(), errString) {
t.Errorf("expected error %q to contain %q, but it didn't!", err.Error(), errString)
}
}
})
}
}

@ -75,6 +75,14 @@ type Flag interface {
GetName() string
}
// RequiredFlag is an interface that allows us to mark flags as required
// it allows flags required flags to be backwards compatible with the Flag interface
type RequiredFlag interface {
Flag
IsRequired() bool
}
// errorableFlag is an interface that allows us to return errors during apply
// it allows flags defined in this library to return errors in a fashion backwards compatible
// TODO remove in v2 and modify the existing Flag interface to return errors

@ -14,6 +14,7 @@ type BoolFlag struct {
Usage string
EnvVar string
FilePath string
Required bool
Hidden bool
Destination *bool
}
@ -29,6 +30,11 @@ func (f BoolFlag) GetName() string {
return f.Name
}
// IsRequired returns the whether or not the flag is required
func (f BoolFlag) IsRequired() bool {
return f.Required
}
// Bool looks up the value of a local BoolFlag, returns
// false if not found
func (c *Context) Bool(name string) bool {
@ -62,6 +68,7 @@ type BoolTFlag struct {
Usage string
EnvVar string
FilePath string
Required bool
Hidden bool
Destination *bool
}
@ -77,6 +84,11 @@ func (f BoolTFlag) GetName() string {
return f.Name
}
// IsRequired returns the whether or not the flag is required
func (f BoolTFlag) IsRequired() bool {
return f.Required
}
// BoolT looks up the value of a local BoolTFlag, returns
// false if not found
func (c *Context) BoolT(name string) bool {
@ -110,6 +122,7 @@ type DurationFlag struct {
Usage string
EnvVar string
FilePath string
Required bool
Hidden bool
Value time.Duration
Destination *time.Duration
@ -126,6 +139,11 @@ func (f DurationFlag) GetName() string {
return f.Name
}
// IsRequired returns the whether or not the flag is required
func (f DurationFlag) IsRequired() bool {
return f.Required
}
// Duration looks up the value of a local DurationFlag, returns
// 0 if not found
func (c *Context) Duration(name string) time.Duration {
@ -159,6 +177,7 @@ type Float64Flag struct {
Usage string
EnvVar string
FilePath string
Required bool
Hidden bool
Value float64
Destination *float64
@ -175,6 +194,11 @@ func (f Float64Flag) GetName() string {
return f.Name
}
// IsRequired returns the whether or not the flag is required
func (f Float64Flag) IsRequired() bool {
return f.Required
}
// Float64 looks up the value of a local Float64Flag, returns
// 0 if not found
func (c *Context) Float64(name string) float64 {
@ -208,6 +232,7 @@ type GenericFlag struct {
Usage string
EnvVar string
FilePath string
Required bool
Hidden bool
Value Generic
}
@ -223,6 +248,11 @@ func (f GenericFlag) GetName() string {
return f.Name
}
// IsRequired returns the whether or not the flag is required
func (f GenericFlag) IsRequired() bool {
return f.Required
}
// Generic looks up the value of a local GenericFlag, returns
// nil if not found
func (c *Context) Generic(name string) interface{} {
@ -256,6 +286,7 @@ type Int64Flag struct {
Usage string
EnvVar string
FilePath string
Required bool
Hidden bool
Value int64
Destination *int64
@ -272,6 +303,11 @@ func (f Int64Flag) GetName() string {
return f.Name
}
// IsRequired returns the whether or not the flag is required
func (f Int64Flag) IsRequired() bool {
return f.Required
}
// Int64 looks up the value of a local Int64Flag, returns
// 0 if not found
func (c *Context) Int64(name string) int64 {
@ -305,6 +341,7 @@ type IntFlag struct {
Usage string
EnvVar string
FilePath string
Required bool
Hidden bool
Value int
Destination *int
@ -321,6 +358,11 @@ func (f IntFlag) GetName() string {
return f.Name
}
// IsRequired returns the whether or not the flag is required
func (f IntFlag) IsRequired() bool {
return f.Required
}
// Int looks up the value of a local IntFlag, returns
// 0 if not found
func (c *Context) Int(name string) int {
@ -354,6 +396,7 @@ type IntSliceFlag struct {
Usage string
EnvVar string
FilePath string
Required bool
Hidden bool
Value *IntSlice
}
@ -369,6 +412,11 @@ func (f IntSliceFlag) GetName() string {
return f.Name
}
// IsRequired returns the whether or not the flag is required
func (f IntSliceFlag) IsRequired() bool {
return f.Required
}
// IntSlice looks up the value of a local IntSliceFlag, returns
// nil if not found
func (c *Context) IntSlice(name string) []int {
@ -402,6 +450,7 @@ type Int64SliceFlag struct {
Usage string
EnvVar string
FilePath string
Required bool
Hidden bool
Value *Int64Slice
}
@ -417,6 +466,11 @@ func (f Int64SliceFlag) GetName() string {
return f.Name
}
// IsRequired returns the whether or not the flag is required
func (f Int64SliceFlag) IsRequired() bool {
return f.Required
}
// Int64Slice looks up the value of a local Int64SliceFlag, returns
// nil if not found
func (c *Context) Int64Slice(name string) []int64 {
@ -450,6 +504,7 @@ type StringFlag struct {
Usage string
EnvVar string
FilePath string
Required bool
Hidden bool
Value string
Destination *string
@ -466,6 +521,11 @@ func (f StringFlag) GetName() string {
return f.Name
}
// IsRequired returns the whether or not the flag is required
func (f StringFlag) IsRequired() bool {
return f.Required
}
// String looks up the value of a local StringFlag, returns
// "" if not found
func (c *Context) String(name string) string {
@ -499,6 +559,7 @@ type StringSliceFlag struct {
Usage string
EnvVar string
FilePath string
Required bool
Hidden bool
Value *StringSlice
}
@ -514,6 +575,11 @@ func (f StringSliceFlag) GetName() string {
return f.Name
}
// IsRequired returns the whether or not the flag is required
func (f StringSliceFlag) IsRequired() bool {
return f.Required
}
// StringSlice looks up the value of a local StringSliceFlag, returns
// nil if not found
func (c *Context) StringSlice(name string) []string {
@ -547,6 +613,7 @@ type Uint64Flag struct {
Usage string
EnvVar string
FilePath string
Required bool
Hidden bool
Value uint64
Destination *uint64
@ -563,6 +630,11 @@ func (f Uint64Flag) GetName() string {
return f.Name
}
// IsRequired returns the whether or not the flag is required
func (f Uint64Flag) IsRequired() bool {
return f.Required
}
// Uint64 looks up the value of a local Uint64Flag, returns
// 0 if not found
func (c *Context) Uint64(name string) uint64 {
@ -596,6 +668,7 @@ type UintFlag struct {
Usage string
EnvVar string
FilePath string
Required bool
Hidden bool
Value uint
Destination *uint
@ -612,6 +685,11 @@ func (f UintFlag) GetName() string {
return f.Name
}
// IsRequired returns the whether or not the flag is required
func (f UintFlag) IsRequired() bool {
return f.Required
}
// Uint looks up the value of a local UintFlag, returns
// 0 if not found
func (c *Context) Uint(name string) uint {

@ -143,6 +143,7 @@ def _write_cli_flag_types(outfile, types):
Usage string
EnvVar string
FilePath string
Required bool
Hidden bool
""".format(**typedef))
@ -170,6 +171,11 @@ def _write_cli_flag_types(outfile, types):
return f.Name
}}
// IsRequired returns the whether or not the flag is required
func (f {name}Flag) IsRequired() bool {{
return f.Required
}}
// {name} looks up the value of a local {name}Flag, returns
// {context_default} if not found
func (c *Context) {name}(name string) {context_type} {{

Loading…
Cancel
Save