diff --git a/cobra.go b/cobra.go new file mode 100644 index 0000000..295d832 --- /dev/null +++ b/cobra.go @@ -0,0 +1,304 @@ +// Copyright © 2013 Steve Francia . +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Commands similar to git, go tools and other modern CLI tools +// inspired by go, go-Commander, gh and subcommand + +package cobra + +import ( + "bytes" + "fmt" + flag "github.com/ogier/pflag" + "os" + "strings" +) + +var _ = flag.ContinueOnError + +type Flag interface { + Args() []string +} + +type Flags interface { + Lookup(string) *Flag + VisitAll(fn func(*Flag)) + Parse(arguments []string) error +} + +// A Commander holds the configuration for the command line tool. +type Commander struct { + // A Commander is also a Command for top level and global help & flags + Command + //ExitOnError, ContinueOnError or PanicOnError + behavior flag.ErrorHandling + + args []string +} + +func NewCommander() (c *Commander) { + c = new(Commander) + return +} + +func (c *Commander) setFlagBehavior(b flag.ErrorHandling) error { + if b == flag.ExitOnError || b == flag.ContinueOnError || b == flag.PanicOnError { + c.behavior = b + return nil + } + return fmt.Errorf("%v is not a valid behavior", b) +} + +func (c *Commander) SetName(name string) { + c.name = name +} + +func (c *Commander) SetArgs(a []string) { + c.args = a +} + +func (c *Commander) Execute() { + if len(c.args) == 0 { + c.execute(os.Args[1:]) + } else { + c.execute(c.args) + } +} + +// Command is just that, a command for your application. +// eg. 'go run' ... 'run' is the command. Cobra requires +// you to define the usage and description as part of your command +// definition to ensure usability. +type Command struct { + // Name is the command name, usually the executable's name. + name string + // The one-line usage message. + Use string + // The short description shown in the 'help' output. + Short string + // The long message shown in the 'help ' output. + Long string + // Set of flags specific to this command. + flags *flag.FlagSet + // Set of flags children commands will inherit + pflags *flag.FlagSet + // Run runs the command. + // The args are the arguments after the command name. + Run func(cmd *Command, args []string) + // Commands is the list of commands supported by this Commander program. + commands []*Command + // Parent Command for this command + parent *Command + // Commander + //cmdr *Commander + flagErrorBuf *bytes.Buffer +} + +// find the target command given the args and command tree +func (c *Command) Find(args []string) (cmd *Command, a []string, err error) { + if c == nil { + return nil, nil, fmt.Errorf("Called find() on a nil Command") + } + + validSubCommand := false + if len(args) > 1 && c.HasSubCommands() { + for _, cmd := range c.commands { + if cmd.Name() == args[0] { + validSubCommand = true + return cmd.Find(args[1:]) + } + } + } + if !validSubCommand && c.Runnable() { + return c, args, nil + } + + return nil, nil, nil +} + +// execute the command determined by args and the command tree +func (c *Command) execute(args []string) (err error) { + err = fmt.Errorf("unknown subcommand %q\nRun 'help' for usage.\n", args[0]) + + if c == nil { + return fmt.Errorf("Called Execute() on a nil Command") + } + + cmd, a, e := c.Find(args) + if e == nil { + cmd.Flags().Parse(a) + argWoFlags := cmd.Flags().Args() + cmd.Run(cmd, argWoFlags) + return nil + } + err = e + return err +} + +// Add one or many commands as children of this +func (c *Command) AddCommand(cmds ...*Command) { + for i, x := range cmds { + cmds[i].parent = c + c.commands = append(c.commands, x) + } +} + +// The full usage for a given command (including parents) +func (c *Command) Usage(depth ...int) string { + i := 0 + if len(depth) > 0 { + i = depth[0] + } + + if c.HasParent() { + return c.parent.Usage(i+1) + " " + c.Use + } else if i > 0 { + return c.Name() + } else { + return c.Use + } +} + +// Usage prints the usage details to the standard output. +func (c *Command) PrintUsage() { + if c.Runnable() { + fmt.Printf("usage: %s\n\n", c.Usage()) + } + + fmt.Println(strings.Trim(c.Long, "\n")) +} + +// Name returns the command's name: the first word in the use line. +func (c *Command) Name() string { + if c.name != "" { + return c.name + } + name := c.Use + i := strings.Index(name, " ") + if i >= 0 { + name = name[:i] + } + return name +} + +// Determine if the command is itself runnable +func (c *Command) Runnable() bool { + return c.Run != nil +} + +// Determine if the command has children commands +func (c *Command) HasSubCommands() bool { + return len(c.commands) > 0 +} + +// Determine if the command has children commands +func (c *Command) HasParent() bool { + return c.parent != nil +} + +// Get the Commands FlagSet +func (c *Command) Flags() *flag.FlagSet { + if c.flags == nil { + c.flags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.flags.SetOutput(c.flagErrorBuf) + } + return c.flags +} + +// Get the Commands Persistent FlagSet +func (c *Command) PersistentFlags() *flag.FlagSet { + if c.pflags == nil { + c.pflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.pflags.SetOutput(c.flagErrorBuf) + } + return c.flags +} + +// Intended for use in testing +func (c *Command) ResetFlags() { + c.flags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.flags.SetOutput(c.flagErrorBuf) + c.pflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.pflags.SetOutput(c.flagErrorBuf) +} + +func (c *Command) HasFlags() bool { + return hasFlags(c.flags) +} + +func (c *Command) HasPersistentFlags() bool { + return hasFlags(c.pflags) +} + +// Is this set of flags not empty +func hasFlags(f *flag.FlagSet) bool { + if f == nil { + return false + } + if f.NFlag() != 0 { + return true + } + return false +} + +// Climbs up the command tree looking for matching flag +func (c *Command) Flag(name string) (flag *flag.Flag) { + flag = c.Flags().Lookup(name) + + if flag == nil { + flag = c.persistentFlag(name) + } + + return +} + +// recursively find matching persistent flag +func (c *Command) persistentFlag(name string) (flag *flag.Flag) { + if c.HasPersistentFlags() { + flag = c.PersistentFlags().Lookup(name) + } + + if flag == nil && c.HasParent() { + flag = c.parent.persistentFlag(name) + } + return +} + +// Parses persistent flag tree & local flags +func (c *Command) ParseFlags(args []string) (err error) { + err = c.ParsePersistentFlags(args) + if err != nil { + return err + } + err = c.Flags().Parse(args) + if err != nil { + return err + } + return nil +} + +// Climbs up the command tree parsing flags from top to bottom +func (c *Command) ParsePersistentFlags(args []string) (err error) { + if !c.HasParent() || c.parent.PersistentFlags().Parsed() { + err = c.PersistentFlags().Parse(args) + if err != nil { + return err + } + } else { + err = c.parent.ParsePersistentFlags(args) + if err != nil { + return err + } + } + return nil +} diff --git a/cobra_test.go b/cobra_test.go new file mode 100644 index 0000000..4c4e6de --- /dev/null +++ b/cobra_test.go @@ -0,0 +1,210 @@ +package cobra_test + +import ( + . "cobra" + "strings" + "testing" +) + +var tp, te, tt, t1 []string +var flagb1, flagb2, flagb3 bool +var flags1, flags2, flags3 string +var flagi1, flagi2, flagi3 int +var globalFlag1 bool +var flagEcho bool + +var cmdPrint = &Command{ + Use: "print [string to print]", + Short: "Print anything to the screen", + Long: `an utterly useless command for testing.`, + Run: func(cmd *Command, args []string) { + tp = args + }, +} + +var cmdEcho = &Command{ + Use: "echo [string to echo]", + Short: "Echo anything to the screen", + Long: `an utterly useless command for testing.`, + Run: func(cmd *Command, args []string) { + te = args + }, +} + +var cmdTimes = &Command{ + Use: "times [string to echo]", + Short: "Echo anything to the screen more times", + Long: `an slightly useless command for testing.`, + Run: func(cmd *Command, args []string) { + tt = args + }, +} + +func flagInit() { + cmdEcho.ResetFlags() + cmdPrint.ResetFlags() + cmdTimes.ResetFlags() + cmdEcho.Flags().IntVarP(&flagi1, "intone", "i", 123, "help message for flag intone") + cmdTimes.Flags().IntVarP(&flagi2, "inttwo", "j", 234, "help message for flag inttwo") + cmdPrint.Flags().IntVarP(&flagi3, "intthree", "i", 345, "help message for flag intthree") + cmdEcho.PersistentFlags().StringVarP(&flags1, "strone", "s", "one", "help message for flag strone") + cmdTimes.PersistentFlags().StringVarP(&flags2, "strtwo", "t", "two", "help message for flag strtwo") + cmdPrint.PersistentFlags().StringVarP(&flags3, "strthree", "s", "three", "help message for flag strthree") + cmdEcho.Flags().BoolVarP(&flagb1, "boolone", "b", true, "help message for flag boolone") + cmdTimes.Flags().BoolVarP(&flagb2, "booltwo", "c", false, "help message for flag booltwo") + cmdPrint.Flags().BoolVarP(&flagb3, "boolthree", "b", true, "help message for flag boolthree") +} + +func initialize() *Commander { + tt, tp, te = nil, nil, nil + var c = NewCommander() + c.SetName("cobra test") + + return c +} + +func TestSingleCommand(t *testing.T) { + c := initialize() + c.AddCommand(cmdPrint, cmdEcho) + c.SetArgs(strings.Split("print one two", " ")) + c.Execute() + + if te != nil || tt != nil { + t.Error("Wrong command called") + } + if tp == nil { + t.Error("Wrong command called") + } + if strings.Join(tp, " ") != "one two" { + t.Error("Command didn't parse correctly") + } +} + +func TestChildCommand(t *testing.T) { + c := initialize() + cmdEcho.AddCommand(cmdTimes) + c.AddCommand(cmdPrint, cmdEcho) + c.SetArgs(strings.Split("echo times one two", " ")) + c.Execute() + + if te != nil || tp != nil { + t.Error("Wrong command called") + } + if tt == nil { + t.Error("Wrong command called") + } + if strings.Join(tt, " ") != "one two" { + t.Error("Command didn't parse correctly") + } +} + +func TestFlagLong(t *testing.T) { + c := initialize() + c.AddCommand(cmdPrint, cmdEcho, cmdTimes) + flagInit() + c.SetArgs(strings.Split("echo --intone=13 something here", " ")) + c.Execute() + + if strings.Join(te, " ") != "something here" { + t.Errorf("flags didn't leave proper args remaining..%s given", te) + } + if flagi1 != 13 { + t.Errorf("int flag didn't get correct value, had %d", flagi1) + } + if flagi2 != 234 { + t.Errorf("default flag value changed, 234 expected, %d given", flagi2) + } +} + +func TestFlagShort(t *testing.T) { + c := initialize() + c.AddCommand(cmdPrint, cmdEcho, cmdTimes) + flagInit() + c.SetArgs(strings.Split("echo -i13 something here", " ")) + c.Execute() + + if strings.Join(te, " ") != "something here" { + t.Errorf("flags didn't leave proper args remaining..%s given", te) + } + if flagi1 != 13 { + t.Errorf("int flag didn't get correct value, had %d", flagi1) + } + if flagi2 != 234 { + t.Errorf("default flag value changed, 234 expected, %d given", flagi2) + } + + c = initialize() + c.AddCommand(cmdPrint, cmdEcho, cmdTimes) + flagInit() + c.SetArgs(strings.Split("echo -i 13 something here", " ")) + c.Execute() + + if strings.Join(te, " ") != "something here" { + t.Errorf("flags didn't leave proper args remaining..%s given", te) + } + if flagi1 != 13 { + t.Errorf("int flag didn't get correct value, had %d", flagi1) + } + if flagi2 != 234 { + t.Errorf("default flag value changed, 234 expected, %d given", flagi2) + } + + // Testing same shortcode, different command + c = initialize() + c.AddCommand(cmdPrint, cmdEcho, cmdTimes) + flagInit() + c.SetArgs(strings.Split("print -i99 one two", " ")) + c.Execute() + + if strings.Join(tp, " ") != "one two" { + t.Errorf("flags didn't leave proper args remaining..%s given", tp) + } + if flagi3 != 99 { + t.Errorf("int flag didn't get correct value, had %d", flagi3) + } + if flagi1 != 123 { + t.Errorf("default flag value changed on different comamnd with same shortname, 234 expected, %d given", flagi2) + } +} + +func TestChildCommandFlags(t *testing.T) { + c := initialize() + cmdEcho.AddCommand(cmdTimes) + c.AddCommand(cmdPrint, cmdEcho) + c.SetArgs(strings.Split("echo times -j 99 one two", " ")) + c.Execute() + + if strings.Join(tt, " ") != "one two" { + t.Errorf("flags didn't leave proper args remaining..%s given", tt) + } + + //c = initialize() + //cmdEcho.AddCommand(cmdTimes) + //c.AddCommand(cmdPrint, cmdEcho) + //c.SetArgs(strings.Split("echo times -j 99 -i 77 one two", " ")) + //c.Execute() + + //if strings.Join(tt, " ") != "one two" { + //t.Errorf("flags didn't leave proper args remaining..%s given", tt) + //} +} + +func TestPersistentFlags(t *testing.T) { + c := initialize() + cmdEcho.AddCommand(cmdTimes) + c.AddCommand(cmdPrint, cmdEcho) + flagInit() + c.SetArgs(strings.Split("echo -s something more here", " ")) + c.Execute() + + // persistentFlag should act like normal flag on it's own command + if strings.Join(te, " ") != "more here" { + t.Errorf("flags didn't leave proper args remaining..%s given", te) + } + + // persistentFlag should act like normal flag on it's own command + if flags1 != "something" { + t.Errorf("string flag didn't get correct value, had %v", flags1) + } + +}