Merge pull request #2245 from go-redis/fix/dial-hook

fix: late binding for dial hook
This commit is contained in:
Vladimir Mihailenco 2022-10-12 17:18:27 +03:00 committed by GitHub
commit d01dc36c09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 53 additions and 52 deletions

View File

@ -845,9 +845,9 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient {
c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo)
c.cmdable = c.Process c.cmdable = c.Process
c.hooks.process = c.process c.hooks.setProcess(c.process)
c.hooks.processPipeline = c._processPipeline c.hooks.setProcessPipeline(c._processPipeline)
c.hooks.processTxPipeline = c._processTxPipeline c.hooks.setProcessTxPipeline(c._processTxPipeline)
return c return c
} }

View File

@ -89,13 +89,7 @@ func (th *tracingHook) DialHook(hook redis.DialHook) redis.DialHook {
return hook(ctx, network, addr) return hook(ctx, network, addr)
} }
spanOpts := th.spanOpts ctx, span := th.conf.tracer.Start(ctx, "redis.dial", th.spanOpts...)
spanOpts = append(spanOpts, trace.WithAttributes(
attribute.String("network", network),
attribute.String("addr", addr),
))
ctx, span := th.conf.tracer.Start(ctx, "redis.dial", spanOpts...)
defer span.End() defer span.End()
conn, err := hook(ctx, network, addr) conn, err := hook(ctx, network, addr)

View File

@ -38,28 +38,18 @@ type (
type hooks struct { type hooks struct {
slice []Hook slice []Hook
dial DialHook dialHook DialHook
process ProcessHook processHook ProcessHook
processPipeline ProcessPipelineHook processPipelineHook ProcessPipelineHook
processTxPipeline ProcessPipelineHook processTxPipelineHook ProcessPipelineHook
} }
func (hs *hooks) AddHook(hook Hook) { func (hs *hooks) AddHook(hook Hook) {
if hs.process == nil {
panic("hs.process == nil")
}
if hs.processPipeline == nil {
panic("hs.processPipeline == nil")
}
if hs.processTxPipeline == nil {
panic("hs.processTxPipeline == nil")
}
hs.slice = append(hs.slice, hook) hs.slice = append(hs.slice, hook)
hs.dial = hook.DialHook(hs.dial) hs.dialHook = hook.DialHook(hs.dialHook)
hs.process = hook.ProcessHook(hs.process) hs.processHook = hook.ProcessHook(hs.processHook)
hs.processPipeline = hook.ProcessPipelineHook(hs.processPipeline) hs.processPipelineHook = hook.ProcessPipelineHook(hs.processPipelineHook)
hs.processTxPipeline = hook.ProcessPipelineHook(hs.processTxPipeline) hs.processTxPipelineHook = hook.ProcessPipelineHook(hs.processTxPipelineHook)
} }
func (hs *hooks) clone() hooks { func (hs *hooks) clone() hooks {
@ -70,37 +60,37 @@ func (hs *hooks) clone() hooks {
} }
func (hs *hooks) setDial(dial DialHook) { func (hs *hooks) setDial(dial DialHook) {
hs.dial = dial hs.dialHook = dial
for _, h := range hs.slice { for _, h := range hs.slice {
if wrapped := h.DialHook(hs.dial); wrapped != nil { if wrapped := h.DialHook(hs.dialHook); wrapped != nil {
hs.dial = wrapped hs.dialHook = wrapped
} }
} }
} }
func (hs *hooks) setProcess(process ProcessHook) { func (hs *hooks) setProcess(process ProcessHook) {
hs.process = process hs.processHook = process
for _, h := range hs.slice { for _, h := range hs.slice {
if wrapped := h.ProcessHook(hs.process); wrapped != nil { if wrapped := h.ProcessHook(hs.processHook); wrapped != nil {
hs.process = wrapped hs.processHook = wrapped
} }
} }
} }
func (hs *hooks) setProcessPipeline(processPipeline ProcessPipelineHook) { func (hs *hooks) setProcessPipeline(processPipeline ProcessPipelineHook) {
hs.processPipeline = processPipeline hs.processPipelineHook = processPipeline
for _, h := range hs.slice { for _, h := range hs.slice {
if wrapped := h.ProcessPipelineHook(hs.processPipeline); wrapped != nil { if wrapped := h.ProcessPipelineHook(hs.processPipelineHook); wrapped != nil {
hs.processPipeline = wrapped hs.processPipelineHook = wrapped
} }
} }
} }
func (hs *hooks) setProcessTxPipeline(processTxPipeline ProcessPipelineHook) { func (hs *hooks) setProcessTxPipeline(processTxPipeline ProcessPipelineHook) {
hs.processTxPipeline = processTxPipeline hs.processTxPipelineHook = processTxPipeline
for _, h := range hs.slice { for _, h := range hs.slice {
if wrapped := h.ProcessPipelineHook(hs.processTxPipeline); wrapped != nil { if wrapped := h.ProcessPipelineHook(hs.processTxPipelineHook); wrapped != nil {
hs.processTxPipeline = wrapped hs.processTxPipelineHook = wrapped
} }
} }
} }
@ -125,6 +115,22 @@ func (hs *hooks) withProcessPipelineHook(
return hook(ctx, cmds) return hook(ctx, cmds)
} }
func (hs *hooks) dial(ctx context.Context, network, addr string) (net.Conn, error) {
return hs.dialHook(ctx, network, addr)
}
func (hs *hooks) process(ctx context.Context, cmd Cmder) error {
return hs.processHook(ctx, cmd)
}
func (hs *hooks) processPipeline(ctx context.Context, cmds []Cmder) error {
return hs.processPipelineHook(ctx, cmds)
}
func (hs *hooks) processTxPipeline(ctx context.Context, cmds []Cmder) error {
return hs.processTxPipelineHook(ctx, cmds)
}
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
type baseClient struct { type baseClient struct {
@ -538,8 +544,8 @@ func NewClient(opt *Options) *Client {
opt: opt, opt: opt,
}, },
} }
c.connPool = newConnPool(opt, c.baseClient.dial)
c.init() c.init()
c.connPool = newConnPool(opt, c.hooks.dial)
return &c return &c
} }

