Add Command.SetContext (#1551)

Increases flexibility in how Contexts can be used with Cobra.
This commit is contained in:
Joshua Carpeggiani 2022-03-18 21:01:58 +11:00 committed by GitHub
parent 5d066b77b5
commit f848943afd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 109 additions and 0 deletions

View File

@ -230,6 +230,12 @@ func (c *Command) Context() context.Context {
return c.ctx return c.ctx
} }
// SetContext sets context for the command. It is set to context.Background by default and will be overwritten by
// Command.ExecuteContext or Command.ExecuteContextC
func (c *Command) SetContext(ctx context.Context) {
c.ctx = ctx
}
// SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden // SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden
// particularly useful when testing. // particularly useful when testing.
func (c *Command) SetArgs(a []string) { func (c *Command) SetArgs(a []string) {

View File

@ -2058,3 +2058,106 @@ func TestFParseErrWhitelistSiblingCommand(t *testing.T) {
} }
checkStringContains(t, output, "unknown flag: --unknown") checkStringContains(t, output, "unknown flag: --unknown")
} }
func TestSetContext(t *testing.T) {
type key struct{}
val := "foobar"
root := &Command{
Use: "root",
Run: func(cmd *Command, args []string) {
key := cmd.Context().Value(key{})
got, ok := key.(string)
if !ok {
t.Error("key not found in context")
}
if got != val {
t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got)
}
},
}
ctx := context.WithValue(context.Background(), key{}, val)
root.SetContext(ctx)
err := root.Execute()
if err != nil {
t.Error(err)
}
}
func TestSetContextPreRun(t *testing.T) {
type key struct{}
val := "barr"
root := &Command{
Use: "root",
PreRun: func(cmd *Command, args []string) {
ctx := context.WithValue(cmd.Context(), key{}, val)
cmd.SetContext(ctx)
},
Run: func(cmd *Command, args []string) {
val := cmd.Context().Value(key{})
got, ok := val.(string)
if !ok {
t.Error("key not found in context")
}
if got != val {
t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got)
}
},
}
err := root.Execute()
if err != nil {
t.Error(err)
}
}
func TestSetContextPreRunOverwrite(t *testing.T) {
type key struct{}
val := "blah"
root := &Command{
Use: "root",
Run: func(cmd *Command, args []string) {
key := cmd.Context().Value(key{})
_, ok := key.(string)
if ok {
t.Error("key found in context when not expected")
}
},
}
ctx := context.WithValue(context.Background(), key{}, val)
root.SetContext(ctx)
err := root.ExecuteContext(context.Background())
if err != nil {
t.Error(err)
}
}
func TestSetContextPersistentPreRun(t *testing.T) {
type key struct{}
val := "barbar"
root := &Command{
Use: "root",
PersistentPreRun: func(cmd *Command, args []string) {
ctx := context.WithValue(cmd.Context(), key{}, val)
cmd.SetContext(ctx)
},
}
child := &Command{
Use: "child",
Run: func(cmd *Command, args []string) {
key := cmd.Context().Value(key{})
got, ok := key.(string)
if !ok {
t.Error("key not found in context")
}
if got != val {
t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got)
}
},
}
root.AddCommand(child)
root.SetArgs([]string{"child"})
err := root.Execute()
if err != nil {
t.Error(err)
}
}