This commit is contained in:
Chris Suszynski 2024-11-11 10:22:27 -05:00 committed by GitHub
commit 3e671475f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 94 additions and 2 deletions

View File

@ -188,6 +188,11 @@ type Command struct {
// versionTemplate is the version template defined by user. // versionTemplate is the version template defined by user.
versionTemplate string versionTemplate string
// errHandler is a function that a user can define to handle errors in a
// custom way. The function will be called only if the error is received and
// SilenceErrors isn't set. By default, it prints the error with a prefix
// on standard error stream.
errHandler func(err error)
// errPrefix is the error message prefix defined by user. // errPrefix is the error message prefix defined by user.
errPrefix string errPrefix string
@ -357,6 +362,12 @@ func (c *Command) SetVersionTemplate(s string) {
c.versionTemplate = s c.versionTemplate = s
} }
// SetErrHandler sets a custom error handler to be used. The function will be
// called only if the error is received and SilenceErrors isn't set.
func (c *Command) SetErrHandler(fn func(err error)) {
c.errHandler = fn
}
// SetErrPrefix sets error message prefix to be used. Application can use it to set custom prefix. // SetErrPrefix sets error message prefix to be used. Application can use it to set custom prefix.
func (c *Command) SetErrPrefix(s string) { func (c *Command) SetErrPrefix(s string) {
c.errPrefix = s c.errPrefix = s
@ -611,6 +622,31 @@ func (c *Command) VersionTemplate() string {
` `
} }
// ErrHandler return the error handler for the command.
func (c *Command) ErrHandler() func(err error) {
if handler := c.errHandlerOrNil(); handler != nil {
return handler
}
return c.defaultErrHandler
}
func (c *Command) errHandlerOrNil() func(err error) {
if c.errHandler != nil {
return c.errHandler
}
if c.HasParent() {
return c.parent.errHandlerOrNil()
}
return nil
}
func (c *Command) defaultErrHandler(err error) {
c.PrintErrln(c.ErrPrefix(), err.Error())
}
// ErrPrefix return error message prefix for the command // ErrPrefix return error message prefix for the command
func (c *Command) ErrPrefix() string { func (c *Command) ErrPrefix() string {
if c.errPrefix != "" { if c.errPrefix != "" {
@ -1098,7 +1134,7 @@ func (c *Command) ExecuteC() (cmd *Command, err error) {
c = cmd c = cmd
} }
if !c.SilenceErrors { if !c.SilenceErrors {
c.PrintErrln(c.ErrPrefix(), err.Error()) c.ErrHandler()(err)
c.PrintErrf("Run '%v --help' for usage.\n", c.CommandPath()) c.PrintErrf("Run '%v --help' for usage.\n", c.CommandPath())
} }
return c, err return c, err
@ -1127,7 +1163,7 @@ func (c *Command) ExecuteC() (cmd *Command, err error) {
// If root command has SilenceErrors flagged, // If root command has SilenceErrors flagged,
// all subcommands should respect it // all subcommands should respect it
if !cmd.SilenceErrors && !c.SilenceErrors { if !cmd.SilenceErrors && !c.SilenceErrors {
c.PrintErrln(cmd.ErrPrefix(), err.Error()) cmd.ErrHandler()(err)
} }
// If root command has SilenceUsage flagged, // If root command has SilenceUsage flagged,

View File

@ -17,6 +17,7 @@ package cobra
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -2839,3 +2840,58 @@ func TestUnknownFlagShouldReturnSameErrorRegardlessOfArgPosition(t *testing.T) {
}) })
} }
} }
func TestErrHandler(t *testing.T) {
type testErrHandlerTestCase struct {
name string
root bool
sub bool
}
testCases := []testErrHandlerTestCase{
{"CustomOnRootAndSub", true, true},
{"CustomOnRoot", true, false},
{"CustomOnSub", false, true},
{"DefaultOnBoth", false, false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
want := errors.New("expected")
root := &Command{
SilenceUsage: true,
}
sub := &Command{
Use: "sub",
SilenceUsage: true,
RunE: func(cmd *Command, args []string) error {
return fmt.Errorf("%w: foo", want)
},
}
root.AddCommand(sub)
called := false
handler := func(got error) {
called = true
if !errors.Is(got, want) {
t.Errorf("error missmatch,\nwant = %#v\n got = %#v",
want, got)
}
}
if tc.root {
root.SetErrHandler(handler)
}
if tc.sub {
sub.SetErrHandler(handler)
}
output, got := executeCommand(root, "sub")
if (tc.root || tc.sub) && !called {
t.Error("expecting the custom error handler be called")
}
if !errors.Is(got, want) {
t.Errorf("error missmatch,\nwant = %#v\n got = %#v",
want, got)
}
if (tc.root || tc.sub) && output != "" {
t.Error("unexpected output: ", output)
}
})
}
}