diff --git a/command.go b/command.go index ab3cf69..fb60ebd 100644 --- a/command.go +++ b/command.go @@ -17,6 +17,7 @@ package cobra import ( "bytes" + "context" "errors" "fmt" "io" @@ -143,9 +144,11 @@ type Command struct { // TraverseChildren parses flags on all parents before executing child command. TraverseChildren bool - //FParseErrWhitelist flag parse errors to be ignored + // FParseErrWhitelist flag parse errors to be ignored FParseErrWhitelist FParseErrWhitelist + ctx context.Context + // commands is the list of commands supported by this program. commands []*Command // parent is a parent command for this command. @@ -205,6 +208,12 @@ type Command struct { errWriter io.Writer } +// Context returns underlying command context. If command wasn't +// executed with ExecuteContext Context returns Background context. +func (c *Command) Context() context.Context { + return c.ctx +} + // SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden // particularly useful when testing. func (c *Command) SetArgs(a []string) { @@ -862,6 +871,13 @@ func (c *Command) preRun() { } } +// ExecuteContext is the same as Execute(), but sets the ctx on the command. +// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle functions. +func (c *Command) ExecuteContext(ctx context.Context) error { + c.ctx = ctx + return c.Execute() +} + // Execute uses the args (os.Args[1:] by default) // and run through the command tree finding appropriate matches // for commands and then corresponding flags. @@ -872,6 +888,10 @@ func (c *Command) Execute() error { // ExecuteC executes the command. func (c *Command) ExecuteC() (cmd *Command, err error) { + if c.ctx == nil { + c.ctx = context.Background() + } + // Regardless of what command execute is called on, run on Root only if c.HasParent() { return c.Root().ExecuteC() @@ -916,6 +936,12 @@ func (c *Command) ExecuteC() (cmd *Command, err error) { cmd.commandCalledAs.name = cmd.Name() } + // We have to pass global context to children command + // if context is present on the parent command. + if cmd.ctx == nil { + cmd.ctx = c.ctx + } + err = cmd.execute(flags) if err != nil { // Always show help if requested, even if SilenceErrors is in @@ -1560,7 +1586,7 @@ func (c *Command) ParseFlags(args []string) error { beforeErrorBufLen := c.flagErrorBuf.Len() c.mergePersistentFlags() - //do it here after merging all flags and just before parse + // do it here after merging all flags and just before parse c.Flags().ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist) err := c.Flags().Parse(args) diff --git a/command_test.go b/command_test.go index b26bd4a..05a3c0f 100644 --- a/command_test.go +++ b/command_test.go @@ -2,6 +2,7 @@ package cobra import ( "bytes" + "context" "fmt" "os" "reflect" @@ -18,6 +19,16 @@ func executeCommand(root *Command, args ...string) (output string, err error) { return output, err } +func executeCommandWithContext(ctx context.Context, root *Command, args ...string) (output string, err error) { + buf := new(bytes.Buffer) + root.SetOutput(buf) + root.SetArgs(args) + + err = root.ExecuteContext(ctx) + + return buf.String(), err +} + func executeCommandC(root *Command, args ...string) (c *Command, output string, err error) { buf := new(bytes.Buffer) root.SetOutput(buf) @@ -135,6 +146,62 @@ func TestSubcommandExecuteC(t *testing.T) { } } +func TestExecuteContext(t *testing.T) { + ctx := context.TODO() + + ctxRun := func(cmd *Command, args []string) { + if cmd.Context() != ctx { + t.Errorf("Command %q must have context when called with ExecuteContext", cmd.Use) + } + } + + rootCmd := &Command{Use: "root", Run: ctxRun, PreRun: ctxRun} + childCmd := &Command{Use: "child", Run: ctxRun, PreRun: ctxRun} + granchildCmd := &Command{Use: "grandchild", Run: ctxRun, PreRun: ctxRun} + + childCmd.AddCommand(granchildCmd) + rootCmd.AddCommand(childCmd) + + if _, err := executeCommandWithContext(ctx, rootCmd, ""); err != nil { + t.Errorf("Root command must not fail: %+v", err) + } + + if _, err := executeCommandWithContext(ctx, rootCmd, "child"); err != nil { + t.Errorf("Subcommand must not fail: %+v", err) + } + + if _, err := executeCommandWithContext(ctx, rootCmd, "child", "grandchild"); err != nil { + t.Errorf("Command child must not fail: %+v", err) + } +} + +func TestExecute_NoContext(t *testing.T) { + run := func(cmd *Command, args []string) { + if cmd.Context() != context.Background() { + t.Errorf("Command %s must have background context", cmd.Use) + } + } + + rootCmd := &Command{Use: "root", Run: run, PreRun: run} + childCmd := &Command{Use: "child", Run: run, PreRun: run} + granchildCmd := &Command{Use: "grandchild", Run: run, PreRun: run} + + childCmd.AddCommand(granchildCmd) + rootCmd.AddCommand(childCmd) + + if _, err := executeCommand(rootCmd, ""); err != nil { + t.Errorf("Root command must not fail: %+v", err) + } + + if _, err := executeCommand(rootCmd, "child"); err != nil { + t.Errorf("Subcommand must not fail: %+v", err) + } + + if _, err := executeCommand(rootCmd, "child", "grandchild"); err != nil { + t.Errorf("Command child must not fail: %+v", err) + } +} + func TestRootUnknownCommandSilenced(t *testing.T) { rootCmd := &Command{Use: "root", Run: emptyRun} rootCmd.SilenceErrors = true