Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Setter and Setv() #240

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions bool.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ func (b *boolValue) Set(s string) error {
return err
}

func (b *boolValue) Setv(v interface{}) error {
switch tv := v.(type) {
case bool:
*b = boolValue(tv)
case string:
return b.Set(tv)
default:
return ErrSetv
}
return nil
}

func (b *boolValue) Type() string {
return "bool"
}
Expand Down
11 changes: 11 additions & 0 deletions bool_slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ func (s *boolSliceValue) Set(val string) error {
return nil
}

// Setv tries its best to set the value from an arbitrary type
func (s *boolSliceValue) Setv(v interface{}) error {
switch tv := v.(type) {
case []bool:
*s.value = tv
default:
return ErrSetv
}
return nil
}

// Type returns a string that uniquely represents this flag's type.
func (s *boolSliceValue) Type() string {
return "boolSlice"
Expand Down
25 changes: 25 additions & 0 deletions bytes.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ func (bytesHex *bytesHexValue) Set(value string) error {
return nil
}

func (bytesHex *bytesHexValue) Setv(v interface{}) error {
switch tv := v.(type) {
case []byte:
*bytesHex = tv
case string:
return bytesHex.Set(tv)
default:
return ErrSetv
}
return nil
}

// Type implements pflag.Value.Type.
func (*bytesHexValue) Type() string {
return "bytesHex"
Expand Down Expand Up @@ -129,6 +141,19 @@ func (bytesBase64 *bytesBase64Value) Set(value string) error {
return nil
}

// Setv tries its best to set the value from an arbitrary type
func (bytesBase64 *bytesBase64Value) Setv(v interface{}) error {
switch tv := v.(type) {
case []byte:
*bytesBase64 = tv
case string:
return bytesBase64.Set(tv)
default:
return ErrSetv
}
return nil
}

// Type implements pflag.Value.Type.
func (*bytesBase64Value) Type() string {
return "bytesBase64"
Expand Down
12 changes: 12 additions & 0 deletions duration.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ func (d *durationValue) Set(s string) error {
return err
}

func (d *durationValue) Setv(v interface{}) error {
switch tv := v.(type) {
case time.Duration:
*d = durationValue(tv)
case string:
return d.Set(tv)
default:
return ErrSetv
}
return nil
}

func (d *durationValue) Type() string {
return "duration"
}
Expand Down
10 changes: 10 additions & 0 deletions duration_slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ func (s *durationSliceValue) Set(val string) error {
return nil
}

func (s *durationSliceValue) Setv(v interface{}) error {
switch tv := v.(type) {
case []time.Duration:
*s.value = tv
default:
return ErrSetv
}
return nil
}

func (s *durationSliceValue) Type() string {
return "durationSlice"
}
Expand Down
31 changes: 28 additions & 3 deletions flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ import (
// ErrHelp is the error returned if the flag -help is invoked but no such flag is defined.
var ErrHelp = errors.New("pflag: help requested")

// ErrSetv is the error returned if Setv(interface{}) couldn't do a sane conversion.
var ErrSetv = errors.New("pflag: could not convert type for Setv()") // TODO: error wrapping once go 1.13 is the minimal version

// ErrorHandling defines how to handle flag parsing errors.
type ErrorHandling int

Expand Down Expand Up @@ -190,6 +193,11 @@ type Value interface {
Type() string
}

// Setter is the interface that represents a Value that can be set directly.
type Setter interface {
Setv(interface{}) error
}

// SliceValue is a secondary interface to all flags which hold a list
// of values. This allows full control over the value of list flags,
// and avoids complicated marshalling and unmarshalling to csv.
Expand Down Expand Up @@ -452,15 +460,27 @@ func ShorthandLookup(name string) *Flag {
return CommandLine.ShorthandLookup(name)
}

// Set sets the value of the named flag.
// Set sets the value of the named flag to a string.
func (f *FlagSet) Set(name, value string) error {
return f.Setv(name, value)
}

// Setv sets the value of the named flag.
func (f *FlagSet) Setv(name string, value interface{}) error {
normalName := f.normalizeFlagName(name)
flag, ok := f.formal[normalName]
if !ok {
return fmt.Errorf("no such flag -%v", name)
}

err := flag.Value.Set(value)
var err error
if s, ok := value.(string); ok {
err = flag.Value.Set(s) // guaranteed backwards-compat
} else if s, ok := flag.Value.(Setter); ok {
err = s.Setv(value)
} else {
err = ErrSetv
}
if err != nil {
var flagName string
if flag.Shorthand != "" && flag.ShorthandDeprecated == "" {
Expand Down Expand Up @@ -514,11 +534,16 @@ func (f *FlagSet) Changed(name string) bool {
return flag.Changed
}

// Set sets the value of the named command-line flag.
// Set sets the string value of the named command-line flag.
func Set(name, value string) error {
return CommandLine.Set(name, value)
}

// Setv sets the value of the named command-line flag.
func Setv(name string, value interface{}) error {
return CommandLine.Setv(name, value)
}

// PrintDefaults prints, to standard error unless configured
// otherwise, the default values of all defined flags in the set.
func (f *FlagSet) PrintDefaults() {
Expand Down
28 changes: 28 additions & 0 deletions flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,34 @@ func TestEverything(t *testing.T) {
t.Log(k, *v)
}
}
// Now set all flags to real values
if Setv("test_bool", true) != nil {
t.Error("Could not Setv bool to true")
}
if Setv("test_int", 1) != nil {
t.Error("Could not Setv int to 1")
}
if Setv("test_int64", 1) != nil {
t.Error("Could not Setv int64 to 1")
}
if Setv("test_uint", 1) != nil {
t.Error("Could not Setv uint to 1")
}
if Setv("test_uint64", 1) != nil {
t.Error("Could not Setv uint64 to 1")
}
if Setv("test_string", "1") != nil {
t.Error("Could not Setv string to \"1\"")
}
if Setv("test_float64", 1) != nil {
t.Error("Could not Setv float64 to 1")
}
if Setv("test_duration", time.Duration(1)) != nil {
t.Error("Could not Setv time.Duration to 1")
}
if Setv("test_optional_int", 1) != nil {
t.Error("Could not Setv optional int to 1")
}
// Now test they're visited in sort order.
var flagNames []string
Visit(func(f *Flag) { flagNames = append(flagNames, f.Name) })
Expand Down
20 changes: 20 additions & 0 deletions float32.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,26 @@ func (f *float32Value) Set(s string) error {
return err
}

func (f *float32Value) Setv(v interface{}) error {
switch tv := v.(type) {
case float32:
*f = float32Value(tv)
case int16:
*f = float32Value(tv)
case int8:
*f = float32Value(tv)
case uint16:
*f = float32Value(tv)
case uint8:
*f = float32Value(tv)
case string:
return f.Set(tv)
default:
return ErrSetv
}
return nil
}

func (f *float32Value) Type() string {
return "float32"
}
Expand Down
10 changes: 10 additions & 0 deletions float32_slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ func (s *float32SliceValue) Set(val string) error {
return nil
}

func (s *float32SliceValue) Setv(v interface{}) error {
switch tv := v.(type) {
case []float32:
*s.value = tv
default:
return ErrSetv
}
return nil
}

func (s *float32SliceValue) Type() string {
return "float32Slice"
}
Expand Down
30 changes: 30 additions & 0 deletions float64.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,36 @@ func (f *float64Value) Set(s string) error {
return err
}

func (f *float64Value) Setv(v interface{}) error {
switch tv := v.(type) {
case float64:
*f = float64Value(tv)
case float32:
*f = float64Value(tv)
case int:
*f = float64Value(tv)
case int32:
*f = float64Value(tv)
case int16:
*f = float64Value(tv)
case int8:
*f = float64Value(tv)
case uint:
*f = float64Value(tv)
case uint32:
*f = float64Value(tv)
case uint16:
*f = float64Value(tv)
case uint8:
*f = float64Value(tv)
case string:
return f.Set(tv)
default:
return ErrSetv
}
return nil
}

func (f *float64Value) Type() string {
return "float64"
}
Expand Down
10 changes: 10 additions & 0 deletions float64_slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ func (s *float64SliceValue) Set(val string) error {
return nil
}

func (s *float64SliceValue) Setv(v interface{}) error {
switch tv := v.(type) {
case []float64:
*s.value = tv
default:
return ErrSetv
}
return nil
}

func (s *float64SliceValue) Type() string {
return "float64Slice"
}
Expand Down
22 changes: 22 additions & 0 deletions int.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,28 @@ func (i *intValue) Set(s string) error {
return err
}

func (i *intValue) Setv(v interface{}) error {
switch tv := v.(type) {
case int:
*i = intValue(tv)
case int32:
*i = intValue(tv)
case int16:
*i = intValue(tv)
case int8:
*i = intValue(tv)
case uint16:
*i = intValue(tv)
case uint8:
*i = intValue(tv)
case string:
return i.Set(tv)
default:
return ErrSetv
}
return nil
}

func (i *intValue) Type() string {
return "int"
}
Expand Down
16 changes: 16 additions & 0 deletions int16.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@ func (i *int16Value) Set(s string) error {
return err
}

func (i *int16Value) Setv(v interface{}) error {
switch tv := v.(type) {
case int16:
*i = int16Value(tv)
case int8:
*i = int16Value(tv)
case uint8:
*i = int16Value(tv)
case string:
return i.Set(tv)
default:
return ErrSetv
}
return nil
}

func (i *int16Value) Type() string {
return "int16"
}
Expand Down
22 changes: 22 additions & 0 deletions int32.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,28 @@ func (i *int32Value) Set(s string) error {
return err
}

func (i *int32Value) Setv(v interface{}) error {
switch tv := v.(type) {
case int:
*i = int32Value(tv)
case int32:
*i = int32Value(tv)
case int16:
*i = int32Value(tv)
case int8:
*i = int32Value(tv)
case uint16:
*i = int32Value(tv)
case uint8:
*i = int32Value(tv)
case string:
return i.Set(tv)
default:
return ErrSetv
}
return nil
}

func (i *int32Value) Type() string {
return "int32"
}
Expand Down
Loading