diff --git a/README.md b/README.md index e1b83a2..373a056 100644 --- a/README.md +++ b/README.md @@ -343,6 +343,19 @@ A flag can also be assigned locally which will only apply to that specific comma RootCmd.Flags().StringVarP(&Source, "source", "s", "", "Source directory to read from") ``` +### Local Flag on Parent Commands + +By default Cobra only parses local flags on the target command, any local flags on +parent commands are ignored. By enabling `Command.TraverseChildren` Cobra will +parse local flags on each command before executing the target command. + +```go +command := cobra.Command{ + Use: "print [OPTIONS] [COMMANDS]", + TraverseChildren: true, +} +``` + ### Bind Flags with Config You can also bind your flags with [viper](https://github.com/spf13/viper): diff --git a/command.go b/command.go index a0a633d..a23bb6d 100644 --- a/command.go +++ b/command.go @@ -125,6 +125,9 @@ type Command struct { // Must be > 0. SuggestionsMinimumDistance int + // TraverseChildren parses flags on all parents before executing child command. + TraverseChildren bool + // name is the command name, usually the executable's name. name string // commands is the list of commands supported by this program. @@ -475,13 +478,14 @@ func argsMinusFirstX(args []string, x string) []string { return args } +func isFlagArg(arg string) bool { + return ((len(arg) >= 3 && arg[1] == '-') || + (len(arg) >= 2 && arg[0] == '-' && arg[1] != '-')) +} + // Find the target command given the args and command tree // Meant to be run on the highest node. Only searches down. func (c *Command) Find(args []string) (*Command, []string, error) { - if c == nil { - return nil, nil, fmt.Errorf("Called find() on a nil Command") - } - var innerfind func(*Command, []string) (*Command, []string) innerfind = func(c *Command, innerArgs []string) (*Command, []string) { @@ -490,28 +494,11 @@ func (c *Command) Find(args []string) (*Command, []string, error) { return c, innerArgs } nextSubCmd := argsWOflags[0] - matches := make([]*Command, 0) - for _, cmd := range c.commands { - if cmd.Name() == nextSubCmd || cmd.HasAlias(nextSubCmd) { // exact name or alias match - return innerfind(cmd, argsMinusFirstX(innerArgs, nextSubCmd)) - } - if EnablePrefixMatching { - if strings.HasPrefix(cmd.Name(), nextSubCmd) { // prefix match - matches = append(matches, cmd) - } - for _, x := range cmd.Aliases { - if strings.HasPrefix(x, nextSubCmd) { - matches = append(matches, cmd) - } - } - } - } - // only accept a single prefix match - multiple matches would be ambiguous - if len(matches) == 1 { - return innerfind(matches[0], argsMinusFirstX(innerArgs, argsWOflags[0])) + cmd := c.findNext(nextSubCmd) + if cmd != nil { + return innerfind(cmd, argsMinusFirstX(innerArgs, nextSubCmd)) } - return c, innerArgs } @@ -539,6 +526,66 @@ func (c *Command) findSuggestions(arg string) string { return suggestionsString } +func (c *Command) findNext(next string) *Command { + matches := make([]*Command, 0) + for _, cmd := range c.commands { + if cmd.Name() == next || cmd.HasAlias(next) { + return cmd + } + if EnablePrefixMatching && cmd.hasNameOrAliasPrefix(next) { + matches = append(matches, cmd) + } + } + + if len(matches) == 1 { + return matches[0] + } + return nil +} + +// Traverse the command tree to find the command, and parse args for +// each parent. +func (c *Command) Traverse(args []string) (*Command, []string, error) { + flags := []string{} + inFlag := false + + for i, arg := range args { + switch { + // A long flag with a space separated value + case strings.HasPrefix(arg, "--") && !strings.Contains(arg, "="): + // TODO: this isn't quite right, we should really check ahead for 'true' or 'false' + inFlag = !hasNoOptDefVal(arg[2:], c.Flags()) + flags = append(flags, arg) + continue + // A short flag with a space separated value + case strings.HasPrefix(arg, "-") && !strings.Contains(arg, "=") && len(arg) == 2 && !shortHasNoOptDefVal(arg[1:], c.Flags()): + inFlag = true + flags = append(flags, arg) + continue + // The value for a flag + case inFlag: + inFlag = false + flags = append(flags, arg) + continue + // A flag without a value, or with an `=` separated value + case isFlagArg(arg): + flags = append(flags, arg) + continue + } + + cmd := c.findNext(arg) + if cmd == nil { + return c, args, nil + } + + if err := c.ParseFlags(flags); err != nil { + return nil, args, err + } + return cmd.Traverse(args[i+1:]) + } + return c, args, nil +} + // SuggestionsFor provides suggestions for the typedName. func (c *Command) SuggestionsFor(typedName string) []string { suggestions := []string{} @@ -714,7 +761,12 @@ func (c *Command) ExecuteC() (cmd *Command, err error) { args = c.args } - cmd, flags, err := c.Find(args) + var flags []string + if c.TraverseChildren { + cmd, flags, err = c.Traverse(args) + } else { + cmd, flags, err = c.Find(args) + } if err != nil { // If found parse to a subcommand and then failed, talk about the subcommand if cmd != nil { @@ -726,6 +778,7 @@ func (c *Command) ExecuteC() (cmd *Command, err error) { } return c, err } + err = cmd.execute(flags) if err != nil { // Always show help if requested, even if SilenceErrors is in @@ -993,7 +1046,21 @@ func (c *Command) HasAlias(s string) bool { return false } -// NameAndAliases returns string containing name and all aliases +// hasNameOrAliasPrefix returns true if the Name or any of aliases start +// with prefix +func (c *Command) hasNameOrAliasPrefix(prefix string) bool { + if strings.HasPrefix(c.Name(), prefix) { + return true + } + for _, alias := range c.Aliases { + if strings.HasPrefix(alias, prefix) { + return true + } + } + return false +} + +// NameAndAliases returns a list of the command name and all aliases func (c *Command) NameAndAliases() string { return strings.Join(append([]string{c.Name()}, c.Aliases...), ", ") } diff --git a/command_test.go b/command_test.go index aa6658f..938a741 100644 --- a/command_test.go +++ b/command_test.go @@ -347,3 +347,94 @@ func TestSetHelpCommand(t *testing.T) { t.Errorf("Expected to contain %q message, but got %q", correctMessage, output.String()) } } + +func TestTraverseWithParentFlags(t *testing.T) { + cmd := &Command{ + Use: "do", + TraverseChildren: true, + } + cmd.Flags().String("foo", "", "foo things") + cmd.Flags().BoolP("goo", "g", false, "foo things") + + sub := &Command{Use: "next"} + sub.Flags().String("add", "", "add things") + cmd.AddCommand(sub) + + c, args, err := cmd.Traverse([]string{"-g", "--foo", "ok", "next", "--add"}) + if err != nil { + t.Fatalf("Expected no error: %s", err) + } + if len(args) != 1 && args[0] != "--add" { + t.Fatalf("wrong args %s", args) + } + if c.Name() != sub.Name() { + t.Fatalf("wrong command %q expected %q", c.Name(), sub.Name()) + } +} + +func TestTraverseNoParentFlags(t *testing.T) { + cmd := &Command{ + Use: "do", + TraverseChildren: true, + } + cmd.Flags().String("foo", "", "foo things") + + sub := &Command{Use: "next"} + sub.Flags().String("add", "", "add things") + cmd.AddCommand(sub) + + c, args, err := cmd.Traverse([]string{"next"}) + if err != nil { + t.Fatalf("Expected no error: %s", err) + } + if len(args) != 0 { + t.Fatalf("wrong args %s", args) + } + if c.Name() != sub.Name() { + t.Fatalf("wrong command %q expected %q", c.Name(), sub.Name()) + } +} + +func TestTraverseWithBadParentFlags(t *testing.T) { + cmd := &Command{ + Use: "do", + TraverseChildren: true, + } + sub := &Command{Use: "next"} + sub.Flags().String("add", "", "add things") + cmd.AddCommand(sub) + + expected := "got unknown flag: --add" + + c, _, err := cmd.Traverse([]string{"--add", "ok", "next"}) + if err == nil || strings.Contains(err.Error(), expected) { + t.Fatalf("Expected error %s got %s", expected, err) + } + if c != nil { + t.Fatalf("Expected nil command") + } +} + +func TestTraverseWithBadChildFlag(t *testing.T) { + cmd := &Command{ + Use: "do", + TraverseChildren: true, + } + cmd.Flags().String("foo", "", "foo things") + + sub := &Command{Use: "next"} + cmd.AddCommand(sub) + + // Expect no error because the last commands args shouldn't be parsed in + // Traverse + c, args, err := cmd.Traverse([]string{"next", "--add"}) + if err != nil { + t.Fatalf("Expected no error: %s", err) + } + if len(args) != 1 && args[0] != "--add" { + t.Fatalf("wrong args %s", args) + } + if c.Name() != sub.Name() { + t.Fatalf("wrong command %q expected %q", c.Name(), sub.Name()) + } +}