Merge pull request #1479 from go-redis/fix/hook-call-after

Make sure to call after hook on error
This commit is contained in:
Vladimir Mihailenco 2020-09-11 09:25:48 +03:00 committed by GitHub
commit b67982d210
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 36 additions and 56 deletions

View File

@ -48,83 +48,63 @@ func (hs *hooks) AddHook(hook Hook) {
func (hs hooks) process( func (hs hooks) process(
ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error, ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error,
) error { ) error {
ctx, err := hs.beforeProcess(ctx, cmd) if len(hs.hooks) == 0 {
if err != nil { return fn(ctx, cmd)
cmd.SetErr(err)
return err
} }
cmdErr := fn(ctx, cmd) var hookIndex int
var retErr error
if err := hs.afterProcess(ctx, cmd); err != nil { for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
cmd.SetErr(err) ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd)
return err if retErr != nil {
cmd.SetErr(retErr)
}
} }
return cmdErr if retErr == nil {
} retErr = fn(ctx, cmd)
}
func (hs hooks) beforeProcess(ctx context.Context, cmd Cmder) (context.Context, error) { for hookIndex--; hookIndex >= 0; hookIndex-- {
for _, h := range hs.hooks { if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil {
var err error retErr = err
ctx, err = h.BeforeProcess(ctx, cmd) cmd.SetErr(retErr)
if err != nil {
return nil, err
} }
} }
return ctx, nil
}
func (hs hooks) afterProcess(ctx context.Context, cmd Cmder) error { return retErr
var firstErr error
for i := len(hs.hooks) - 1; i >= 0; i-- {
h := hs.hooks[i]
if err := h.AfterProcess(ctx, cmd); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
} }
func (hs hooks) processPipeline( func (hs hooks) processPipeline(
ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
) error { ) error {
ctx, err := hs.beforeProcessPipeline(ctx, cmds) if len(hs.hooks) == 0 {
if err != nil { return fn(ctx, cmds)
setCmdsErr(cmds, err)
return err
} }
cmdsErr := fn(ctx, cmds) var hookIndex int
var retErr error
if err := hs.afterProcessPipeline(ctx, cmds); err != nil { for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
setCmdsErr(cmds, err) ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds)
return err if retErr != nil {
setCmdsErr(cmds, retErr)
}
} }
return cmdsErr if retErr == nil {
} retErr = fn(ctx, cmds)
}
func (hs hooks) beforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) { for hookIndex--; hookIndex >= 0; hookIndex-- {
for _, h := range hs.hooks { if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil {
var err error retErr = err
ctx, err = h.BeforeProcessPipeline(ctx, cmds) setCmdsErr(cmds, retErr)
if err != nil {
return nil, err
} }
} }
return ctx, nil
}
func (hs hooks) afterProcessPipeline(ctx context.Context, cmds []Cmder) error { return retErr
var firstErr error
for i := len(hs.hooks) - 1; i >= 0; i-- {
h := hs.hooks[i]
if err := h.AfterProcessPipeline(ctx, cmds); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
} }
func (hs hooks) processTxPipeline( func (hs hooks) processTxPipeline(