Merge pull request #96 from ghigt/master

Add time.Duration flag type
This commit is contained in:
Jeremy Saenz 2014-08-02 14:44:59 -07:00
commit aa19913be3
4 changed files with 99 additions and 3 deletions

View File

@ -5,6 +5,7 @@ import (
"flag" "flag"
"strconv" "strconv"
"strings" "strings"
"time"
) )
// Context is a type that is passed through to // Context is a type that is passed through to
@ -29,6 +30,11 @@ func (c *Context) Int(name string) int {
return lookupInt(name, c.flagSet) return lookupInt(name, c.flagSet)
} }
// Looks up the value of a local time.Duration flag, returns 0 if no time.Duration flag exists
func (c *Context) Duration(name string) time.Duration {
return lookupDuration(name, c.flagSet)
}
// Looks up the value of a local float64 flag, returns 0 if no float64 flag exists // Looks up the value of a local float64 flag, returns 0 if no float64 flag exists
func (c *Context) Float64(name string) float64 { func (c *Context) Float64(name string) float64 {
return lookupFloat64(name, c.flagSet) return lookupFloat64(name, c.flagSet)
@ -69,6 +75,11 @@ func (c *Context) GlobalInt(name string) int {
return lookupInt(name, c.globalSet) return lookupInt(name, c.globalSet)
} }
// Looks up the value of a global time.Duration flag, returns 0 if no time.Duration flag exists
func (c *Context) GlobalDuration(name string) time.Duration {
return lookupDuration(name, c.globalSet)
}
// Looks up the value of a global bool flag, returns false if no bool flag exists // Looks up the value of a global bool flag, returns false if no bool flag exists
func (c *Context) GlobalBool(name string) bool { func (c *Context) GlobalBool(name string) bool {
return lookupBool(name, c.globalSet) return lookupBool(name, c.globalSet)
@ -174,6 +185,18 @@ func lookupInt(name string, set *flag.FlagSet) int {
return 0 return 0
} }
func lookupDuration(name string, set *flag.FlagSet) time.Duration {
f := set.Lookup(name)
if f != nil {
val, err := time.ParseDuration(f.Value.String())
if err == nil {
return val
}
}
return 0
}
func lookupFloat64(name string, set *flag.FlagSet) float64 { func lookupFloat64(name string, set *flag.FlagSet) float64 {
f := set.Lookup(name) f := set.Lookup(name)
if f != nil { if f != nil {

View File

@ -2,8 +2,10 @@ package cli_test
import ( import (
"flag" "flag"
"github.com/codegangsta/cli"
"testing" "testing"
"time"
"github.com/codegangsta/cli"
) )
func TestNewContext(t *testing.T) { func TestNewContext(t *testing.T) {
@ -26,6 +28,13 @@ func TestContext_Int(t *testing.T) {
expect(t, c.Int("myflag"), 12) expect(t, c.Int("myflag"), 12)
} }
func TestContext_Duration(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Duration("myflag", time.Duration(12*time.Second), "doc")
c := cli.NewContext(nil, set, set)
expect(t, c.Duration("myflag"), time.Duration(12*time.Second))
}
func TestContext_String(t *testing.T) { func TestContext_String(t *testing.T) {
set := flag.NewFlagSet("test", 0) set := flag.NewFlagSet("test", 0)
set.String("myflag", "hello world", "doc") set.String("myflag", "hello world", "doc")

31
flag.go
View File

@ -6,6 +6,7 @@ import (
"os" "os"
"strconv" "strconv"
"strings" "strings"
"time"
) )
// This flag enables bash-completion for all commands and subcommands // This flag enables bash-completion for all commands and subcommands
@ -318,6 +319,36 @@ func (f IntFlag) getName() string {
return f.Name return f.Name
} }
type DurationFlag struct {
Name string
Value time.Duration
Usage string
EnvVar string
}
func (f DurationFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage))
}
func (f DurationFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
if envVal := os.Getenv(f.EnvVar); envVal != "" {
envValDuration, err := time.ParseDuration(envVal)
if err == nil {
f.Value = envValDuration
}
}
}
eachName(f.Name, func(name string) {
set.Duration(name, f.Value, f.Usage)
})
}
func (f DurationFlag) getName() string {
return f.Name
}
type Float64Flag struct { type Float64Flag struct {
Name string Name string
Value float64 Value float64

View File

@ -1,13 +1,13 @@
package cli_test package cli_test
import ( import (
"github.com/codegangsta/cli"
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
"github.com/codegangsta/cli"
) )
var boolFlagTests = []struct { var boolFlagTests = []struct {
@ -151,6 +151,39 @@ func TestIntFlagWithEnvVarHelpOutput(t *testing.T) {
} }
} }
var durationFlagTests = []struct {
name string
expected string
}{
{"help", "--help '0'\t"},
{"h", "-h '0'\t"},
}
func TestDurationFlagHelpOutput(t *testing.T) {
for _, test := range durationFlagTests {
flag := cli.DurationFlag{Name: test.name}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
func TestDurationFlagWithEnvVarHelpOutput(t *testing.T) {
os.Setenv("APP_BAR", "2h3m6s")
for _, test := range durationFlagTests {
flag := cli.DurationFlag{Name: test.name, EnvVar: "APP_BAR"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_BAR]") {
t.Errorf("%s does not end with [$APP_BAR]", output)
}
}
}
var intSliceFlagTests = []struct { var intSliceFlagTests = []struct {
name string name string
value *cli.IntSlice value *cli.IntSlice