From 8e2c9596e14ab6052f2bd47e11908a19d99e5e83 Mon Sep 17 00:00:00 2001 From: Pieter Noordhuis Date: Sun, 11 Dec 2022 08:36:49 +0100 Subject: [PATCH] Always propagate root context to child command The context passed to the root command should propagate to its children not only on the first execution but also subsequent calls. Calling the same command multiple times is common when testing cobra applications. --- command.go | 9 ++--- command_test.go | 90 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 6 deletions(-) diff --git a/command.go b/command.go index 1b0b17a..9403a5c 100644 --- a/command.go +++ b/command.go @@ -1009,7 +1009,7 @@ func (c *Command) ExecuteC() (cmd *Command, err error) { // Regardless of what command execute is called on, run on Root only if c.HasParent() { - return c.Root().ExecuteC() + return c.Root().ExecuteContextC(c.ctx) } // windows hook @@ -1059,11 +1059,8 @@ func (c *Command) ExecuteC() (cmd *Command, err error) { cmd.commandCalledAs.name = cmd.Name() } - // We have to pass global context to children command - // if context is present on the parent command. - if cmd.ctx == nil { - cmd.ctx = c.ctx - } + // Pass context of root command to child command. + cmd.ctx = c.ctx err = cmd.execute(flags) if err != nil { diff --git a/command_test.go b/command_test.go index 1801132..94a512b 100644 --- a/command_test.go +++ b/command_test.go @@ -232,6 +232,96 @@ func TestExecuteContextC(t *testing.T) { } } +// This tests that the context passed to the root command propagates to its children +// not only on the first execution but also subsequent calls. +// Calling the same command multiple times is common when testing cobra applications. +func TestExecuteContextMultiple(t *testing.T) { + var key string + + // Define unique contexts so we can tell them apart below. + ctxs := []context.Context{ + context.WithValue(context.Background(), &key, "1"), + context.WithValue(context.Background(), &key, "2"), + } + + // Shared reference to the context in the current iteration. + var currentCtx context.Context + + ctxRun := func(cmd *Command, args []string) { + if cmd.Context() != currentCtx { + t.Errorf("Command %q must have context with value %s", cmd.Use, currentCtx.Value(&key)) + } + } + + rootCmd := &Command{Use: "root", Run: ctxRun, PreRun: ctxRun} + childCmd := &Command{Use: "child", Run: ctxRun, PreRun: ctxRun} + granchildCmd := &Command{Use: "grandchild", Run: ctxRun, PreRun: ctxRun} + + childCmd.AddCommand(granchildCmd) + rootCmd.AddCommand(childCmd) + + for i := 0; i < 2; i++ { + currentCtx = ctxs[i] + + if _, err := executeCommandWithContext(currentCtx, rootCmd, ""); err != nil { + t.Errorf("Root command must not fail: %+v", err) + } + + if _, err := executeCommandWithContext(currentCtx, rootCmd, "child"); err != nil { + t.Errorf("Subcommand must not fail: %+v", err) + } + + if _, err := executeCommandWithContext(currentCtx, rootCmd, "child", "grandchild"); err != nil { + t.Errorf("Command child must not fail: %+v", err) + } + } +} + +// This tests that the context passed to a subcommand propagates to the root. +// If the entry point happens to be different from the root command, the +// context should still propagate throughout the execution. +func TestExecuteContextOnSubcommand(t *testing.T) { + var key string + + // Define unique contexts so we can tell them apart below. + ctxs := []context.Context{ + context.WithValue(context.Background(), &key, "1"), + context.WithValue(context.Background(), &key, "2"), + context.WithValue(context.Background(), &key, "3"), + } + + // Shared reference to the context in the current iteration. + var currentCtx context.Context + + ctxRun := func(cmd *Command, args []string) { + if cmd.Context() != currentCtx { + t.Errorf("Command %q must have context with value %s", cmd.Use, currentCtx.Value(&key)) + } + } + + rootCmd := &Command{Use: "root", Run: ctxRun, PreRun: ctxRun} + childCmd := &Command{Use: "child", Run: ctxRun, PreRun: ctxRun} + granchildCmd := &Command{Use: "grandchild", Run: ctxRun, PreRun: ctxRun} + + childCmd.AddCommand(granchildCmd) + rootCmd.AddCommand(childCmd) + + currentCtx = ctxs[0] + if _, err := executeCommandWithContext(currentCtx, rootCmd, ""); err != nil { + t.Errorf("Root command must not fail: %+v", err) + } + + currentCtx = ctxs[1] + if _, err := executeCommandWithContext(currentCtx, childCmd, "child"); err != nil { + t.Errorf("Subcommand must not fail: %+v", err) + } + + currentCtx = ctxs[2] + if _, err := executeCommandWithContext(currentCtx, granchildCmd, "child", "grandchild"); err != nil { + t.Errorf("Command child must not fail: %+v", err) + } +} + func TestExecute_NoContext(t *testing.T) { run := func(cmd *Command, args []string) { if cmd.Context() != context.Background() {