From 8bd96bc3a2bcfda4223206f34523eb4d67d20555 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Chris=20Suszy=C5=84ski?= Date: Fri, 18 Oct 2024 15:29:04 +0200 Subject: [PATCH] Custom error handler --- command.go | 40 +++++++++++++++++++++++++++++++++-- command_test.go | 56 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 2 deletions(-) diff --git a/command.go b/command.go index 4cd712b..d8abe92 100644 --- a/command.go +++ b/command.go @@ -188,6 +188,11 @@ type Command struct { // versionTemplate is the version template defined by user. versionTemplate string + // errHandler is a function that a user can define to handle errors in a + // custom way. The function will be called only if the error is received and + // SilenceErrors isn't set. By default, it prints the error with a prefix + // on standard error stream. + errHandler func(err error) // errPrefix is the error message prefix defined by user. errPrefix string @@ -357,6 +362,12 @@ func (c *Command) SetVersionTemplate(s string) { c.versionTemplate = s } +// SetErrHandler sets a custom error handler to be used. The function will be +// called only if the error is received and SilenceErrors isn't set. +func (c *Command) SetErrHandler(fn func(err error)) { + c.errHandler = fn +} + // SetErrPrefix sets error message prefix to be used. Application can use it to set custom prefix. func (c *Command) SetErrPrefix(s string) { c.errPrefix = s @@ -611,6 +622,31 @@ func (c *Command) VersionTemplate() string { ` } +// ErrHandler return the error handler for the command. +func (c *Command) ErrHandler() func(err error) { + if handler := c.errHandlerOrNil(); handler != nil { + return handler + } + + return c.defaultErrHandler +} + +func (c *Command) errHandlerOrNil() func(err error) { + if c.errHandler != nil { + return c.errHandler + } + + if c.HasParent() { + return c.parent.errHandlerOrNil() + } + + return nil +} + +func (c *Command) defaultErrHandler(err error) { + c.PrintErrln(c.ErrPrefix(), err.Error()) +} + // ErrPrefix return error message prefix for the command func (c *Command) ErrPrefix() string { if c.errPrefix != "" { @@ -1098,7 +1134,7 @@ func (c *Command) ExecuteC() (cmd *Command, err error) { c = cmd } if !c.SilenceErrors { - c.PrintErrln(c.ErrPrefix(), err.Error()) + c.ErrHandler()(err) c.PrintErrf("Run '%v --help' for usage.\n", c.CommandPath()) } return c, err @@ -1127,7 +1163,7 @@ func (c *Command) ExecuteC() (cmd *Command, err error) { // If root command has SilenceErrors flagged, // all subcommands should respect it if !cmd.SilenceErrors && !c.SilenceErrors { - c.PrintErrln(cmd.ErrPrefix(), err.Error()) + cmd.ErrHandler()(err) } // If root command has SilenceUsage flagged, diff --git a/command_test.go b/command_test.go index cd44992..9fe75dd 100644 --- a/command_test.go +++ b/command_test.go @@ -17,6 +17,7 @@ package cobra import ( "bytes" "context" + "errors" "fmt" "io" "os" @@ -2839,3 +2840,58 @@ func TestUnknownFlagShouldReturnSameErrorRegardlessOfArgPosition(t *testing.T) { }) } } + +func TestErrHandler(t *testing.T) { + type testErrHandlerTestCase struct { + name string + root bool + sub bool + } + testCases := []testErrHandlerTestCase{ + {"CustomOnRootAndSub", true, true}, + {"CustomOnRoot", true, false}, + {"CustomOnSub", false, true}, + {"DefaultOnBoth", false, false}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + want := errors.New("expected") + root := &Command{ + SilenceUsage: true, + } + sub := &Command{ + Use: "sub", + SilenceUsage: true, + RunE: func(cmd *Command, args []string) error { + return fmt.Errorf("%w: foo", want) + }, + } + root.AddCommand(sub) + called := false + handler := func(got error) { + called = true + if !errors.Is(got, want) { + t.Errorf("error missmatch,\nwant = %#v\n got = %#v", + want, got) + } + } + if tc.root { + root.SetErrHandler(handler) + } + if tc.sub { + sub.SetErrHandler(handler) + } + output, got := executeCommand(root, "sub") + if (tc.root || tc.sub) && !called { + t.Error("expecting the custom error handler be called") + } + if !errors.Is(got, want) { + t.Errorf("error missmatch,\nwant = %#v\n got = %#v", + want, got) + } + if (tc.root || tc.sub) && output != "" { + t.Error("unexpected output: ", output) + } + }) + } +}