Merge pull request #1251 from go-redis/fix/hook-error-cmd

Set an error returned from the hook on the Cmd
This commit is contained in:
Vladimir Mihailenco 2020-02-02 12:13:49 +02:00 committed by GitHub
commit 5edc4c8384
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 4 deletions

View File

@ -51,13 +51,14 @@ func (hs hooks) process(
) error { ) error {
ctx, err := hs.beforeProcess(ctx, cmd) ctx, err := hs.beforeProcess(ctx, cmd)
if err != nil { if err != nil {
cmd.setErr(err)
return err return err
} }
cmdErr := fn(ctx, cmd) cmdErr := fn(ctx, cmd)
err = hs.afterProcess(ctx, cmd) if err := hs.afterProcess(ctx, cmd); err != nil {
if err != nil { cmd.setErr(err)
return err return err
} }
@ -91,13 +92,14 @@ func (hs hooks) processPipeline(
) error { ) error {
ctx, err := hs.beforeProcessPipeline(ctx, cmds) ctx, err := hs.beforeProcessPipeline(ctx, cmds)
if err != nil { if err != nil {
setCmdsErr(cmds, err)
return err return err
} }
cmdsErr := fn(ctx, cmds) cmdsErr := fn(ctx, cmds)
err = hs.afterProcessPipeline(ctx, cmds) if err := hs.afterProcessPipeline(ctx, cmds); err != nil {
if err != nil { setCmdsErr(cmds, err)
return err return err
} }

View File

@ -3,7 +3,9 @@ package redis_test
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"net" "net"
"testing"
"time" "time"
"github.com/go-redis/redis/v7" "github.com/go-redis/redis/v7"
@ -12,6 +14,39 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
type redisHookError struct {
redis.Hook
}
var _ redis.Hook = redisHookError{}
func (redisHookError) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
return ctx, nil
}
func (redisHookError) AfterProcess(ctx context.Context, cmd redis.Cmder) error {
return errors.New("hook error")
}
func TestHookError(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: ":6379",
})
rdb.AddHook(redisHookError{})
err := rdb.Ping().Err()
if err == nil {
t.Fatalf("got nil, expected an error")
}
wanted := "hook error"
if err.Error() != wanted {
t.Fatalf(`got %q, wanted %q`, err, wanted)
}
}
//------------------------------------------------------------------------------
var _ = Describe("Client", func() { var _ = Describe("Client", func() {
var client *redis.Client var client *redis.Client