diff --git a/cobra_test.go b/cobra_test.go index 46887fd..6c443e7 100644 --- a/cobra_test.go +++ b/cobra_test.go @@ -3,7 +3,6 @@ package cobra import ( "bytes" "fmt" - "os" "reflect" "runtime" "strings" @@ -13,9 +12,6 @@ import ( "github.com/spf13/pflag" ) -var _ = fmt.Println -var _ = os.Stderr - var tp, te, tt, t1, tr []string var rootPersPre, echoPre, echoPersPre, timesPersPre []string var flagb1, flagb2, flagb3, flagbr, flagbp bool @@ -166,20 +162,22 @@ func flagInit() { cmdRootWithRun.ResetFlags() cmdSubNoRun.ResetFlags() cmdCustomFlags.ResetFlags() - cmdRootNoRun.PersistentFlags().StringVarP(&flags2a, "strtwo", "t", "two", strtwoParentHelp) - cmdEcho.Flags().IntVarP(&flagi1, "intone", "i", 123, "help message for flag intone") - cmdTimes.Flags().IntVarP(&flagi2, "inttwo", "j", 234, "help message for flag inttwo") - cmdPrint.Flags().IntVarP(&flagi3, "intthree", "i", 345, "help message for flag intthree") - cmdCustomFlags.Flags().IntVar(&flagi4, "intfour", 456, "help message for flag intfour") - cmdEcho.PersistentFlags().StringVarP(&flags1, "strone", "s", "one", "help message for flag strone") - cmdEcho.PersistentFlags().BoolVarP(&flagbp, "persistentbool", "p", false, "help message for flag persistentbool") - cmdTimes.PersistentFlags().StringVarP(&flags2b, "strtwo", "t", "2", strtwoChildHelp) - cmdPrint.PersistentFlags().StringVarP(&flags3, "strthree", "s", "three", "help message for flag strthree") - cmdEcho.Flags().BoolVarP(&flagb1, "boolone", "b", true, "help message for flag boolone") - cmdTimes.Flags().BoolVarP(&flagb2, "booltwo", "c", false, "help message for flag booltwo") - cmdPrint.Flags().BoolVarP(&flagb3, "boolthree", "b", true, "help message for flag boolthree") cmdVersion1.ResetFlags() cmdVersion2.ResetFlags() + + cmdRootNoRun.PersistentFlags().StringVarP(&flags2a, "strtwo", "t", "two", strtwoParentHelp) + cmdCustomFlags.Flags().IntVar(&flagi4, "intfour", 456, "help message for flag intfour") + cmdEcho.Flags().BoolVarP(&flagb1, "boolone", "b", true, "help message for flag boolone") + cmdEcho.Flags().IntVarP(&flagi1, "intone", "i", 123, "help message for flag intone") + cmdEcho.PersistentFlags().BoolVarP(&flagbp, "persistentbool", "p", false, "help message for flag persistentbool") + cmdEcho.PersistentFlags().StringVarP(&flags1, "strone", "s", "one", "help message for flag strone") + cmdPrint.Flags().IntVarP(&flagi3, "intthree", "i", 345, "help message for flag intthree") + cmdTimes.Flags().BoolVarP(&flagb2, "booltwo", "c", false, "help message for flag booltwo") + cmdTimes.Flags().IntVarP(&flagi2, "inttwo", "j", 234, "help message for flag inttwo") + cmdTimes.Flags().StringVarP(&flags2b, "strtwo", "t", "2", strtwoChildHelp) + cmdTimes.PersistentFlags().StringVarP(&flags2b, "strtwo", "t", "2", strtwoChildHelp) + cmdPrint.Flags().BoolVarP(&flagb3, "boolthree", "b", true, "help message for flag boolthree") + cmdPrint.PersistentFlags().StringVarP(&flags3, "strthree", "s", "three", "help message for flag strthree") } func commandInit() { @@ -858,7 +856,6 @@ func TestFlagAccess(t *testing.T) { } if inherited.Lookup("strtwo") != nil { t.Errorf("InheritedFlags shouldn not contain overwritten flag strtwo") - } } @@ -1150,14 +1147,6 @@ func TestGlobalNormFuncPropagation(t *testing.T) { } } -func TestFlagOnPflagCommandLine(t *testing.T) { - flagName := "flagOnCommandLine" - pflag.CommandLine.String(flagName, "", "about my flag") - r := fullSetupTest("--help") - - checkResultContains(t, r, flagName) -} - func TestAddTemplateFunctions(t *testing.T) { AddTemplateFunc("t", func() bool { return true }) AddTemplateFuncs(template.FuncMap{ @@ -1185,3 +1174,23 @@ func TestUsageIsNotPrintedTwice(t *testing.T) { t.Error("Usage output is not printed exactly once") } } + +func BenchmarkInheritedFlags(b *testing.B) { + initialize() + cmdEcho.AddCommand(cmdTimes) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + cmdTimes.InheritedFlags() + } +} + +func BenchmarkLocalFlags(b *testing.B) { + initialize() + cmdEcho.AddCommand(cmdTimes) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + cmdTimes.LocalFlags() + } +} diff --git a/command.go b/command.go index 50bd3bb..a436656 100644 --- a/command.go +++ b/command.go @@ -66,6 +66,10 @@ type Command struct { pflags *flag.FlagSet // Flags that are declared specifically by this command (not inherited). lflags *flag.FlagSet + // Inherited flags. + iflags *flag.FlagSet + // All persistent flags of cmd's parents. + parentsPflags *flag.FlagSet // SilenceErrors is an option to quiet errors down stream SilenceErrors bool // Silence Usage is an option to silence usage when an error occurs. @@ -544,32 +548,18 @@ func (c *Command) SuggestionsFor(typedName string) []string { // VisitParents visits all parents of the command and invokes fn on each parent. func (c *Command) VisitParents(fn func(*Command)) { - var traverse func(*Command) *Command - - traverse = func(x *Command) *Command { - if x != c { - fn(x) - } - if x.HasParent() { - return traverse(x.parent) - } - return x + if c.HasParent() { + fn(c.Parent()) + c.Parent().VisitParents(fn) } - traverse(c) } // Root finds root command. func (c *Command) Root() *Command { - var findRoot func(*Command) *Command - - findRoot = func(x *Command) *Command { - if x.HasParent() { - return findRoot(x.parent) - } - return x + if c.HasParent() { + return c.Parent().Root() } - - return findRoot(c) + return c } // ArgsLenAtDash will return the length of f.Args at the moment when a -- was @@ -913,12 +903,8 @@ func (c *Command) DebugFlags() { } if x.HasFlags() { x.flags.VisitAll(func(f *flag.Flag) { - if x.HasPersistentFlags() { - if x.persistentFlag(f.Name) == nil { - c.Println(" -"+f.Shorthand+",", "--"+f.Name, "["+f.DefValue+"]", "", f.Value, " [L]") - } else { - c.Println(" -"+f.Shorthand+",", "--"+f.Name, "["+f.DefValue+"]", "", f.Value, " [LP]") - } + if x.HasPersistentFlags() && x.persistentFlag(f.Name) != nil { + c.Println(" -"+f.Shorthand+",", "--"+f.Name, "["+f.DefValue+"]", "", f.Value, " [LP]") } else { c.Println(" -"+f.Shorthand+",", "--"+f.Name, "["+f.DefValue+"]", "", f.Value, " [L]") } @@ -1077,6 +1063,7 @@ func (c *Command) Flags() *flag.FlagSet { c.flags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) c.flags.SetOutput(c.OutOrStderr()) } + return c.flags } @@ -1096,48 +1083,44 @@ func (c *Command) LocalNonPersistentFlags() *flag.FlagSet { // LocalFlags returns the local FlagSet specifically set in the current command. func (c *Command) LocalFlags() *flag.FlagSet { c.mergePersistentFlags() + c.updateParentsPersistentFlags() - local := flag.NewFlagSet(c.Name(), flag.ContinueOnError) - c.lflags.VisitAll(func(f *flag.Flag) { - local.AddFlag(f) - }) - if !c.HasParent() { - flag.CommandLine.VisitAll(func(f *flag.Flag) { - if local.Lookup(f.Name) == nil { - local.AddFlag(f) - } - }) + if c.lflags == nil { + c.lflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.lflags.SetOutput(c.OutOrStderr()) } - return local + + flags := c.Flags() + c.lflags.SortFlags = flags.SortFlags + + addToLocal := func(f *flag.Flag) { + if c.lflags.Lookup(f.Name) == nil && c.parentsPflags.Lookup(f.Name) == nil { + c.lflags.AddFlag(f) + } + } + c.flags.VisitAll(addToLocal) + c.PersistentFlags().VisitAll(addToLocal) + return c.lflags } // InheritedFlags returns all flags which were inherited from parents commands. func (c *Command) InheritedFlags() *flag.FlagSet { c.mergePersistentFlags() + c.updateParentsPersistentFlags() + + if c.iflags == nil { + c.iflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + } - inherited := flag.NewFlagSet(c.Name(), flag.ContinueOnError) local := c.LocalFlags() - var rmerge func(x *Command) - - rmerge = func(x *Command) { - if x.HasPersistentFlags() { - x.PersistentFlags().VisitAll(func(f *flag.Flag) { - if inherited.Lookup(f.Name) == nil && local.Lookup(f.Name) == nil { - inherited.AddFlag(f) - } - }) + c.parentsPflags.VisitAll(func(f *flag.Flag) { + if c.iflags.Lookup(f.Name) == nil && local.Lookup(f.Name) == nil { + c.iflags.AddFlag(f) } - if x.HasParent() { - rmerge(x.parent) - } - } + }) - if c.HasParent() { - rmerge(c.parent) - } - - return inherited + return c.iflags } // NonInheritedFlags returns all flags which were not inherited from parent commands. @@ -1243,38 +1226,35 @@ func (c *Command) Parent() *Command { return c.parent } -func (c *Command) mergePersistentFlags() { - var rmerge func(x *Command) - - // Save the set of local flags - if c.lflags == nil { - c.lflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) - c.lflags.SetOutput(c.OutOrStderr()) - addtolocal := func(f *flag.Flag) { - c.lflags.AddFlag(f) - } - c.Flags().VisitAll(addtolocal) - c.PersistentFlags().VisitAll(addtolocal) +func (c *Command) updateParentsPersistentFlags() { + if c.parentsPflags == nil { + c.parentsPflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.parentsPflags.SetOutput(c.OutOrStderr()) } - rmerge = func(x *Command) { - if !x.HasParent() { - flag.CommandLine.VisitAll(func(f *flag.Flag) { - if x.PersistentFlags().Lookup(f.Name) == nil { - x.PersistentFlags().AddFlag(f) - } - }) - } + + c.VisitParents(func(x *Command) { if x.HasPersistentFlags() { x.PersistentFlags().VisitAll(func(f *flag.Flag) { - if c.Flags().Lookup(f.Name) == nil { - c.Flags().AddFlag(f) + if c.parentsPflags.Lookup(f.Name) == nil { + c.parentsPflags.AddFlag(f) } }) } - if x.HasParent() { - rmerge(x.parent) + }) +} + +func (c *Command) mergePersistentFlags() { + flags := c.Flags() + + merge := func(x *Command) { + if x.HasPersistentFlags() { + x.PersistentFlags().VisitAll(func(f *flag.Flag) { + if flags.Lookup(f.Name) == nil { + flags.AddFlag(f) + } + }) } } - - rmerge(c) + merge(c) + c.VisitParents(merge) }