diff --git a/command.go b/command.go index b4a6e66..329bbea 100644 --- a/command.go +++ b/command.go @@ -1089,16 +1089,14 @@ func (c *Command) LocalFlags() *flag.FlagSet { c.lflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) c.lflags.SetOutput(c.OutOrStderr()) } - - flags := c.Flags() - c.lflags.SortFlags = flags.SortFlags + c.lflags.SortFlags = c.Flags().SortFlags addToLocal := func(f *flag.Flag) { if c.lflags.Lookup(f.Name) == nil && c.parentsPflags.Lookup(f.Name) == nil { c.lflags.AddFlag(f) } } - c.flags.VisitAll(addToLocal) + c.Flags().VisitAll(addToLocal) c.PersistentFlags().VisitAll(addToLocal) return c.lflags } @@ -1112,13 +1110,11 @@ func (c *Command) InheritedFlags() *flag.FlagSet { } local := c.LocalFlags() - c.parentsPflags.VisitAll(func(f *flag.Flag) { if c.iflags.Lookup(f.Name) == nil && local.Lookup(f.Name) == nil { c.iflags.AddFlag(f) } }) - return c.iflags } @@ -1204,8 +1200,9 @@ func (c *Command) persistentFlag(name string) (flag *flag.Flag) { flag = c.PersistentFlags().Lookup(name) } - if flag == nil && c.HasParent() { - flag = c.parent.persistentFlag(name) + if flag == nil { + c.updateParentsPflags() + flag = c.parentsPflags.Lookup(name) } return } @@ -1229,23 +1226,14 @@ func (c *Command) Parent() *Command { // and adds missing persistent flags of all parents. func (c *Command) mergePersistentFlags() { c.Flags().AddFlagSet(c.PersistentFlags()) - - added := c.updateParentsPflags() - if len(added) > 0 { - for _, f := range added { - if c.Flags().Lookup(f.Name) == nil { - c.Flags().AddFlag(f) - } - } - } + c.updateParentsPflags() + c.Flags().AddFlagSet(c.parentsPflags) } // updateParentsPflags updates c.parentsPflags by adding -// new persistent flags of all parents and returns added flags. +// new persistent flags of all parents. // If c.parentsPflags == nil, it makes new. -// -// This function must be used ONLY in mergePersistentFlags. -func (c *Command) updateParentsPflags() (added []*flag.Flag) { +func (c *Command) updateParentsPflags() { if c.parentsPflags == nil { c.parentsPflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) c.parentsPflags.SetOutput(c.OutOrStderr()) @@ -1254,16 +1242,7 @@ func (c *Command) updateParentsPflags() (added []*flag.Flag) { c.Root().PersistentFlags().AddFlagSet(flag.CommandLine) - c.VisitParents(func(x *Command) { - if x.HasPersistentFlags() { - x.PersistentFlags().VisitAll(func(f *flag.Flag) { - if c.parentsPflags.Lookup(f.Name) == nil { - c.parentsPflags.AddFlag(f) - added = append(added, f) - } - }) - } + c.VisitParents(func(parent *Command) { + c.parentsPflags.AddFlagSet(parent.PersistentFlags()) }) - - return added }