diff --git a/bash_completions.go b/bash_completions.go index 716146f..236dee6 100644 --- a/bash_completions.go +++ b/bash_completions.go @@ -167,6 +167,11 @@ __handle_flag() must_have_one_flag=() fi + # if you set a flag which only applies to this command, don't show subcommands + if __contains_word "${flagname}" "${local_nonpersistent_flags[@]}"; then + commands=() + fi + # keep flag value with flagname as flaghash if [ -n "${flagvalue}" ] ; then flaghash[${flagname}]=${flagvalue} @@ -263,6 +268,7 @@ func postscript(w io.Writer, name string) error { local c=0 local flags=() local two_word_flags=() + local local_nonpersistent_flags=() local flags_with_completion=() local flags_completion=() local commands=("%s") @@ -360,7 +366,7 @@ func writeFlagHandler(name string, annotations map[string][]string, w io.Writer) } func writeShortFlag(flag *pflag.Flag, w io.Writer) error { - b := (flag.Value.Type() == "bool") + b := (len(flag.NoOptDefVal) > 0) name := flag.Shorthand format := " " if !b { @@ -374,7 +380,7 @@ func writeShortFlag(flag *pflag.Flag, w io.Writer) error { } func writeFlag(flag *pflag.Flag, w io.Writer) error { - b := (flag.Value.Type() == "bool") + b := (len(flag.NoOptDefVal) > 0) name := flag.Name format := " flags+=(\"--%s" if !b { @@ -387,9 +393,24 @@ func writeFlag(flag *pflag.Flag, w io.Writer) error { return writeFlagHandler("--"+name, flag.Annotations, w) } +func writeLocalNonPersistentFlag(flag *pflag.Flag, w io.Writer) error { + b := (len(flag.NoOptDefVal) > 0) + name := flag.Name + format := " local_nonpersistent_flags+=(\"--%s" + if !b { + format += "=" + } + format += "\")\n" + if _, err := fmt.Fprintf(w, format, name); err != nil { + return err + } + return nil +} + func writeFlags(cmd *Command, w io.Writer) error { _, err := fmt.Fprintf(w, ` flags=() two_word_flags=() + local_nonpersistent_flags=() flags_with_completion=() flags_completion=() @@ -397,6 +418,7 @@ func writeFlags(cmd *Command, w io.Writer) error { if err != nil { return err } + localNonPersistentFlags := cmd.LocalNonPersistentFlags() var visitErr error cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) { if err := writeFlag(flag, w); err != nil { @@ -409,6 +431,12 @@ func writeFlags(cmd *Command, w io.Writer) error { return } } + if localNonPersistentFlags.Lookup(flag.Name) != nil { + if err := writeLocalNonPersistentFlag(flag, w); err != nil { + visitErr = err + return + } + } }) if visitErr != nil { return visitErr diff --git a/command.go b/command.go index b92c95c..e86fee9 100644 --- a/command.go +++ b/command.go @@ -1025,6 +1025,19 @@ func (c *Command) Flags() *flag.FlagSet { return c.flags } +// LocalNonPersistentFlags are flags specific to this command which will NOT persist to subcommands +func (c *Command) LocalNonPersistentFlags() *flag.FlagSet { + persistentFlags := c.PersistentFlags() + + out := flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.LocalFlags().VisitAll(func(f *flag.Flag) { + if persistentFlags.Lookup(f.Name) == nil { + out.AddFlag(f) + } + }) + return out +} + // Get the local FlagSet specifically set in the current command func (c *Command) LocalFlags() *flag.FlagSet { c.mergePersistentFlags()