From 50258f15eb8a64e95e5d9ec571f17e550e37fa69 Mon Sep 17 00:00:00 2001 From: Luap99 <45212748+Luap99@users.noreply.github.com> Date: Wed, 9 Sep 2020 17:34:51 +0200 Subject: [PATCH] Complete subcommands when TraverseChildren is set (#1171) * Complete subcommands when TraverseChildren is true in custom completion The current custom completion logic does not complete subcommands when a local flag is set. This is good unless TraverseChildren is set to true where local flags can be set on parent commands. This commit allows subcommands to be completed if TraverseChildren is set to true on the root cmd. Closes #1170 Signed-off-by: Paul Holzinger * Complete subcommands when TraverseChildren is true in bash completion The current bash completion logic does not complete subcommands when a local flag is set. There is also a bug where subcommands are sometimes still getting completed. see: #1172 If TraverseChildren is true we should allow subcommands to be completed even if a local flag is set. Closes #1172 Signed-off-by: Paul Holzinger --- bash_completions.go | 12 +++++--- bash_completions_test.go | 25 ++++++++++++++++ custom_completions.go | 30 +++++++++++++------ custom_completions_test.go | 60 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 114 insertions(+), 13 deletions(-) diff --git a/bash_completions.go b/bash_completions.go index ab428cc..846636d 100644 --- a/bash_completions.go +++ b/bash_completions.go @@ -495,12 +495,14 @@ func writeFlag(buf *bytes.Buffer, flag *pflag.Flag, cmd *Command) { func writeLocalNonPersistentFlag(buf *bytes.Buffer, flag *pflag.Flag) { name := flag.Name - format := " local_nonpersistent_flags+=(\"--%s" + format := " local_nonpersistent_flags+=(\"--%[1]s\")\n" if len(flag.NoOptDefVal) == 0 { - format += "=" + format += " local_nonpersistent_flags+=(\"--%[1]s=\")\n" } - format += "\")\n" buf.WriteString(fmt.Sprintf(format, name)) + if len(flag.Shorthand) > 0 { + buf.WriteString(fmt.Sprintf(" local_nonpersistent_flags+=(\"-%s\")\n", flag.Shorthand)) + } } // Setup annotations for go completions for registered flags @@ -535,7 +537,9 @@ func writeFlags(buf *bytes.Buffer, cmd *Command) { if len(flag.Shorthand) > 0 { writeShortFlag(buf, flag, cmd) } - if localNonPersistentFlags.Lookup(flag.Name) != nil { + // localNonPersistentFlags are used to stop the completion of subcommands when one is set + // if TraverseChildren is true we should allow to complete subcommands + if localNonPersistentFlags.Lookup(flag.Name) != nil && !cmd.Root().TraverseChildren { writeLocalNonPersistentFlag(buf, flag) } }) diff --git a/bash_completions_test.go b/bash_completions_test.go index eefa3de..2c182ba 100644 --- a/bash_completions_test.go +++ b/bash_completions_test.go @@ -193,6 +193,13 @@ func TestBashCompletions(t *testing.T) { checkOmit(t, output, `two_word_flags+=("--two-w-default")`) checkOmit(t, output, `two_word_flags+=("-T")`) + // check local nonpersistent flag + check(t, output, `local_nonpersistent_flags+=("--two")`) + check(t, output, `local_nonpersistent_flags+=("--two=")`) + check(t, output, `local_nonpersistent_flags+=("-t")`) + check(t, output, `local_nonpersistent_flags+=("--two-w-default")`) + check(t, output, `local_nonpersistent_flags+=("-T")`) + checkOmit(t, output, deprecatedCmd.Name()) // If available, run shellcheck against the script. @@ -235,3 +242,21 @@ func TestBashCompletionDeprecatedFlag(t *testing.T) { t.Errorf("expected completion to not include %q flag: Got %v", flagName, output) } } + +func TestBashCompletionTraverseChildren(t *testing.T) { + c := &Command{Use: "c", Run: emptyRun, TraverseChildren: true} + + c.Flags().StringP("string-flag", "s", "", "string flag") + c.Flags().BoolP("bool-flag", "b", false, "bool flag") + + buf := new(bytes.Buffer) + c.GenBashCompletion(buf) + output := buf.String() + + // check that local nonpersistent flag are not set since we have TraverseChildren set to true + checkOmit(t, output, `local_nonpersistent_flags+=("--string-flag")`) + checkOmit(t, output, `local_nonpersistent_flags+=("--string-flag=")`) + checkOmit(t, output, `local_nonpersistent_flags+=("-s")`) + checkOmit(t, output, `local_nonpersistent_flags+=("--bool-flag")`) + checkOmit(t, output, `local_nonpersistent_flags+=("-b")`) +} diff --git a/custom_completions.go b/custom_completions.go index c25c03e..fc85290 100644 --- a/custom_completions.go +++ b/custom_completions.go @@ -175,8 +175,16 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi toComplete := args[len(args)-1] trimmedArgs := args[:len(args)-1] + var finalCmd *Command + var finalArgs []string + var err error // Find the real command for which completion must be performed - finalCmd, finalArgs, err := c.Root().Find(trimmedArgs) + // check if we need to traverse here to parse local flags on parent commands + if c.Root().TraverseChildren { + finalCmd, finalArgs, err = c.Root().Traverse(trimmedArgs) + } else { + finalCmd, finalArgs, err = c.Root().Find(trimmedArgs) + } if err != nil { // Unable to find the real command. E.g., someInvalidCmd return c, []string{}, ShellCompDirectiveDefault, fmt.Errorf("Unable to find a command for arguments: %v", trimmedArgs) @@ -271,20 +279,24 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi var completions []string directive := ShellCompDirectiveDefault if flag == nil { - // Check if there are any local, non-persistent flags on the command-line foundLocalNonPersistentFlag := false - localNonPersistentFlags := finalCmd.LocalNonPersistentFlags() - finalCmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) { - if localNonPersistentFlags.Lookup(flag.Name) != nil && flag.Changed { - foundLocalNonPersistentFlag = true - } - }) + // If TraverseChildren is true on the root command we don't check for + // local flags because we can use a local flag on a parent command + if !finalCmd.Root().TraverseChildren { + // Check if there are any local, non-persistent flags on the command-line + localNonPersistentFlags := finalCmd.LocalNonPersistentFlags() + finalCmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) { + if localNonPersistentFlags.Lookup(flag.Name) != nil && flag.Changed { + foundLocalNonPersistentFlag = true + } + }) + } // Complete subcommand names, including the help command if len(finalArgs) == 0 && !foundLocalNonPersistentFlag { // We only complete sub-commands if: // - there are no arguments on the command-line and - // - there are no local, non-peristent flag on the command-line + // - there are no local, non-peristent flag on the command-line or TraverseChildren is true for _, subCmd := range finalCmd.Commands() { if subCmd.IsAvailableCommand() || subCmd == finalCmd.helpCommand { if strings.HasPrefix(subCmd.Name(), toComplete) { diff --git a/custom_completions_test.go b/custom_completions_test.go index 276b8a7..fbc6a53 100644 --- a/custom_completions_test.go +++ b/custom_completions_test.go @@ -139,6 +139,8 @@ func TestNoCmdNameCompletionInGo(t *testing.T) { Use: "root", Run: emptyRun, } + rootCmd.Flags().String("localroot", "", "local root flag") + childCmd1 := &Command{ Use: "childCmd1", Short: "First command", @@ -187,6 +189,64 @@ func TestNoCmdNameCompletionInGo(t *testing.T) { t.Errorf("expected: %q, got: %q", expected, output) } + // Test that sub-command names are completed if a local non-persistent flag is present and TraverseChildren is set to true + // set TraverseChildren to true on the root cmd + rootCmd.TraverseChildren = true + + output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--localroot", "value", "") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + // Reset TraverseChildren for next command + rootCmd.TraverseChildren = false + + expected = strings.Join([]string{ + "childCmd1", + "help", + ":4", + "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n") + + if output != expected { + t.Errorf("expected: %q, got: %q", expected, output) + } + + // Test that sub-command names from a child cmd are completed if a local non-persistent flag is present + // and TraverseChildren is set to true on the root cmd + rootCmd.TraverseChildren = true + + output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--localroot", "value", "childCmd1", "--nonPersistent", "value", "") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + // Reset TraverseChildren for next command + rootCmd.TraverseChildren = false + // Reset the flag for the next command + nonPersistentFlag.Changed = false + + expected = strings.Join([]string{ + "childCmd2", + ":4", + "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n") + + if output != expected { + t.Errorf("expected: %q, got: %q", expected, output) + } + + // Test that we don't use Traverse when we shouldn't. + // This command should not return a completion since the command line is invalid without TraverseChildren. + output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--localroot", "value", "childCmd1", "") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + expected = strings.Join([]string{ + ":0", + "Completion ended with directive: ShellCompDirectiveDefault", ""}, "\n") + + if output != expected { + t.Errorf("expected: %q, got: %q", expected, output) + } + // Test that sub-command names are not completed if a local non-persistent short flag is present output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "childCmd1", "-n", "value", "") if err != nil {