Merge pull request #1479 from go-redis/fix/hook-call-after

Make sure to call after hook on error
This commit is contained in:
Vladimir Mihailenco 2020-09-11 09:25:48 +03:00 committed by GitHub
commit b67982d210
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 36 additions and 56 deletions

View File

@ -48,83 +48,63 @@ func (hs *hooks) AddHook(hook Hook) {
func (hs hooks) process(
ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error,
) error {
ctx, err := hs.beforeProcess(ctx, cmd)
if err != nil {
cmd.SetErr(err)
return err
if len(hs.hooks) == 0 {
return fn(ctx, cmd)
}
cmdErr := fn(ctx, cmd)
var hookIndex int
var retErr error
if err := hs.afterProcess(ctx, cmd); err != nil {
cmd.SetErr(err)
return err
for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd)
if retErr != nil {
cmd.SetErr(retErr)
}
}
return cmdErr
if retErr == nil {
retErr = fn(ctx, cmd)
}
func (hs hooks) beforeProcess(ctx context.Context, cmd Cmder) (context.Context, error) {
for _, h := range hs.hooks {
var err error
ctx, err = h.BeforeProcess(ctx, cmd)
if err != nil {
return nil, err
for hookIndex--; hookIndex >= 0; hookIndex-- {
if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil {
retErr = err
cmd.SetErr(retErr)
}
}
return ctx, nil
}
func (hs hooks) afterProcess(ctx context.Context, cmd Cmder) error {
var firstErr error
for i := len(hs.hooks) - 1; i >= 0; i-- {
h := hs.hooks[i]
if err := h.AfterProcess(ctx, cmd); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
return retErr
}
func (hs hooks) processPipeline(
ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
) error {
ctx, err := hs.beforeProcessPipeline(ctx, cmds)
if err != nil {
setCmdsErr(cmds, err)
return err
if len(hs.hooks) == 0 {
return fn(ctx, cmds)
}
cmdsErr := fn(ctx, cmds)
var hookIndex int
var retErr error
if err := hs.afterProcessPipeline(ctx, cmds); err != nil {
setCmdsErr(cmds, err)
return err
for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds)
if retErr != nil {
setCmdsErr(cmds, retErr)
}
}
return cmdsErr
if retErr == nil {
retErr = fn(ctx, cmds)
}
func (hs hooks) beforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) {
for _, h := range hs.hooks {
var err error
ctx, err = h.BeforeProcessPipeline(ctx, cmds)
if err != nil {
return nil, err
for hookIndex--; hookIndex >= 0; hookIndex-- {
if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil {
retErr = err
setCmdsErr(cmds, retErr)
}
}
return ctx, nil
}
func (hs hooks) afterProcessPipeline(ctx context.Context, cmds []Cmder) error {
var firstErr error
for i := len(hs.hooks) - 1; i >= 0; i-- {
h := hs.hooks[i]
if err := h.AfterProcessPipeline(ctx, cmds); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
return retErr
}
func (hs hooks) processTxPipeline(