diff --git a/completions_test.go b/completions_test.go index 4c6f41b..7585f88 100644 --- a/completions_test.go +++ b/completions_test.go @@ -17,7 +17,9 @@ package cobra import ( "bytes" "context" + "fmt" "strings" + "sync" "testing" ) @@ -2040,6 +2042,114 @@ func TestFlagCompletionWorksRootCommandAddedAfterFlags(t *testing.T) { } } +func TestFlagCompletionForPersistentFlagsCalledFromSubCmd(t *testing.T) { + rootCmd := &Command{Use: "root", Run: emptyRun} + rootCmd.PersistentFlags().String("string", "", "test string flag") + _ = rootCmd.RegisterFlagCompletionFunc("string", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) { + return []string{"myval"}, ShellCompDirectiveDefault + }) + + childCmd := &Command{ + Use: "child", + Run: emptyRun, + ValidArgsFunction: func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) { + return []string{"--validarg", "test"}, ShellCompDirectiveDefault + }, + } + childCmd.Flags().Bool("bool", false, "test bool flag") + rootCmd.AddCommand(childCmd) + + // Test that persistent flag completion works for the subcmd + output, err := executeCommand(rootCmd, ShellCompRequestCmd, "child", "--string", "") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + expected := strings.Join([]string{ + "myval", + ":0", + "Completion ended with directive: ShellCompDirectiveDefault", ""}, "\n") + + if output != expected { + t.Errorf("expected: %q, got: %q", expected, output) + } +} + +// This test tries to register flag completion concurrently to make sure the +// code handles concurrency properly. +// This was reported as a problem when tests are run concurrently: +// https://github.com/spf13/cobra/issues/1320 +// +// NOTE: this test can sometimes pass even if the code were to not handle +// concurrency properly. This is not great but the important part is that +// it should never fail. Therefore, if the tests fails sometimes, we will +// still be able to know there is a problem. +func TestFlagCompletionConcurrentRegistration(t *testing.T) { + rootCmd := &Command{Use: "root", Run: emptyRun} + const maxFlags = 50 + for i := 1; i < maxFlags; i += 2 { + flagName := fmt.Sprintf("flag%d", i) + rootCmd.Flags().String(flagName, "", fmt.Sprintf("test %s flag on root", flagName)) + } + + childCmd := &Command{ + Use: "child", + Run: emptyRun, + } + for i := 2; i <= maxFlags; i += 2 { + flagName := fmt.Sprintf("flag%d", i) + childCmd.Flags().String(flagName, "", fmt.Sprintf("test %s flag on child", flagName)) + } + + rootCmd.AddCommand(childCmd) + + // Register completion in different threads to test concurrency. + var wg sync.WaitGroup + for i := 1; i <= maxFlags; i++ { + index := i + flagName := fmt.Sprintf("flag%d", i) + wg.Add(1) + go func() { + defer wg.Done() + cmd := rootCmd + if index%2 == 0 { + cmd = childCmd + } + _ = cmd.RegisterFlagCompletionFunc(flagName, func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) { + return []string{fmt.Sprintf("flag%d", index)}, ShellCompDirectiveDefault + }) + }() + } + + wg.Wait() + + // Test that flag completion works for each flag + for i := 1; i <= 6; i++ { + var output string + var err error + flagName := fmt.Sprintf("flag%d", i) + + if i%2 == 1 { + output, err = executeCommand(rootCmd, ShellCompRequestCmd, "--"+flagName, "") + } else { + output, err = executeCommand(rootCmd, ShellCompRequestCmd, "child", "--"+flagName, "") + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + expected := strings.Join([]string{ + flagName, + ":0", + "Completion ended with directive: ShellCompDirectiveDefault", ""}, "\n") + + if output != expected { + t.Errorf("expected: %q, got: %q", expected, output) + } + } +} + func TestFlagCompletionInGoWithDesc(t *testing.T) { rootCmd := &Command{ Use: "root",