diff --git a/completions.go b/completions.go index c00daec..52e4266 100644 --- a/completions.go +++ b/completions.go @@ -144,7 +144,7 @@ func (c *Command) RegisterFlagCompletionFunc(flagName string, f func(cmd *Comman } // GetFlagCompletion returns the completion function for the given flag, if available. -func (c *Command) GetFlagCompletion(flag *pflag.Flag) (func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective), bool) { +func (c *Command) GetFlagCompletionFunc(flag *pflag.Flag) (func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective), bool) { c.initializeCompletionStorage() c.flagCompletionMutex.RLock() @@ -163,17 +163,17 @@ func (c *Command) GetFlagCompletion(flag *pflag.Flag) (func(cmd *Command, args [ } // Or walk up the command tree. - return c.Parent().GetFlagCompletion(flag) + return c.Parent().GetFlagCompletionFunc(flag) } // GetFlagCompletionByName returns the completion function for the given flag in the command by name, if available. -func (c *Command) GetFlagCompletionByName(flagName string) (func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective), bool) { +func (c *Command) GetFlagCompletionFuncByName(flagName string) (func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective), bool) { flag := c.Flags().Lookup(flagName) if flag == nil { return nil, false } - return c.GetFlagCompletion(flag) + return c.GetFlagCompletionFunc(flag) } // initializeCompletionStorage is (and should be) called in all @@ -531,7 +531,7 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi // Find the completion function for the flag or command var completionFn func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) if flag != nil && flagCompletion { - completionFn, _ = finalCmd.GetFlagCompletion(flag) + completionFn, _ = finalCmd.GetFlagCompletionFunc(flag) } else { completionFn = finalCmd.ValidArgsFunction } diff --git a/completions_test.go b/completions_test.go index 7585f88..d5aee25 100644 --- a/completions_test.go +++ b/completions_test.go @@ -3427,3 +3427,93 @@ Completion ended with directive: ShellCompDirectiveNoFileComp }) } } + +func TestGetFlagCompletion(t *testing.T) { + rootCmd := &Command{Use: "root", Run: emptyRun} + + rootCmd.Flags().String("rootflag", "", "root flag") + _ = rootCmd.RegisterFlagCompletionFunc("rootflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) { + return []string{"rootvalue"}, ShellCompDirectiveKeepOrder + }) + + rootCmd.PersistentFlags().String("persistentflag", "", "persistent flag") + _ = rootCmd.RegisterFlagCompletionFunc("persistentflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) { + return []string{"persistentvalue"}, ShellCompDirectiveDefault + }) + + childCmd := &Command{Use: "child", Run: emptyRun} + + childCmd.Flags().String("childflag", "", "child flag") + _ = childCmd.RegisterFlagCompletionFunc("childflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) { + return []string{"childvalue"}, ShellCompDirectiveNoFileComp | ShellCompDirectiveNoSpace + }) + + rootCmd.AddCommand(childCmd) + + testcases := []struct { + desc string + cmd *Command + flagName string + exists bool + comps []string + directive ShellCompDirective + }{ + { + desc: "get flag completion function for command", + cmd: rootCmd, + flagName: "rootflag", + exists: true, + comps: []string{"rootvalue"}, + directive: ShellCompDirectiveKeepOrder, + }, + { + desc: "get persistent flag completion function for command", + cmd: rootCmd, + flagName: "persistentflag", + exists: true, + comps: []string{"persistentvalue"}, + directive: ShellCompDirectiveDefault, + }, + { + desc: "get flag completion function for child command", + cmd: childCmd, + flagName: "childflag", + exists: true, + comps: []string{"childvalue"}, + directive: ShellCompDirectiveNoFileComp | ShellCompDirectiveNoSpace, + }, + { + desc: "get persistent flag completion function for child command", + cmd: childCmd, + flagName: "persistentflag", + exists: true, + comps: []string{"persistentvalue"}, + directive: ShellCompDirectiveDefault, + }, + { + desc: "cannot get flag completion function for local parent flag", + cmd: childCmd, + flagName: "rootflag", + exists: false, + }, + } + + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + compFunc, exists := tc.cmd.GetFlagCompletionFunc(tc.flagName) + if tc.exists != exists { + t.Errorf("Unexpected result looking for flag completion function") + } + + if exists { + comps, directive := compFunc(tc.cmd, []string{}, "") + if strings.Join(tc.comps, " ") != strings.Join(comps, " ") { + t.Errorf("Unexpected completions %q", comps) + } + if tc.directive != directive { + t.Errorf("Unexpected directive %q", directive) + } + } + }) + } +}