diff --git a/cluster.go b/cluster.go index 12417d1f..219bdb6b 100644 --- a/cluster.go +++ b/cluster.go @@ -838,7 +838,7 @@ type ClusterClient struct { state *clusterStateHolder cmdsInfoCache *cmdsInfoCache cmdable - hooksMixin + *hooksMixin } // NewClusterClient returns a Redis Cluster client as described in @@ -847,8 +847,9 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { opt.init() c := &ClusterClient{ - opt: opt, - nodes: newClusterNodes(opt), + opt: opt, + nodes: newClusterNodes(opt), + hooksMixin: &hooksMixin{}, } c.state = newClusterStateHolder(c.loadState) diff --git a/redis.go b/redis.go index 6eed8424..976ff0c0 100644 --- a/redis.go +++ b/redis.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "strings" + "sync" "sync/atomic" "time" @@ -44,6 +45,8 @@ type hooksMixin struct { slice []Hook initial hooks current hooks + + hooksMu sync.RWMutex } func (hs *hooksMixin) initHooks(hooks hooks) { @@ -117,6 +120,9 @@ func (hs *hooksMixin) AddHook(hook Hook) { func (hs *hooksMixin) chain() { hs.initial.setDefaults() + hs.hooksMu.Lock() + defer hs.hooksMu.Unlock() + hs.current.dial = hs.initial.dial hs.current.process = hs.initial.process hs.current.pipeline = hs.initial.pipeline @@ -138,8 +144,15 @@ func (hs *hooksMixin) chain() { } } -func (hs *hooksMixin) clone() hooksMixin { - clone := *hs +func (hs *hooksMixin) clone() *hooksMixin { + hs.hooksMu.Lock() + defer hs.hooksMu.Unlock() + + clone := &hooksMixin{ + slice: hs.slice, + initial: hs.initial, + current: hs.current, + } l := len(clone.slice) clone.slice = clone.slice[:l:l] return clone @@ -166,7 +179,11 @@ func (hs *hooksMixin) withProcessPipelineHook( } func (hs *hooksMixin) dialHook(ctx context.Context, network, addr string) (net.Conn, error) { - return hs.current.dial(ctx, network, addr) + hs.hooksMu.RLock() + conn, err := hs.current.dial(ctx, network, addr) + hs.hooksMu.RUnlock() + + return conn, err } func (hs *hooksMixin) processHook(ctx context.Context, cmd Cmder) error { @@ -588,8 +605,8 @@ func (c *baseClient) context(ctx context.Context) context.Context { // of idle connections. You can control the pool size with Config.PoolSize option. type Client struct { *baseClient + *hooksMixin cmdable - hooksMixin } // NewClient returns a client to the Redis Server specified by Options. @@ -600,6 +617,7 @@ func NewClient(opt *Options) *Client { baseClient: &baseClient{ opt: opt, }, + hooksMixin: &hooksMixin{}, } c.init() c.connPool = newConnPool(opt, c.dialHook) @@ -620,6 +638,7 @@ func (c *Client) init() { func (c *Client) WithTimeout(timeout time.Duration) *Client { clone := *c clone.baseClient = c.baseClient.withTimeout(timeout) + clone.hooksMixin = c.hooksMixin.clone() clone.init() return &clone } @@ -758,7 +777,7 @@ type Conn struct { baseClient cmdable statefulCmdable - hooksMixin + *hooksMixin } func newConn(opt *Options, connPool pool.Pooler) *Conn { @@ -767,6 +786,7 @@ func newConn(opt *Options, connPool pool.Pooler) *Conn { opt: opt, connPool: connPool, }, + hooksMixin: &hooksMixin{}, } c.cmdable = c.Process diff --git a/ring.go b/ring.go index f924ac0a..bbe3a0ff 100644 --- a/ring.go +++ b/ring.go @@ -487,7 +487,7 @@ func (c *ringSharding) Close() error { // Otherwise you should use Redis Cluster. type Ring struct { cmdable - hooksMixin + *hooksMixin opt *RingOptions sharding *ringSharding @@ -504,6 +504,7 @@ func NewRing(opt *RingOptions) *Ring { opt: opt, sharding: newRingSharding(opt), heartbeatCancelFn: hbCancel, + hooksMixin: &hooksMixin{}, } ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo) diff --git a/sentinel.go b/sentinel.go index 5ea41f17..1afd0fca 100644 --- a/sentinel.go +++ b/sentinel.go @@ -211,6 +211,7 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { baseClient: &baseClient{ opt: opt, }, + hooksMixin: &hooksMixin{}, } rdb.init() @@ -267,7 +268,7 @@ func masterReplicaDialer( // SentinelClient is a client for a Redis Sentinel. type SentinelClient struct { *baseClient - hooksMixin + *hooksMixin } func NewSentinelClient(opt *Options) *SentinelClient { @@ -276,6 +277,7 @@ func NewSentinelClient(opt *Options) *SentinelClient { baseClient: &baseClient{ opt: opt, }, + hooksMixin: &hooksMixin{}, } c.initHooks(hooks{ diff --git a/tx.go b/tx.go index 039eaf35..a16b99e3 100644 --- a/tx.go +++ b/tx.go @@ -19,7 +19,7 @@ type Tx struct { baseClient cmdable statefulCmdable - hooksMixin + *hooksMixin } func (c *Client) newTx() *Tx {