From 2ec03d9b370ea6f42b4ce4054121fccf31649b00 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Wed, 12 Oct 2022 15:00:06 +0300 Subject: [PATCH] fix: late binding for dial hook --- cluster.go | 6 ++-- extra/redisotel/tracing.go | 8 +---- redis.go | 70 +++++++++++++++++++++----------------- ring.go | 10 +++--- sentinel.go | 7 ++-- 5 files changed, 51 insertions(+), 50 deletions(-) diff --git a/cluster.go b/cluster.go index 88a0b0e8..a7cc541a 100644 --- a/cluster.go +++ b/cluster.go @@ -845,9 +845,9 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) c.cmdable = c.Process - c.hooks.process = c.process - c.hooks.processPipeline = c._processPipeline - c.hooks.processTxPipeline = c._processTxPipeline + c.hooks.setProcess(c.process) + c.hooks.setProcessPipeline(c._processPipeline) + c.hooks.setProcessTxPipeline(c._processTxPipeline) return c } diff --git a/extra/redisotel/tracing.go b/extra/redisotel/tracing.go index db74c728..be1a283f 100644 --- a/extra/redisotel/tracing.go +++ b/extra/redisotel/tracing.go @@ -89,13 +89,7 @@ func (th *tracingHook) DialHook(hook redis.DialHook) redis.DialHook { return hook(ctx, network, addr) } - spanOpts := th.spanOpts - spanOpts = append(spanOpts, trace.WithAttributes( - attribute.String("network", network), - attribute.String("addr", addr), - )) - - ctx, span := th.conf.tracer.Start(ctx, "redis.dial", spanOpts...) + ctx, span := th.conf.tracer.Start(ctx, "redis.dial", th.spanOpts...) defer span.End() conn, err := hook(ctx, network, addr) diff --git a/redis.go b/redis.go index 2f3585ec..4d683f75 100644 --- a/redis.go +++ b/redis.go @@ -37,29 +37,19 @@ type ( ) type hooks struct { - slice []Hook - dial DialHook - process ProcessHook - processPipeline ProcessPipelineHook - processTxPipeline ProcessPipelineHook + slice []Hook + dialHook DialHook + processHook ProcessHook + processPipelineHook ProcessPipelineHook + processTxPipelineHook ProcessPipelineHook } func (hs *hooks) AddHook(hook Hook) { - if hs.process == nil { - panic("hs.process == nil") - } - if hs.processPipeline == nil { - panic("hs.processPipeline == nil") - } - if hs.processTxPipeline == nil { - panic("hs.processTxPipeline == nil") - } - hs.slice = append(hs.slice, hook) - hs.dial = hook.DialHook(hs.dial) - hs.process = hook.ProcessHook(hs.process) - hs.processPipeline = hook.ProcessPipelineHook(hs.processPipeline) - hs.processTxPipeline = hook.ProcessPipelineHook(hs.processTxPipeline) + hs.dialHook = hook.DialHook(hs.dialHook) + hs.processHook = hook.ProcessHook(hs.processHook) + hs.processPipelineHook = hook.ProcessPipelineHook(hs.processPipelineHook) + hs.processTxPipelineHook = hook.ProcessPipelineHook(hs.processTxPipelineHook) } func (hs *hooks) clone() hooks { @@ -70,37 +60,37 @@ func (hs *hooks) clone() hooks { } func (hs *hooks) setDial(dial DialHook) { - hs.dial = dial + hs.dialHook = dial for _, h := range hs.slice { - if wrapped := h.DialHook(hs.dial); wrapped != nil { - hs.dial = wrapped + if wrapped := h.DialHook(hs.dialHook); wrapped != nil { + hs.dialHook = wrapped } } } func (hs *hooks) setProcess(process ProcessHook) { - hs.process = process + hs.processHook = process for _, h := range hs.slice { - if wrapped := h.ProcessHook(hs.process); wrapped != nil { - hs.process = wrapped + if wrapped := h.ProcessHook(hs.processHook); wrapped != nil { + hs.processHook = wrapped } } } func (hs *hooks) setProcessPipeline(processPipeline ProcessPipelineHook) { - hs.processPipeline = processPipeline + hs.processPipelineHook = processPipeline for _, h := range hs.slice { - if wrapped := h.ProcessPipelineHook(hs.processPipeline); wrapped != nil { - hs.processPipeline = wrapped + if wrapped := h.ProcessPipelineHook(hs.processPipelineHook); wrapped != nil { + hs.processPipelineHook = wrapped } } } func (hs *hooks) setProcessTxPipeline(processTxPipeline ProcessPipelineHook) { - hs.processTxPipeline = processTxPipeline + hs.processTxPipelineHook = processTxPipeline for _, h := range hs.slice { - if wrapped := h.ProcessPipelineHook(hs.processTxPipeline); wrapped != nil { - hs.processTxPipeline = wrapped + if wrapped := h.ProcessPipelineHook(hs.processTxPipelineHook); wrapped != nil { + hs.processTxPipelineHook = wrapped } } } @@ -125,6 +115,22 @@ func (hs *hooks) withProcessPipelineHook( return hook(ctx, cmds) } +func (hs *hooks) dial(ctx context.Context, network, addr string) (net.Conn, error) { + return hs.dialHook(ctx, network, addr) +} + +func (hs *hooks) process(ctx context.Context, cmd Cmder) error { + return hs.processHook(ctx, cmd) +} + +func (hs *hooks) processPipeline(ctx context.Context, cmds []Cmder) error { + return hs.processPipelineHook(ctx, cmds) +} + +func (hs *hooks) processTxPipeline(ctx context.Context, cmds []Cmder) error { + return hs.processTxPipelineHook(ctx, cmds) +} + //------------------------------------------------------------------------------ type baseClient struct { @@ -538,8 +544,8 @@ func NewClient(opt *Options) *Client { opt: opt, }, } - c.connPool = newConnPool(opt, c.baseClient.dial) c.init() + c.connPool = newConnPool(opt, c.hooks.dial) return &c } diff --git a/ring.go b/ring.go index a4f0b06d..1cb0d221 100644 --- a/ring.go +++ b/ring.go @@ -495,13 +495,13 @@ func NewRing(opt *RingOptions) *Ring { ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo) ring.cmdable = ring.Process - ring.hooks.process = ring.process - ring.hooks.processPipeline = func(ctx context.Context, cmds []Cmder) error { + ring.hooks.setProcess(ring.process) + ring.hooks.setProcessPipeline(func(ctx context.Context, cmds []Cmder) error { return ring.generalProcessPipeline(ctx, cmds, false) - } - ring.hooks.processTxPipeline = func(ctx context.Context, cmds []Cmder) error { + }) + ring.hooks.setProcessTxPipeline(func(ctx context.Context, cmds []Cmder) error { return ring.generalProcessPipeline(ctx, cmds, true) - } + }) go ring.sharding.Heartbeat(hbCtx, opt.HeartbeatFrequency) diff --git a/sentinel.go b/sentinel.go index 41228d45..44b073f4 100644 --- a/sentinel.go +++ b/sentinel.go @@ -205,10 +205,11 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { opt: opt, }, } - connPool = newConnPool(opt, rdb.baseClient.dial) + rdb.init() + + connPool = newConnPool(opt, rdb.hooks.dial) rdb.connPool = connPool rdb.onClose = failover.Close - rdb.init() failover.mu.Lock() failover.onFailover = func(ctx context.Context, addr string) { @@ -269,10 +270,10 @@ func NewSentinelClient(opt *Options) *SentinelClient { opt: opt, }, } - c.connPool = newConnPool(opt, c.baseClient.dial) c.hooks.setDial(c.baseClient.dial) c.hooks.setProcess(c.baseClient.process) + c.connPool = newConnPool(opt, c.hooks.dial) return c }