Make sure to call after hook on error

This commit is contained in:
Vladimir Mihailenco 2020-09-10 11:06:15 +03:00
parent fb80d4211a
commit 69287d7ea9
1 changed files with 36 additions and 56 deletions

View File

@ -48,83 +48,63 @@ func (hs *hooks) AddHook(hook Hook) {
func (hs hooks) process( func (hs hooks) process(
ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error, ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error,
) error { ) error {
ctx, err := hs.beforeProcess(ctx, cmd) if len(hs.hooks) == 0 {
if err != nil { return fn(ctx, cmd)
cmd.SetErr(err)
return err
} }
cmdErr := fn(ctx, cmd) var hookIndex int
var retErr error
if err := hs.afterProcess(ctx, cmd); err != nil { for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
cmd.SetErr(err) ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd)
return err if retErr != nil {
} cmd.SetErr(retErr)
return cmdErr
}
func (hs hooks) beforeProcess(ctx context.Context, cmd Cmder) (context.Context, error) {
for _, h := range hs.hooks {
var err error
ctx, err = h.BeforeProcess(ctx, cmd)
if err != nil {
return nil, err
} }
} }
return ctx, nil
}
func (hs hooks) afterProcess(ctx context.Context, cmd Cmder) error { if retErr == nil {
var firstErr error retErr = fn(ctx, cmd)
for i := len(hs.hooks) - 1; i >= 0; i-- { }
h := hs.hooks[i]
if err := h.AfterProcess(ctx, cmd); err != nil && firstErr == nil { for hookIndex--; hookIndex >= 0; hookIndex-- {
firstErr = err if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil {
retErr = err
cmd.SetErr(retErr)
} }
} }
return firstErr
return retErr
} }
func (hs hooks) processPipeline( func (hs hooks) processPipeline(
ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
) error { ) error {
ctx, err := hs.beforeProcessPipeline(ctx, cmds) if len(hs.hooks) == 0 {
if err != nil { return fn(ctx, cmds)
setCmdsErr(cmds, err)
return err
} }
cmdsErr := fn(ctx, cmds) var hookIndex int
var retErr error
if err := hs.afterProcessPipeline(ctx, cmds); err != nil { for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
setCmdsErr(cmds, err) ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds)
return err if retErr != nil {
} setCmdsErr(cmds, retErr)
return cmdsErr
}
func (hs hooks) beforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) {
for _, h := range hs.hooks {
var err error
ctx, err = h.BeforeProcessPipeline(ctx, cmds)
if err != nil {
return nil, err
} }
} }
return ctx, nil
}
func (hs hooks) afterProcessPipeline(ctx context.Context, cmds []Cmder) error { if retErr == nil {
var firstErr error retErr = fn(ctx, cmds)
for i := len(hs.hooks) - 1; i >= 0; i-- { }
h := hs.hooks[i]
if err := h.AfterProcessPipeline(ctx, cmds); err != nil && firstErr == nil { for hookIndex--; hookIndex >= 0; hookIndex-- {
firstErr = err if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil {
retErr = err
setCmdsErr(cmds, retErr)
} }
} }
return firstErr
return retErr
} }
func (hs hooks) processTxPipeline( func (hs hooks) processTxPipeline(