Ensure all completion maps are initialized before use

This commit is contained in:
maxlandon 2023-09-30 18:30:43 +02:00
parent f1f260eb59
commit 53fb4ebbd1
No known key found for this signature in database
GPG Key ID: 2DE5C14975A86900
2 changed files with 23 additions and 0 deletions

View File

@ -534,6 +534,7 @@ func writeLocalNonPersistentFlag(buf io.StringWriter, flag *pflag.Flag) {
// prepareCustomAnnotationsForFlags setup annotations for go completions for registered flags
func prepareCustomAnnotationsForFlags(cmd *Command) {
cmd.initializeCompletionStorage()
cmd.flagCompletionMutex.RLock()
defer cmd.flagCompletionMutex.RUnlock()
for flag := range cmd.flagCompletionFunctions {

View File

@ -18,8 +18,10 @@ import (
"fmt"
"os"
"strings"
"sync"
"github.com/spf13/pflag"
flag "github.com/spf13/pflag"
)
const (
@ -134,9 +136,13 @@ func (c *Command) RegisterFlagCompletionFunc(flagName string, f func(cmd *Comman
if flag == nil {
return fmt.Errorf("RegisterFlagCompletionFunc: flag '%s' does not exist", flagName)
}
// Ensure none of our relevant fields are nil.
c.initializeCompletionStorage()
c.flagCompletionMutex.Lock()
defer c.flagCompletionMutex.Unlock()
// And attempt to bind the completion.
if _, exists := c.flagCompletionFunctions[flag]; exists {
return fmt.Errorf("RegisterFlagCompletionFunc: flag '%s' already registered", flagName)
}
@ -144,6 +150,20 @@ func (c *Command) RegisterFlagCompletionFunc(flagName string, f func(cmd *Comman
return nil
}
// initializeCompletionStorage is (and should be) called in all
// functions that make use of the command's flag completion functions.
func (c *Command) initializeCompletionStorage() {
if c.flagCompletionMutex == nil {
c.flagCompletionMutex = new(sync.RWMutex)
}
var completionFn func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective)
if c.flagCompletionFunctions == nil {
c.flagCompletionFunctions = make(map[*flag.Flag]completionFn, 0)
}
}
// Returns a string listing the different directive enabled in the specified parameter
func (d ShellCompDirective) string() string {
var directives []string
@ -478,6 +498,8 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi
// Find the completion function for the flag or command
var completionFn func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective)
if flag != nil && flagCompletion {
c.initializeCompletionStorage()
finalCmd.flagCompletionMutex.RLock()
completionFn = finalCmd.flagCompletionFunctions[flag]
finalCmd.flagCompletionMutex.RUnlock()