Skip to content

Commit

Permalink
spf13#199 with some modifications (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
cornfeedhobo authored Aug 29, 2020
1 parent e43c76f commit ce4d9a3
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 5 deletions.
53 changes: 49 additions & 4 deletions flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ const (

// ParseErrorsWhitelist defines the parsing errors that can be ignored
type ParseErrorsWhitelist struct {
// UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags
// UnknownFlags will ignore unknown flags errors and continue parsing the rest of the flags.
// Consider using SetUnknownFlags/GetUnknownFlags if you need to know which unknown flags occured.
UnknownFlags bool
}

Expand Down Expand Up @@ -162,6 +163,7 @@ type FlagSet struct {
output io.Writer // nil means stderr; use Output() accessor
interspersed bool // allow interspersed option/non-option args
normalizeNameFunc func(f *FlagSet, name string) NormalizedName
unknownFlags *[]string

addedGoFlagSets []*goflag.FlagSet
}
Expand Down Expand Up @@ -964,10 +966,17 @@ func (f *FlagSet) usage() {
}
}

func (f *FlagSet) addUnknownFlag(s string) {
if f.unknownFlags == nil {
f.unknownFlags = new([]string)
}
*f.unknownFlags = append(*f.unknownFlags, s)
}

//--unknown (args will be empty)
//--unknown --next-flag ... (args will be --next-flag ...)
//--unknown arg ... (args will be arg ...)
func stripUnknownFlagValue(args []string) []string {
func (f *FlagSet) stripUnknownFlagValue(args []string) []string {
if len(args) == 0 {
//--unknown
return args
Expand All @@ -981,6 +990,7 @@ func stripUnknownFlagValue(args []string) []string {

//--unknown arg ... (args will be arg ...)
if len(args) > 1 {
f.addUnknownFlag(args[0])
return args[1:]
}
return nil
Expand All @@ -1007,13 +1017,14 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin
}
return
case f.ParseErrorsWhitelist.UnknownFlags:
f.addUnknownFlag(s)
// --unknown=unknownval arg ...
// we do not want to lose arg in this case
if len(split) >= 2 {
return a, nil
}

return stripUnknownFlagValue(a), nil
return f.stripUnknownFlagValue(a), nil
default:
err = f.failf("unknown flag: --%s", name)
return
Expand Down Expand Up @@ -1063,11 +1074,15 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse
// '-f=arg arg ...'
// we do not want to lose arg in this case
if len(shorthands) > 2 && shorthands[1] == '=' {
f.addUnknownFlag("-" + shorthands)
outShorts = ""
return
}

outArgs = stripUnknownFlagValue(outArgs)
f.addUnknownFlag("-" + string(c))
if len(outShorts) == 0 {
outArgs = f.stripUnknownFlagValue(outArgs)
}
return
default:
err = f.failf("unknown shorthand flag: %q in -%s", c, shorthands)
Expand Down Expand Up @@ -1223,6 +1238,21 @@ func (f *FlagSet) Parsed() bool {
return f.parsed
}

// SetUnknownFlags sets the store for unknown flags found during Parse.
// The argument s points to a slice variable in which to store the values.
// This requires ParseErrorsWhitelist.UnknownFlags to be set so that
// parsing does not abort on the first unknown flag.
func (f *FlagSet) SetUnknownFlags(s *[]string) {
f.unknownFlags = s
}

// GetUnknownFlags returns unknown flags found during Parse.
// This requires ParseErrorsWhitelist.UnknownFlags to be set so that
// parsing does not abort on the first unknown flag.
func (f *FlagSet) GetUnknownFlags() *[]string {
return f.unknownFlags
}

// Parse parses the command-line flags from os.Args[1:]. Must be called
// after all flags are defined and before flags are accessed by the program.
func Parse() {
Expand All @@ -1248,6 +1278,21 @@ func Parsed() bool {
return CommandLine.Parsed()
}

// SetUnknownFlags sets the store for unknown flags found during Parse.
// The argument s points to a slice variable in which to store the values.
// This requires ParseErrorsWhitelist.UnknownFlags to be set so that
// parsing does not abort on the first unknown flag.
func SetUnknownFlags(s *[]string) {
CommandLine.SetUnknownFlags(s)
}

// GetUnknownFlags returns unknown flags found during Parse.
// This requires ParseErrorsWhitelist.UnknownFlags to be set so that
// parsing does not abort on the first unknown flag.
func GetUnknownFlags() *[]string {
return CommandLine.GetUnknownFlags()
}

// CommandLine is the default set of command-line flags, parsed from os.Args.
var CommandLine = NewFlagSet(os.Args[0], ExitOnError)

Expand Down
22 changes: 21 additions & 1 deletion flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) {
t.Error("f.Parse() = true before Parse")
}
f.ParseErrorsWhitelist.UnknownFlags = true
var unknownFlags []string
f.SetUnknownFlags(&unknownFlags)

f.BoolP("boola", "a", false, "bool value")
f.BoolP("boolb", "b", false, "bool2 value")
Expand Down Expand Up @@ -455,6 +457,19 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) {
"stringo", "ovalue",
"boole", "true",
}
wantUnknowns := []string{
"--unknown1", "unknown1Value",
"--unknown2=unknown2Value",
"-u=unknown3Value",
"-p", "unknown4Value",
"-q",
"--unknown7=unknown7value",
"--unknown8=unknown8value",
"--unknown6", "",
"-u", "-u", "-u", "-u", "-u", "",
"--unknown10",
"--unknown11",
}
got := []string{}
store := func(flag *Flag, value string) error {
got = append(got, flag.Name)
Expand All @@ -470,10 +485,15 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) {
t.Errorf("f.Parse() = false after Parse")
}
if !reflect.DeepEqual(got, want) {
t.Errorf("f.ParseAll() fail to restore the args")
t.Errorf("f.Parse() failed to parse with unknown flags")
t.Errorf("Got: %v", got)
t.Errorf("Want: %v", want)
}
if !reflect.DeepEqual(unknownFlags, wantUnknowns) {
t.Errorf("f.Parse() failed to enumerate the unknown flags")
t.Errorf("Got: %v", unknownFlags)
t.Errorf("Want: %v", wantUnknowns)
}
}

func TestShorthand(t *testing.T) {
Expand Down

0 comments on commit ce4d9a3

Please sign in to comment.