diff --git a/command.go b/command.go index 6ff47dd..1b0b17a 100644 --- a/command.go +++ b/command.go @@ -655,13 +655,37 @@ Loop: // argsMinusFirstX removes only the first x from args. Otherwise, commands that look like // openshift admin policy add-role-to-user admin my-user, lose the admin argument (arg[4]). -func argsMinusFirstX(args []string, x string) []string { - for i, y := range args { - if x == y { - ret := []string{} - ret = append(ret, args[:i]...) - ret = append(ret, args[i+1:]...) - return ret +// Special care needs to be taken not to remove a flag value. +func (c *Command) argsMinusFirstX(args []string, x string) []string { + if len(args) == 0 { + return args + } + c.mergePersistentFlags() + flags := c.Flags() + +Loop: + for pos := 0; pos < len(args); pos++ { + s := args[pos] + switch { + case s == "--": + // -- means we have reached the end of the parseable args. Break out of the loop now. + break Loop + case strings.HasPrefix(s, "--") && !strings.Contains(s, "=") && !hasNoOptDefVal(s[2:], flags): + fallthrough + case strings.HasPrefix(s, "-") && !strings.Contains(s, "=") && len(s) == 2 && !shortHasNoOptDefVal(s[1:], flags): + // This is a flag without a default value, and an equal sign is not used. Increment pos in order to skip + // over the next arg, because that is the value of this flag. + pos++ + continue + case !strings.HasPrefix(s, "-"): + // This is not a flag or a flag value. Check to see if it matches what we're looking for, and if so, + // return the args, excluding the one at this position. + if s == x { + ret := []string{} + ret = append(ret, args[:pos]...) + ret = append(ret, args[pos+1:]...) + return ret + } } } return args @@ -686,7 +710,7 @@ func (c *Command) Find(args []string) (*Command, []string, error) { cmd := c.findNext(nextSubCmd) if cmd != nil { - return innerfind(cmd, argsMinusFirstX(innerArgs, nextSubCmd)) + return innerfind(cmd, c.argsMinusFirstX(innerArgs, nextSubCmd)) } return c, innerArgs } diff --git a/command_test.go b/command_test.go index c023bd6..3fd36c7 100644 --- a/command_test.go +++ b/command_test.go @@ -2603,3 +2603,92 @@ func TestHelpflagCommandExecutedWithoutVersionSet(t *testing.T) { checkStringContains(t, output, HelpFlag) checkStringOmits(t, output, VersionFlag) } + +func TestFind(t *testing.T) { + var foo, bar string + root := &Command{ + Use: "root", + } + root.PersistentFlags().StringVarP(&foo, "foo", "f", "", "") + root.PersistentFlags().StringVarP(&bar, "bar", "b", "something", "") + + child := &Command{ + Use: "child", + } + root.AddCommand(child) + + testCases := []struct { + args []string + expectedFoundArgs []string + }{ + { + []string{"child"}, + []string{}, + }, + { + []string{"child", "child"}, + []string{"child"}, + }, + { + []string{"child", "foo", "child", "bar", "child", "baz", "child"}, + []string{"foo", "child", "bar", "child", "baz", "child"}, + }, + { + []string{"-f", "child", "child"}, + []string{"-f", "child"}, + }, + { + []string{"child", "-f", "child"}, + []string{"-f", "child"}, + }, + { + []string{"-b", "child", "child"}, + []string{"-b", "child"}, + }, + { + []string{"child", "-b", "child"}, + []string{"-b", "child"}, + }, + { + []string{"child", "-b"}, + []string{"-b"}, + }, + { + []string{"-b", "-f", "child", "child"}, + []string{"-b", "-f", "child"}, + }, + { + []string{"-f", "child", "-b", "something", "child"}, + []string{"-f", "child", "-b", "something"}, + }, + { + []string{"-f", "child", "child", "-b"}, + []string{"-f", "child", "-b"}, + }, + { + []string{"-f=child", "-b=something", "child"}, + []string{"-f=child", "-b=something"}, + }, + { + []string{"--foo", "child", "--bar", "something", "child"}, + []string{"--foo", "child", "--bar", "something"}, + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("%v", tc.args), func(t *testing.T) { + cmd, foundArgs, err := root.Find(tc.args) + if err != nil { + t.Fatal(err) + } + + if cmd != child { + t.Fatal("Expected cmd to be child, but it was not") + } + + if !reflect.DeepEqual(tc.expectedFoundArgs, foundArgs) { + t.Fatalf("Wrong args\nExpected: %v\nGot: %v", tc.expectedFoundArgs, foundArgs) + } + }) + } +}