diff --git a/redis.go b/redis.go index 9430eb75..9792af76 100644 --- a/redis.go +++ b/redis.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "sync" "sync/atomic" "time" @@ -40,12 +41,15 @@ type ( ) type hooksMixin struct { + hooksMu *sync.Mutex + slice []Hook initial hooks current hooks } func (hs *hooksMixin) initHooks(hooks hooks) { + hs.hooksMu = new(sync.Mutex) hs.initial = hooks hs.chain() } @@ -116,6 +120,9 @@ func (hs *hooksMixin) AddHook(hook Hook) { func (hs *hooksMixin) chain() { hs.initial.setDefaults() + hs.hooksMu.Lock() + defer hs.hooksMu.Unlock() + hs.current.dial = hs.initial.dial hs.current.process = hs.initial.process hs.current.pipeline = hs.initial.pipeline @@ -138,9 +145,13 @@ func (hs *hooksMixin) chain() { } func (hs *hooksMixin) clone() hooksMixin { + hs.hooksMu.Lock() + defer hs.hooksMu.Unlock() + clone := *hs l := len(clone.slice) clone.slice = clone.slice[:l:l] + clone.hooksMu = new(sync.Mutex) return clone } @@ -165,6 +176,8 @@ func (hs *hooksMixin) withProcessPipelineHook( } func (hs *hooksMixin) dialHook(ctx context.Context, network, addr string) (net.Conn, error) { + hs.hooksMu.Lock() + defer hs.hooksMu.Unlock() return hs.current.dial(ctx, network, addr) } diff --git a/redis_test.go b/redis_test.go index 5870fad7..728f6c29 100644 --- a/redis_test.go +++ b/redis_test.go @@ -579,3 +579,53 @@ var _ = Describe("Hook", func() { Expect(cmd.Val()).To(Equal("Script and hook")) }) }) + +var _ = Describe("Hook with MinIdleConns", func() { + var client *redis.Client + + BeforeEach(func() { + options := redisOptions() + options.MinIdleConns = 1 + client = redis.NewClient(options) + Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + err := client.Close() + Expect(err).NotTo(HaveOccurred()) + }) + + It("fifo", func() { + var res []string + client.AddHook(&hook{ + processHook: func(hook redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + res = append(res, "hook-1-process-start") + err := hook(ctx, cmd) + res = append(res, "hook-1-process-end") + return err + } + }, + }) + client.AddHook(&hook{ + processHook: func(hook redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + res = append(res, "hook-2-process-start") + err := hook(ctx, cmd) + res = append(res, "hook-2-process-end") + return err + } + }, + }) + + err := client.Ping(ctx).Err() + Expect(err).NotTo(HaveOccurred()) + + Expect(res).To(Equal([]string{ + "hook-1-process-start", + "hook-2-process-start", + "hook-2-process-end", + "hook-1-process-end", + })) + }) +})