From 69287d7ea9c3b26c85ad06a33b5f7f01134f287b Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Thu, 10 Sep 2020 11:06:15 +0300 Subject: [PATCH] Make sure to call after hook on error --- redis.go | 92 ++++++++++++++++++++++---------------------------------- 1 file changed, 36 insertions(+), 56 deletions(-) diff --git a/redis.go b/redis.go index 198d4a7e..617bf973 100644 --- a/redis.go +++ b/redis.go @@ -48,83 +48,63 @@ func (hs *hooks) AddHook(hook Hook) { func (hs hooks) process( ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error, ) error { - ctx, err := hs.beforeProcess(ctx, cmd) - if err != nil { - cmd.SetErr(err) - return err + if len(hs.hooks) == 0 { + return fn(ctx, cmd) } - cmdErr := fn(ctx, cmd) + var hookIndex int + var retErr error - if err := hs.afterProcess(ctx, cmd); err != nil { - cmd.SetErr(err) - return err - } - - return cmdErr -} - -func (hs hooks) beforeProcess(ctx context.Context, cmd Cmder) (context.Context, error) { - for _, h := range hs.hooks { - var err error - ctx, err = h.BeforeProcess(ctx, cmd) - if err != nil { - return nil, err + for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ { + ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd) + if retErr != nil { + cmd.SetErr(retErr) } } - return ctx, nil -} -func (hs hooks) afterProcess(ctx context.Context, cmd Cmder) error { - var firstErr error - for i := len(hs.hooks) - 1; i >= 0; i-- { - h := hs.hooks[i] - if err := h.AfterProcess(ctx, cmd); err != nil && firstErr == nil { - firstErr = err + if retErr == nil { + retErr = fn(ctx, cmd) + } + + for hookIndex--; hookIndex >= 0; hookIndex-- { + if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil { + retErr = err + cmd.SetErr(retErr) } } - return firstErr + + return retErr } func (hs hooks) processPipeline( ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, ) error { - ctx, err := hs.beforeProcessPipeline(ctx, cmds) - if err != nil { - setCmdsErr(cmds, err) - return err + if len(hs.hooks) == 0 { + return fn(ctx, cmds) } - cmdsErr := fn(ctx, cmds) + var hookIndex int + var retErr error - if err := hs.afterProcessPipeline(ctx, cmds); err != nil { - setCmdsErr(cmds, err) - return err - } - - return cmdsErr -} - -func (hs hooks) beforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) { - for _, h := range hs.hooks { - var err error - ctx, err = h.BeforeProcessPipeline(ctx, cmds) - if err != nil { - return nil, err + for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ { + ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds) + if retErr != nil { + setCmdsErr(cmds, retErr) } } - return ctx, nil -} -func (hs hooks) afterProcessPipeline(ctx context.Context, cmds []Cmder) error { - var firstErr error - for i := len(hs.hooks) - 1; i >= 0; i-- { - h := hs.hooks[i] - if err := h.AfterProcessPipeline(ctx, cmds); err != nil && firstErr == nil { - firstErr = err + if retErr == nil { + retErr = fn(ctx, cmds) + } + + for hookIndex--; hookIndex >= 0; hookIndex-- { + if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil { + retErr = err + setCmdsErr(cmds, retErr) } } - return firstErr + + return retErr } func (hs hooks) processTxPipeline(