diff --git a/command.go b/command.go index 9ae9836..7d34138 100644 --- a/command.go +++ b/command.go @@ -109,10 +109,11 @@ type Command struct { flagErrorBuf *bytes.Buffer - args []string // actual args parsed from flags - output *io.Writer // out writer if set in SetOutput(w) - usageFunc func(*Command) error // Usage can be defined by application - usageTemplate string // Can be defined by Application + args []string // actual args parsed from flags + output *io.Writer // out writer if set in SetOutput(w) + usageFunc func(*Command) error // Usage can be defined by application + usageTemplate string // Can be defined by Application + flagErrorFunc func(*Command, error) error helpTemplate string // Can be defined by Application helpFunc func(*Command, []string) // Help can be defined by application helpCommand *Command // The help command @@ -150,7 +151,13 @@ func (c *Command) SetUsageTemplate(s string) { c.usageTemplate = s } -// Can be defined by Application. +// SetFlagErrorFunc sets a function to generate an error when flag parsing +// fails +func (c *Command) SetFlagErrorFunc(f func(*Command, error) error) { + c.flagErrorFunc = f +} + +// Can be defined by Application func (c *Command) SetHelpFunc(f func(*Command, []string)) { c.helpFunc = f } @@ -257,6 +264,22 @@ func (c *Command) UsageString() string { return bb.String() } +// FlagErrorFunc returns either the function set by SetFlagErrorFunc for this +// command or a parent, or it returns a function which returns the original +// error. +func (c *Command) FlagErrorFunc() (f func(*Command, error) error) { + if c.flagErrorFunc != nil { + return c.flagErrorFunc + } + + if c.HasParent() { + return c.parent.FlagErrorFunc() + } + return func(c *Command, err error) error { + return err + } +} + var minUsagePadding = 25 func (c *Command) UsagePadding() int { @@ -553,7 +576,7 @@ func (c *Command) execute(a []string) (err error) { err = c.ParseFlags(a) if err != nil { - return err + return c.FlagErrorFunc()(c, err) } // If help is called, regardless of other flags, return we want help // Also say we need help if the command isn't runnable. diff --git a/command_test.go b/command_test.go index b0c17e2..b8354ec 100644 --- a/command_test.go +++ b/command_test.go @@ -1,6 +1,8 @@ package cobra import ( + "bytes" + "fmt" "os" "reflect" "testing" @@ -174,3 +176,27 @@ func TestEnableCommandSortingIsDisabled(t *testing.T) { EnableCommandSorting = true } + +func TestFlagErrorFunc(t *testing.T) { + + cmd := &Command{ + Use: "print", + RunE: func(cmd *Command, args []string) error { + return nil + }, + } + expectedFmt := "This is expected: %s" + + cmd.SetFlagErrorFunc(func(c *Command, err error) error { + return fmt.Errorf(expectedFmt, err) + }) + cmd.SetArgs([]string{"--bogus-flag"}) + cmd.SetOutput(new(bytes.Buffer)) + + err := cmd.Execute() + + expected := fmt.Sprintf(expectedFmt, "unknown flag: --bogus-flag") + if err.Error() != expected { + t.Errorf("expected %v, got %v", expected, err.Error()) + } +} diff --git a/doc/md_docs.md b/doc/md_docs.md index 0c3b96e..480b152 100644 --- a/doc/md_docs.md +++ b/doc/md_docs.md @@ -101,4 +101,3 @@ linkHandler := func(name string) string { return "/commands/" + strings.ToLower(base) + "/" } ``` -