diff --git a/redis.go b/redis.go index 123e64d..66dc72f 100644 --- a/redis.go +++ b/redis.go @@ -51,13 +51,14 @@ func (hs hooks) process( ) error { ctx, err := hs.beforeProcess(ctx, cmd) if err != nil { + cmd.setErr(err) return err } cmdErr := fn(ctx, cmd) - err = hs.afterProcess(ctx, cmd) - if err != nil { + if err := hs.afterProcess(ctx, cmd); err != nil { + cmd.setErr(err) return err } @@ -91,13 +92,14 @@ func (hs hooks) processPipeline( ) error { ctx, err := hs.beforeProcessPipeline(ctx, cmds) if err != nil { + setCmdsErr(cmds, err) return err } cmdsErr := fn(ctx, cmds) - err = hs.afterProcessPipeline(ctx, cmds) - if err != nil { + if err := hs.afterProcessPipeline(ctx, cmds); err != nil { + setCmdsErr(cmds, err) return err } diff --git a/redis_test.go b/redis_test.go index b74281d..8a1a149 100644 --- a/redis_test.go +++ b/redis_test.go @@ -3,7 +3,9 @@ package redis_test import ( "bytes" "context" + "errors" "net" + "testing" "time" "github.com/go-redis/redis/v7" @@ -12,6 +14,39 @@ import ( . "github.com/onsi/gomega" ) +type redisHookError struct { + redis.Hook +} + +var _ redis.Hook = redisHookError{} + +func (redisHookError) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) { + return ctx, nil +} + +func (redisHookError) AfterProcess(ctx context.Context, cmd redis.Cmder) error { + return errors.New("hook error") +} + +func TestHookError(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: ":6379", + }) + rdb.AddHook(redisHookError{}) + + err := rdb.Ping().Err() + if err == nil { + t.Fatalf("got nil, expected an error") + } + + wanted := "hook error" + if err.Error() != wanted { + t.Fatalf(`got %q, wanted %q`, err, wanted) + } +} + +//------------------------------------------------------------------------------ + var _ = Describe("Client", func() { var client *redis.Client