diff --git a/redis.go b/redis.go index 198d4a7..617bf97 100644 --- a/redis.go +++ b/redis.go @@ -48,83 +48,63 @@ func (hs *hooks) AddHook(hook Hook) { func (hs hooks) process( ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error, ) error { - ctx, err := hs.beforeProcess(ctx, cmd) - if err != nil { - cmd.SetErr(err) - return err + if len(hs.hooks) == 0 { + return fn(ctx, cmd) } - cmdErr := fn(ctx, cmd) + var hookIndex int + var retErr error - if err := hs.afterProcess(ctx, cmd); err != nil { - cmd.SetErr(err) - return err - } - - return cmdErr -} - -func (hs hooks) beforeProcess(ctx context.Context, cmd Cmder) (context.Context, error) { - for _, h := range hs.hooks { - var err error - ctx, err = h.BeforeProcess(ctx, cmd) - if err != nil { - return nil, err + for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ { + ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd) + if retErr != nil { + cmd.SetErr(retErr) } } - return ctx, nil -} -func (hs hooks) afterProcess(ctx context.Context, cmd Cmder) error { - var firstErr error - for i := len(hs.hooks) - 1; i >= 0; i-- { - h := hs.hooks[i] - if err := h.AfterProcess(ctx, cmd); err != nil && firstErr == nil { - firstErr = err + if retErr == nil { + retErr = fn(ctx, cmd) + } + + for hookIndex--; hookIndex >= 0; hookIndex-- { + if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil { + retErr = err + cmd.SetErr(retErr) } } - return firstErr + + return retErr } func (hs hooks) processPipeline( ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, ) error { - ctx, err := hs.beforeProcessPipeline(ctx, cmds) - if err != nil { - setCmdsErr(cmds, err) - return err + if len(hs.hooks) == 0 { + return fn(ctx, cmds) } - cmdsErr := fn(ctx, cmds) + var hookIndex int + var retErr error - if err := hs.afterProcessPipeline(ctx, cmds); err != nil { - setCmdsErr(cmds, err) - return err - } - - return cmdsErr -} - -func (hs hooks) beforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) { - for _, h := range hs.hooks { - var err error - ctx, err = h.BeforeProcessPipeline(ctx, cmds) - if err != nil { - return nil, err + for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ { + ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds) + if retErr != nil { + setCmdsErr(cmds, retErr) } } - return ctx, nil -} -func (hs hooks) afterProcessPipeline(ctx context.Context, cmds []Cmder) error { - var firstErr error - for i := len(hs.hooks) - 1; i >= 0; i-- { - h := hs.hooks[i] - if err := h.AfterProcessPipeline(ctx, cmds); err != nil && firstErr == nil { - firstErr = err + if retErr == nil { + retErr = fn(ctx, cmds) + } + + for hookIndex--; hookIndex >= 0; hookIndex-- { + if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil { + retErr = err + setCmdsErr(cmds, retErr) } } - return firstErr + + return retErr } func (hs hooks) processTxPipeline(