diff --git a/cluster.go b/cluster.go index b84c5790..53483332 100644 --- a/cluster.go +++ b/cluster.go @@ -647,9 +647,6 @@ func (c *clusterStateHolder) ReloadOrGet() (*clusterState, error) { //------------------------------------------------------------------------------ type clusterClient struct { - cmdable - hooks - opt *ClusterOptions nodes *clusterNodes state *clusterStateHolder //nolint:structcheck @@ -661,6 +658,8 @@ type clusterClient struct { // multiple goroutines. type ClusterClient struct { *clusterClient + cmdable + hooks ctx context.Context } @@ -678,8 +677,8 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { } c.state = newClusterStateHolder(c.loadState) c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) + c.cmdable = c.Process - c.init() if opt.IdleCheckFrequency > 0 { go c.reaper(opt.IdleCheckFrequency) } @@ -687,10 +686,6 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { return c } -func (c *ClusterClient) init() { - c.cmdable = c.Process -} - func (c *ClusterClient) Context() context.Context { return c.ctx } @@ -700,8 +695,9 @@ func (c *ClusterClient) WithContext(ctx context.Context) *ClusterClient { panic("nil context") } clone := *c + clone.cmdable = clone.Process + clone.hooks.Lock() clone.ctx = ctx - clone.init() return &clone } diff --git a/race_test.go b/race_test.go index a1dace6a..088b0086 100644 --- a/race_test.go +++ b/race_test.go @@ -2,6 +2,7 @@ package redis_test import ( "bytes" + "context" "fmt" "net" "strconv" @@ -283,6 +284,13 @@ var _ = Describe("races", func() { wg.Wait() Expect(received).To(Equal(uint32(C * N))) }) + + It("should WithContext", func() { + perform(C, func(_ int) { + err := client.WithContext(context.Background()).Ping().Err() + Expect(err).NotTo(HaveOccurred()) + }) + }) }) var _ = Describe("cluster races", func() { diff --git a/redis.go b/redis.go index b87289cb..feb354e8 100644 --- a/redis.go +++ b/redis.go @@ -32,6 +32,10 @@ type hooks struct { hooks []Hook } +func (hs hooks) Lock() { + hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)] +} + func (hs *hooks) AddHook(hook Hook) { hs.hooks = append(hs.hooks, hook) } @@ -466,17 +470,13 @@ func txPipelineReadQueued(rd *proto.Reader, cmds []Cmder) error { //------------------------------------------------------------------------------ -type client struct { - baseClient - cmdable - hooks -} - // Client is a Redis client representing a pool of zero or more // underlying connections. It's safe for concurrent use by multiple // goroutines. type Client struct { - *client + baseClient + cmdable + hooks ctx context.Context } @@ -485,23 +485,17 @@ func NewClient(opt *Options) *Client { opt.init() c := Client{ - client: &client{ - baseClient: baseClient{ - opt: opt, - connPool: newConnPool(opt), - }, + baseClient: baseClient{ + opt: opt, + connPool: newConnPool(opt), }, ctx: context.Background(), } - c.init() + c.cmdable = c.Process return &c } -func (c *Client) init() { - c.cmdable = c.Process -} - func (c *Client) Context() context.Context { return c.ctx } @@ -511,8 +505,9 @@ func (c *Client) WithContext(ctx context.Context) *Client { panic("nil context") } clone := *c + clone.cmdable = clone.Process + clone.hooks.Lock() clone.ctx = ctx - clone.init() return &clone } diff --git a/ring.go b/ring.go index 1c981910..20c5e11c 100644 --- a/ring.go +++ b/ring.go @@ -338,8 +338,6 @@ func (c *ringShards) Close() error { //------------------------------------------------------------------------------ type ring struct { - cmdable - hooks opt *RingOptions shards *ringShards cmdsInfoCache *cmdsInfoCache //nolint:structcheck @@ -361,6 +359,8 @@ type ring struct { // Otherwise you should use Redis Cluster. type Ring struct { *ring + cmdable + hooks ctx context.Context } @@ -374,9 +374,8 @@ func NewRing(opt *RingOptions) *Ring { }, ctx: context.Background(), } - ring.init() - ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo) + ring.cmdable = ring.Process for name, addr := range opt.Addrs { shard := newRingShard(opt, name, addr) @@ -398,10 +397,6 @@ func newRingShard(opt *RingOptions, name, addr string) *Client { return shard } -func (c *Ring) init() { - c.cmdable = c.Process -} - func (c *Ring) Context() context.Context { return c.ctx } @@ -411,8 +406,9 @@ func (c *Ring) WithContext(ctx context.Context) *Ring { panic("nil context") } clone := *c + clone.cmdable = clone.Process + clone.hooks.Lock() clone.ctx = ctx - clone.init() return &clone } diff --git a/sentinel.go b/sentinel.go index 653e2028..b81e6b7c 100644 --- a/sentinel.go +++ b/sentinel.go @@ -90,16 +90,14 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { } c := Client{ - client: &client{ - baseClient: baseClient{ - opt: opt, - connPool: failover.Pool(), - onClose: failover.Close, - }, + baseClient: baseClient{ + opt: opt, + connPool: failover.Pool(), + onClose: failover.Close, }, ctx: context.Background(), } - c.init() + c.cmdable = c.Process return &c } diff --git a/tx.go b/tx.go index b60036d3..0dda8023 100644 --- a/tx.go +++ b/tx.go @@ -15,10 +15,9 @@ const TxFailedErr = proto.RedisError("redis: transaction failed") // by multiple goroutines, because Exec resets list of watched keys. // If you don't need WATCH it is better to use Pipeline. type Tx struct { + baseClient cmdable statefulCmdable - baseClient - ctx context.Context } @@ -49,6 +48,7 @@ func (c *Tx) WithContext(ctx context.Context) *Tx { } clone := *c clone.ctx = ctx + clone.init() return &clone }