14
ring.go
View File

@ -495,13 +495,13 @@ func NewRing(opt *RingOptions) *Ring {
ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo) ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo)
ring.cmdable = ring.Process ring.cmdable = ring.Process
ring.hooks.process = ring.process ring.hooks.setProcess(ring.process)
ring.hooks.processPipeline = func(ctx context.Context, cmds []Cmder) error { ring.hooks.setProcessPipeline(func(ctx context.Context, cmds []Cmder) error {
return ring.generalProcessPipeline(ctx, cmds, false) return ring.generalProcessPipeline(ctx, cmds, false)
} })
ring.hooks.processTxPipeline = func(ctx context.Context, cmds []Cmder) error { ring.hooks.setProcessTxPipeline(func(ctx context.Context, cmds []Cmder) error {
return ring.generalProcessPipeline(ctx, cmds, true) return ring.generalProcessPipeline(ctx, cmds, true)
} })
go ring.sharding.Heartbeat(hbCtx, opt.HeartbeatFrequency) go ring.sharding.Heartbeat(hbCtx, opt.HeartbeatFrequency)
@ -758,9 +758,9 @@ func (c *Ring) generalProcessPipeline(
if tx { if tx {
cmds = wrapMultiExec(ctx, cmds) cmds = wrapMultiExec(ctx, cmds)
shard.Client.hooks.processTxPipeline(ctx, cmds) _ = shard.Client.hooks.processTxPipeline(ctx, cmds)
} else { } else {
shard.Client.hooks.processPipeline(ctx, cmds) _ = shard.Client.hooks.processPipeline(ctx, cmds)
} }
}(hash, cmds) }(hash, cmds)
} }

View File

@ -205,10 +205,11 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
opt: opt, opt: opt,
}, },
} }
connPool = newConnPool(opt, rdb.baseClient.dial) rdb.init()
connPool = newConnPool(opt, rdb.hooks.dial)
rdb.connPool = connPool rdb.connPool = connPool
rdb.onClose = failover.Close rdb.onClose = failover.Close
rdb.init()
failover.mu.Lock() failover.mu.Lock()
failover.onFailover = func(ctx context.Context, addr string) { failover.onFailover = func(ctx context.Context, addr string) {
@ -269,10 +270,10 @@ func NewSentinelClient(opt *Options) *SentinelClient {
opt: opt, opt: opt,
}, },
} }
c.connPool = newConnPool(opt, c.baseClient.dial)
c.hooks.setDial(c.baseClient.dial) c.hooks.setDial(c.baseClient.dial)
c.hooks.setProcess(c.baseClient.process) c.hooks.setProcess(c.baseClient.process)
c.connPool = newConnPool(opt, c.hooks.dial)
return c return c
} }