add typed error assertions
This commit is contained in:
parent
d4740d10d0
commit
78db152323
@ -1002,6 +1002,9 @@ func TestRequiredFlagAppRunBehavior(t *testing.T) {
|
|||||||
if test.expectedAnError && err == nil {
|
if test.expectedAnError && err == nil {
|
||||||
t.Errorf("expected an error, but there was none")
|
t.Errorf("expected an error, but there was none")
|
||||||
}
|
}
|
||||||
|
if _, ok := err.(requiredFlagsErr); test.expectedAnError && !ok {
|
||||||
|
t.Errorf("expected a requiredFlagsErr, but got: %s", err)
|
||||||
|
}
|
||||||
if !test.expectedAnError && err != nil {
|
if !test.expectedAnError && err != nil {
|
||||||
t.Errorf("did not expected an error, but there was one: %s", err)
|
t.Errorf("did not expected an error, but there was one: %s", err)
|
||||||
}
|
}
|
||||||
|
33
context.go
33
context.go
@ -287,7 +287,29 @@ func normalizeFlags(flags []Flag, set *flag.FlagSet) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkRequiredFlags(flags []Flag, set *flag.FlagSet) error {
|
type requiredFlagsErr interface {
|
||||||
|
error
|
||||||
|
getMissingFlags() []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type errRequiredFlags struct {
|
||||||
|
missingFlags []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errRequiredFlags) Error() string {
|
||||||
|
numberOfMissingFlags := len(e.missingFlags)
|
||||||
|
if numberOfMissingFlags == 1 {
|
||||||
|
return fmt.Sprintf("Required flag %q not set", e.missingFlags[0])
|
||||||
|
}
|
||||||
|
joinedMissingFlags := strings.Join(e.missingFlags, ", ")
|
||||||
|
return fmt.Sprintf("Required flags %q not set", joinedMissingFlags)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errRequiredFlags) getMissingFlags() []string {
|
||||||
|
return e.missingFlags
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkRequiredFlags(flags []Flag, set *flag.FlagSet) requiredFlagsErr {
|
||||||
visited := make(map[string]bool)
|
visited := make(map[string]bool)
|
||||||
set.Visit(func(f *flag.Flag) {
|
set.Visit(func(f *flag.Flag) {
|
||||||
visited[f.Name] = true
|
visited[f.Name] = true
|
||||||
@ -303,13 +325,8 @@ func checkRequiredFlags(flags []Flag, set *flag.FlagSet) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
numberOfMissingFlags := len(missingFlags)
|
if len(missingFlags) != 0 {
|
||||||
if numberOfMissingFlags == 1 {
|
return &errRequiredFlags{missingFlags: missingFlags}
|
||||||
return fmt.Errorf("Required flag %q not set", missingFlags[0])
|
|
||||||
}
|
|
||||||
if numberOfMissingFlags >= 2 {
|
|
||||||
joinedMissingFlags := strings.Join(missingFlags, ", ")
|
|
||||||
return fmt.Errorf("Required flags %q not set", joinedMissingFlags)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
Loading…
Reference in New Issue
Block a user