diff --git a/cobra_test.go b/cobra_test.go index 10a8a1e..71243b5 100644 --- a/cobra_test.go +++ b/cobra_test.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "os" + "reflect" "runtime" "strings" "testing" @@ -12,7 +13,7 @@ import ( var _ = fmt.Println var _ = os.Stderr -var tp, te, tt, t1 []string +var tp, te, tt, t1, tr []string var rootPersPre, echoPre, echoPersPre, timesPersPre []string var flagb1, flagb2, flagb3, flagbr, flagbp bool var flags1, flags2a, flags2b, flags3 string @@ -99,6 +100,7 @@ var cmdRootWithRun = &Command{ Short: "The root can run it's own function", Long: "The root description for help", Run: func(cmd *Command, args []string) { + tr = args rootcalled = true }, } @@ -181,7 +183,7 @@ func initializeWithSameName() *Command { func initializeWithRootCmd() *Command { cmdRootWithRun.ResetCommands() - tt, tp, te, rootcalled = nil, nil, nil, false + tt, tp, te, tr, rootcalled = nil, nil, nil, nil, false flagInit() cmdRootWithRun.Flags().BoolVarP(&flagbr, "boolroot", "b", false, "help message for flag boolroot") cmdRootWithRun.Flags().IntVarP(&flagir, "introot", "i", 321, "help message for flag introot") @@ -494,7 +496,7 @@ func TestTrailingCommandFlags(t *testing.T) { } } -func TestInvalidSubCommandFlags(t *testing.T) { +func TestInvalidSubcommandFlags(t *testing.T) { cmd := initializeWithRootCmd() cmd.AddCommand(cmdTimes) @@ -508,7 +510,7 @@ func TestInvalidSubCommandFlags(t *testing.T) { } -func TestSubCommandArgEvaluation(t *testing.T) { +func TestSubcommandArgEvaluation(t *testing.T) { cmd := initializeWithRootCmd() first := &Command{ @@ -819,6 +821,31 @@ func TestRemoveCommand(t *testing.T) { } } +func TestCommandWithoutSubcommands(t *testing.T) { + c := initializeWithRootCmd() + + x := simpleTester(c, "") + if x.Error != nil { + t.Errorf("Calling command without subcommands should not have error: %v", x.Error) + return + } +} + +func TestCommandWithoutSubcommandsWithArg(t *testing.T) { + c := initializeWithRootCmd() + expectedArgs := []string{"arg"} + + x := simpleTester(c, "arg") + if x.Error != nil { + t.Errorf("Calling command without subcommands but with arg should not have error: %v", x.Error) + return + } + if !reflect.DeepEqual(expectedArgs, tr) { + t.Errorf("Calling command without subcommands but with arg has wrong args: expected: %v, actual: %v", expectedArgs, tr) + return + } +} + func TestReplaceCommandWithRemove(t *testing.T) { versionUsed = 0 c := initializeWithRootCmd() diff --git a/command.go b/command.go index 978c4b0..7dd34f6 100644 --- a/command.go +++ b/command.go @@ -18,13 +18,14 @@ package cobra import ( "bytes" "fmt" - "github.com/inconshreveable/mousetrap" - flag "github.com/spf13/pflag" "io" "os" "runtime" "strings" "time" + + "github.com/inconshreveable/mousetrap" + flag "github.com/spf13/pflag" ) // Command is just that, a command for your application. @@ -360,25 +361,28 @@ func argsMinusFirstX(args []string, x string) []string { // find the target command given the args and command tree // Meant to be run on the highest node. Only searches down. -func (c *Command) Find(arrs []string) (*Command, []string, error) { +func (c *Command) Find(args []string) (*Command, []string, error) { if c == nil { return nil, nil, fmt.Errorf("Called find() on a nil Command") } - if len(arrs) == 0 { - return c.Root(), arrs, nil + // If there are no arguments, return the root command. If the root has no + // subcommands, args reflects arguments that should actually be passed to + // the root command, so also return the root command. + if len(args) == 0 || !c.Root().HasSubCommands() { + return c.Root(), args, nil } var innerfind func(*Command, []string) (*Command, []string) - innerfind = func(c *Command, args []string) (*Command, []string) { - if len(args) > 0 && c.HasSubCommands() { - argsWOflags := stripFlags(args, c) + innerfind = func(c *Command, innerArgs []string) (*Command, []string) { + if len(innerArgs) > 0 && c.HasSubCommands() { + argsWOflags := stripFlags(innerArgs, c) if len(argsWOflags) > 0 { matches := make([]*Command, 0) for _, cmd := range c.commands { if cmd.Name() == argsWOflags[0] || cmd.HasAlias(argsWOflags[0]) { // exact name or alias match - return innerfind(cmd, argsMinusFirstX(args, argsWOflags[0])) + return innerfind(cmd, argsMinusFirstX(innerArgs, argsWOflags[0])) } else if EnablePrefixMatching { if strings.HasPrefix(cmd.Name(), argsWOflags[0]) { // prefix match matches = append(matches, cmd) @@ -393,18 +397,18 @@ func (c *Command) Find(arrs []string) (*Command, []string, error) { // only accept a single prefix match - multiple matches would be ambiguous if len(matches) == 1 { - return innerfind(matches[0], argsMinusFirstX(args, argsWOflags[0])) + return innerfind(matches[0], argsMinusFirstX(innerArgs, argsWOflags[0])) } } } - return c, args + return c, innerArgs } - commandFound, a := innerfind(c, arrs) + commandFound, a := innerfind(c, args) // If we matched on the root, but we asked for a subcommand, return an error - if commandFound.Name() == c.Name() && len(stripFlags(arrs, c)) > 0 && commandFound.Name() != arrs[0] { + if commandFound.Name() == c.Name() && len(stripFlags(args, c)) > 0 && commandFound.Name() != args[0] { return nil, a, fmt.Errorf("unknown command %q", a[0]) }