diff --git a/internal/pool/bench_test.go b/internal/pool/bench_test.go index a43c19e2..dec5d3f2 100644 --- a/internal/pool/bench_test.go +++ b/internal/pool/bench_test.go @@ -18,6 +18,7 @@ func (bm poolGetPutBenchmark) String() string { } func BenchmarkPoolGetPut(b *testing.B) { + ctx := context.Background() benchmarks := []poolGetPutBenchmark{ {1}, {2}, @@ -40,11 +41,11 @@ func BenchmarkPoolGetPut(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - cn, err := connPool.Get(context.Background()) + cn, err := connPool.Get(ctx) if err != nil { b.Fatal(err) } - connPool.Put(cn) + connPool.Put(ctx, cn) } }) }) @@ -60,6 +61,7 @@ func (bm poolGetRemoveBenchmark) String() string { } func BenchmarkPoolGetRemove(b *testing.B) { + ctx := context.Background() benchmarks := []poolGetRemoveBenchmark{ {1}, {2}, @@ -68,6 +70,7 @@ func BenchmarkPoolGetRemove(b *testing.B) { {64}, {128}, } + for _, bm := range benchmarks { b.Run(bm.String(), func(b *testing.B) { connPool := pool.NewConnPool(&pool.Options{ @@ -82,11 +85,11 @@ func BenchmarkPoolGetRemove(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - cn, err := connPool.Get(context.Background()) + cn, err := connPool.Get(ctx) if err != nil { b.Fatal(err) } - connPool.Remove(cn, nil) + connPool.Remove(ctx, cn, nil) } }) }) diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 064efa43..d01a4e1e 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -40,8 +40,8 @@ type Pooler interface { CloseConn(*Conn) error Get(context.Context) (*Conn, error) - Put(*Conn) - Remove(*Conn, error) + Put(context.Context, *Conn) + Remove(context.Context, *Conn, error) Len() int IdleLen() int @@ -318,15 +318,15 @@ func (p *ConnPool) popIdle() *Conn { return cn } -func (p *ConnPool) Put(cn *Conn) { +func (p *ConnPool) Put(ctx context.Context, cn *Conn) { if cn.rd.Buffered() > 0 { - internal.Logger.Printf(context.Background(), "Conn has unread data") - p.Remove(cn, BadConnError{}) + internal.Logger.Printf(ctx, "Conn has unread data") + p.Remove(ctx, cn, BadConnError{}) return } if !cn.pooled { - p.Remove(cn, nil) + p.Remove(ctx, cn, nil) return } @@ -337,7 +337,7 @@ func (p *ConnPool) Put(cn *Conn) { p.freeTurn() } -func (p *ConnPool) Remove(cn *Conn, reason error) { +func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) { p.removeConnWithLock(cn) p.freeTurn() _ = p.closeConn(cn) diff --git a/internal/pool/pool_single.go b/internal/pool/pool_single.go index 04758a00..5a3fde19 100644 --- a/internal/pool/pool_single.go +++ b/internal/pool/pool_single.go @@ -1,64 +1,19 @@ package pool -import ( - "context" - "fmt" - "sync/atomic" -) - -const ( - stateDefault = 0 - stateInited = 1 - stateClosed = 2 -) - -type BadConnError struct { - wrapped error -} - -var _ error = (*BadConnError)(nil) - -func (e BadConnError) Error() string { - s := "redis: Conn is in a bad state" - if e.wrapped != nil { - s += ": " + e.wrapped.Error() - } - return s -} - -func (e BadConnError) Unwrap() error { - return e.wrapped -} +import "context" type SingleConnPool struct { - pool Pooler - level int32 // atomic - - state uint32 // atomic - ch chan *Conn - - _badConnError atomic.Value + pool Pooler + cn *Conn + stickyErr error } var _ Pooler = (*SingleConnPool)(nil) -func NewSingleConnPool(pool Pooler) *SingleConnPool { - p, ok := pool.(*SingleConnPool) - if !ok { - p = &SingleConnPool{ - pool: pool, - ch: make(chan *Conn, 1), - } - } - atomic.AddInt32(&p.level, 1) - return p -} - -func (p *SingleConnPool) SetConn(cn *Conn) { - if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) { - p.ch <- cn - } else { - panic("not reached") +func NewSingleConnPool(pool Pooler, cn *Conn) *SingleConnPool { + return &SingleConnPool{ + pool: pool, + cn: cn, } } @@ -71,138 +26,33 @@ func (p *SingleConnPool) CloseConn(cn *Conn) error { } func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) { - // In worst case this races with Close which is not a very common operation. - for i := 0; i < 1000; i++ { - switch atomic.LoadUint32(&p.state) { - case stateDefault: - cn, err := p.pool.Get(ctx) - if err != nil { - return nil, err - } - if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) { - return cn, nil - } - p.pool.Remove(cn, ErrClosed) - case stateInited: - if err := p.badConnError(); err != nil { - return nil, err - } - cn, ok := <-p.ch - if !ok { - return nil, ErrClosed - } - return cn, nil - case stateClosed: - return nil, ErrClosed - default: - panic("not reached") - } + if p.stickyErr != nil { + return nil, p.stickyErr } - return nil, fmt.Errorf("redis: SingleConnPool.Get: infinite loop") + return p.cn, nil } -func (p *SingleConnPool) Put(cn *Conn) { - defer func() { - if recover() != nil { - p.freeConn(cn) - } - }() - p.ch <- cn +func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) {} + +func (p *SingleConnPool) Remove(ctx context.Context, cn *Conn, reason error) { + p.cn = nil + p.stickyErr = reason } -func (p *SingleConnPool) freeConn(cn *Conn) { - if err := p.badConnError(); err != nil { - p.pool.Remove(cn, err) - } else { - p.pool.Put(cn) - } -} - -func (p *SingleConnPool) Remove(cn *Conn, reason error) { - defer func() { - if recover() != nil { - p.pool.Remove(cn, ErrClosed) - } - }() - p._badConnError.Store(BadConnError{wrapped: reason}) - p.ch <- cn +func (p *SingleConnPool) Close() error { + p.cn = nil + p.stickyErr = ErrClosed + return nil } func (p *SingleConnPool) Len() int { - switch atomic.LoadUint32(&p.state) { - case stateDefault: - return 0 - case stateInited: - return 1 - case stateClosed: - return 0 - default: - panic("not reached") - } + return 0 } func (p *SingleConnPool) IdleLen() int { - return len(p.ch) + return 0 } func (p *SingleConnPool) Stats() *Stats { return &Stats{} } - -func (p *SingleConnPool) Close() error { - level := atomic.AddInt32(&p.level, -1) - if level > 0 { - return nil - } - - for i := 0; i < 1000; i++ { - state := atomic.LoadUint32(&p.state) - if state == stateClosed { - return ErrClosed - } - if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) { - close(p.ch) - cn, ok := <-p.ch - if ok { - p.freeConn(cn) - } - return nil - } - } - - return fmt.Errorf("redis: SingleConnPool.Close: infinite loop") -} - -func (p *SingleConnPool) Reset() error { - if p.badConnError() == nil { - return nil - } - - select { - case cn, ok := <-p.ch: - if !ok { - return ErrClosed - } - p.pool.Remove(cn, ErrClosed) - p._badConnError.Store(BadConnError{wrapped: nil}) - default: - return fmt.Errorf("redis: SingleConnPool does not have a Conn") - } - - if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) { - state := atomic.LoadUint32(&p.state) - return fmt.Errorf("redis: invalid SingleConnPool state: %d", state) - } - - return nil -} - -func (p *SingleConnPool) badConnError() error { - if v := p._badConnError.Load(); v != nil { - err := v.(BadConnError) - if err.wrapped != nil { - return err - } - } - return nil -} diff --git a/internal/pool/pool_sticky.go b/internal/pool/pool_sticky.go index d4a355a4..c3e7e7c0 100644 --- a/internal/pool/pool_sticky.go +++ b/internal/pool/pool_sticky.go @@ -2,111 +2,201 @@ package pool import ( "context" - "sync" + "errors" + "fmt" + "sync/atomic" ) -type StickyConnPool struct { - pool *ConnPool - reusable bool +const ( + stateDefault = 0 + stateInited = 1 + stateClosed = 2 +) - cn *Conn - closed bool - mu sync.Mutex +type BadConnError struct { + wrapped error +} + +var _ error = (*BadConnError)(nil) + +func (e BadConnError) Error() string { + s := "redis: Conn is in a bad state" + if e.wrapped != nil { + s += ": " + e.wrapped.Error() + } + return s +} + +func (e BadConnError) Unwrap() error { + return e.wrapped +} + +//------------------------------------------------------------------------------ + +type StickyConnPool struct { + pool Pooler + shared int32 // atomic + + state uint32 // atomic + ch chan *Conn + + _badConnError atomic.Value } var _ Pooler = (*StickyConnPool)(nil) -func NewStickyConnPool(pool *ConnPool, reusable bool) *StickyConnPool { - return &StickyConnPool{ - pool: pool, - reusable: reusable, +func NewStickyConnPool(pool Pooler) *StickyConnPool { + p, ok := pool.(*StickyConnPool) + if !ok { + p = &StickyConnPool{ + pool: pool, + ch: make(chan *Conn, 1), + } } + atomic.AddInt32(&p.shared, 1) + return p } -func (p *StickyConnPool) NewConn(context.Context) (*Conn, error) { - panic("not implemented") +func (p *StickyConnPool) NewConn(ctx context.Context) (*Conn, error) { + return p.pool.NewConn(ctx) } -func (p *StickyConnPool) CloseConn(*Conn) error { - panic("not implemented") +func (p *StickyConnPool) CloseConn(cn *Conn) error { + return p.pool.CloseConn(cn) } func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) { - p.mu.Lock() - defer p.mu.Unlock() - - if p.closed { - return nil, ErrClosed + // In worst case this races with Close which is not a very common operation. + for i := 0; i < 1000; i++ { + switch atomic.LoadUint32(&p.state) { + case stateDefault: + cn, err := p.pool.Get(ctx) + if err != nil { + return nil, err + } + if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) { + return cn, nil + } + p.pool.Remove(ctx, cn, ErrClosed) + case stateInited: + if err := p.badConnError(); err != nil { + return nil, err + } + cn, ok := <-p.ch + if !ok { + return nil, ErrClosed + } + return cn, nil + case stateClosed: + return nil, ErrClosed + default: + panic("not reached") + } } - if p.cn != nil { - return p.cn, nil + return nil, fmt.Errorf("redis: StickyConnPool.Get: infinite loop") +} + +func (p *StickyConnPool) Put(ctx context.Context, cn *Conn) { + defer func() { + if recover() != nil { + p.freeConn(ctx, cn) + } + }() + p.ch <- cn +} + +func (p *StickyConnPool) freeConn(ctx context.Context, cn *Conn) { + if err := p.badConnError(); err != nil { + p.pool.Remove(ctx, cn, err) + } else { + p.pool.Put(ctx, cn) } - - cn, err := p.pool.Get(ctx) - if err != nil { - return nil, err - } - - p.cn = cn - return cn, nil } -func (p *StickyConnPool) putUpstream() { - p.pool.Put(p.cn) - p.cn = nil -} - -func (p *StickyConnPool) Put(cn *Conn) {} - -func (p *StickyConnPool) removeUpstream(reason error) { - p.pool.Remove(p.cn, reason) - p.cn = nil -} - -func (p *StickyConnPool) Remove(cn *Conn, reason error) { - p.removeUpstream(reason) -} - -func (p *StickyConnPool) Len() int { - p.mu.Lock() - defer p.mu.Unlock() - - if p.cn == nil { - return 0 - } - return 1 -} - -func (p *StickyConnPool) IdleLen() int { - p.mu.Lock() - defer p.mu.Unlock() - - if p.cn == nil { - return 1 - } - return 0 -} - -func (p *StickyConnPool) Stats() *Stats { - return nil +func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) { + defer func() { + if recover() != nil { + p.pool.Remove(ctx, cn, ErrClosed) + } + }() + p._badConnError.Store(BadConnError{wrapped: reason}) + p.ch <- cn } func (p *StickyConnPool) Close() error { - p.mu.Lock() - defer p.mu.Unlock() - - if p.closed { - return ErrClosed + if shared := atomic.AddInt32(&p.shared, -1); shared > 0 { + return nil } - p.closed = true - if p.cn != nil { - if p.reusable { - p.putUpstream() - } else { - p.removeUpstream(ErrClosed) + for i := 0; i < 1000; i++ { + state := atomic.LoadUint32(&p.state) + if state == stateClosed { + return ErrClosed } + if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) { + close(p.ch) + cn, ok := <-p.ch + if ok { + p.freeConn(context.TODO(), cn) + } + return nil + } + } + + return errors.New("redis: StickyConnPool.Close: infinite loop") +} + +func (p *StickyConnPool) Reset(ctx context.Context) error { + if p.badConnError() == nil { + return nil + } + + select { + case cn, ok := <-p.ch: + if !ok { + return ErrClosed + } + p.pool.Remove(ctx, cn, ErrClosed) + p._badConnError.Store(BadConnError{wrapped: nil}) + default: + return errors.New("redis: StickyConnPool does not have a Conn") + } + + if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) { + state := atomic.LoadUint32(&p.state) + return fmt.Errorf("redis: invalid StickyConnPool state: %d", state) } return nil } + +func (p *StickyConnPool) badConnError() error { + if v := p._badConnError.Load(); v != nil { + err := v.(BadConnError) + if err.wrapped != nil { + return err + } + } + return nil +} + +func (p *StickyConnPool) Len() int { + switch atomic.LoadUint32(&p.state) { + case stateDefault: + return 0 + case stateInited: + return 1 + case stateClosed: + return 0 + default: + panic("not reached") + } +} + +func (p *StickyConnPool) IdleLen() int { + return len(p.ch) +} + +func (p *StickyConnPool) Stats() *Stats { + return &Stats{} +} diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index cf31d345..795aef30 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -13,7 +13,7 @@ import ( ) var _ = Describe("ConnPool", func() { - c := context.Background() + ctx := context.Background() var connPool *pool.ConnPool BeforeEach(func() { @@ -32,13 +32,13 @@ var _ = Describe("ConnPool", func() { It("should unblock client when conn is removed", func() { // Reserve one connection. - cn, err := connPool.Get(c) + cn, err := connPool.Get(ctx) Expect(err).NotTo(HaveOccurred()) // Reserve all other connections. var cns []*pool.Conn for i := 0; i < 9; i++ { - cn, err := connPool.Get(c) + cn, err := connPool.Get(ctx) Expect(err).NotTo(HaveOccurred()) cns = append(cns, cn) } @@ -49,11 +49,11 @@ var _ = Describe("ConnPool", func() { defer GinkgoRecover() started <- true - _, err := connPool.Get(c) + _, err := connPool.Get(ctx) Expect(err).NotTo(HaveOccurred()) done <- true - connPool.Put(cn) + connPool.Put(ctx, cn) }() <-started @@ -65,7 +65,7 @@ var _ = Describe("ConnPool", func() { // ok } - connPool.Remove(cn, nil) + connPool.Remove(ctx, cn, nil) // Check that Get is unblocked. select { @@ -76,14 +76,14 @@ var _ = Describe("ConnPool", func() { } for _, cn := range cns { - connPool.Put(cn) + connPool.Put(ctx, cn) } }) }) var _ = Describe("MinIdleConns", func() { - c := context.Background() const poolSize = 100 + ctx := context.Background() var minIdleConns int var connPool *pool.ConnPool @@ -113,7 +113,7 @@ var _ = Describe("MinIdleConns", func() { BeforeEach(func() { var err error - cn, err = connPool.Get(c) + cn, err = connPool.Get(ctx) Expect(err).NotTo(HaveOccurred()) Eventually(func() int { @@ -128,7 +128,7 @@ var _ = Describe("MinIdleConns", func() { Context("after Remove", func() { BeforeEach(func() { - connPool.Remove(cn, nil) + connPool.Remove(ctx, cn, nil) }) It("has idle connections", func() { @@ -148,7 +148,7 @@ var _ = Describe("MinIdleConns", func() { perform(poolSize, func(_ int) { defer GinkgoRecover() - cn, err := connPool.Get(c) + cn, err := connPool.Get(ctx) Expect(err).NotTo(HaveOccurred()) mu.Lock() cns = append(cns, cn) @@ -163,7 +163,7 @@ var _ = Describe("MinIdleConns", func() { It("Get is blocked", func() { done := make(chan struct{}) go func() { - connPool.Get(c) + connPool.Get(ctx) close(done) }() @@ -186,7 +186,7 @@ var _ = Describe("MinIdleConns", func() { BeforeEach(func() { perform(len(cns), func(i int) { mu.RLock() - connPool.Put(cns[i]) + connPool.Put(ctx, cns[i]) mu.RUnlock() }) @@ -205,7 +205,7 @@ var _ = Describe("MinIdleConns", func() { BeforeEach(func() { perform(len(cns), func(i int) { mu.RLock() - connPool.Remove(cns[i], nil) + connPool.Remove(ctx, cns[i], nil) mu.RUnlock() }) @@ -250,11 +250,10 @@ var _ = Describe("MinIdleConns", func() { }) var _ = Describe("conns reaper", func() { - c := context.Background() - const idleTimeout = time.Minute const maxAge = time.Hour + ctx := context.Background() var connPool *pool.ConnPool var conns, staleConns, closedConns []*pool.Conn @@ -279,7 +278,7 @@ var _ = Describe("conns reaper", func() { // add stale connections staleConns = nil for i := 0; i < 3; i++ { - cn, err := connPool.Get(c) + cn, err := connPool.Get(ctx) Expect(err).NotTo(HaveOccurred()) switch typ { case "idle": @@ -293,13 +292,13 @@ var _ = Describe("conns reaper", func() { // add fresh connections for i := 0; i < 3; i++ { - cn, err := connPool.Get(c) + cn, err := connPool.Get(ctx) Expect(err).NotTo(HaveOccurred()) conns = append(conns, cn) } for _, cn := range conns { - connPool.Put(cn) + connPool.Put(ctx, cn) } Expect(connPool.Len()).To(Equal(6)) @@ -338,7 +337,7 @@ var _ = Describe("conns reaper", func() { for j := 0; j < 3; j++ { var freeCns []*pool.Conn for i := 0; i < 3; i++ { - cn, err := connPool.Get(c) + cn, err := connPool.Get(ctx) Expect(err).NotTo(HaveOccurred()) Expect(cn).NotTo(BeNil()) freeCns = append(freeCns, cn) @@ -347,7 +346,7 @@ var _ = Describe("conns reaper", func() { Expect(connPool.Len()).To(Equal(3)) Expect(connPool.IdleLen()).To(Equal(0)) - cn, err := connPool.Get(c) + cn, err := connPool.Get(ctx) Expect(err).NotTo(HaveOccurred()) Expect(cn).NotTo(BeNil()) conns = append(conns, cn) @@ -355,13 +354,13 @@ var _ = Describe("conns reaper", func() { Expect(connPool.Len()).To(Equal(4)) Expect(connPool.IdleLen()).To(Equal(0)) - connPool.Remove(cn, nil) + connPool.Remove(ctx, cn, nil) Expect(connPool.Len()).To(Equal(3)) Expect(connPool.IdleLen()).To(Equal(0)) for _, cn := range freeCns { - connPool.Put(cn) + connPool.Put(ctx, cn) } Expect(connPool.Len()).To(Equal(3)) @@ -375,7 +374,7 @@ var _ = Describe("conns reaper", func() { }) var _ = Describe("race", func() { - c := context.Background() + ctx := context.Background() var connPool *pool.ConnPool var C, N int @@ -402,18 +401,18 @@ var _ = Describe("race", func() { perform(C, func(id int) { for i := 0; i < N; i++ { - cn, err := connPool.Get(c) + cn, err := connPool.Get(ctx) Expect(err).NotTo(HaveOccurred()) if err == nil { - connPool.Put(cn) + connPool.Put(ctx, cn) } } }, func(id int) { for i := 0; i < N; i++ { - cn, err := connPool.Get(c) + cn, err := connPool.Get(ctx) Expect(err).NotTo(HaveOccurred()) if err == nil { - connPool.Remove(cn, nil) + connPool.Remove(ctx, cn, nil) } } }) diff --git a/pool_test.go b/pool_test.go index 4cf8d4f7..781c5df9 100644 --- a/pool_test.go +++ b/pool_test.go @@ -85,7 +85,7 @@ var _ = Describe("pool", func() { cn, err := client.Pool().Get(context.Background()) Expect(err).NotTo(HaveOccurred()) cn.SetNetConn(&badConn{}) - client.Pool().Put(cn) + client.Pool().Put(ctx, cn) err = client.Ping(ctx).Err() Expect(err).To(MatchError("bad connection")) diff --git a/redis.go b/redis.go index 9cd081a9..2894e6ab 100644 --- a/redis.go +++ b/redis.go @@ -218,7 +218,7 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { return c.initConn(ctx, cn) }) if err != nil { - c.connPool.Remove(cn, err) + c.connPool.Remove(ctx, cn, err) if err := internal.Unwrap(err); err != nil { return nil, err } @@ -241,8 +241,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { return nil } - connPool := pool.NewSingleConnPool(nil) - connPool.SetConn(cn) + connPool := pool.NewSingleConnPool(c.connPool, cn) conn := newConn(ctx, c.opt, connPool) _, err := conn.Pipelined(ctx, func(pipe Pipeliner) error { @@ -274,15 +273,15 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { return nil } -func (c *baseClient) releaseConn(cn *pool.Conn, err error) { +func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) { if c.opt.Limiter != nil { c.opt.Limiter.ReportResult(err) } if isBadConn(err, false) { - c.connPool.Remove(cn, err) + c.connPool.Remove(ctx, cn, err) } else { - c.connPool.Put(cn) + c.connPool.Put(ctx, cn) } } @@ -295,7 +294,7 @@ func (c *baseClient) withConn( return err } defer func() { - c.releaseConn(cn, err) + c.releaseConn(ctx, cn, err) }() err = fn(ctx, cn) @@ -585,7 +584,7 @@ func (c *Client) WithContext(ctx context.Context) *Client { } func (c *Client) Conn(ctx context.Context) *Conn { - return newConn(ctx, c.opt, pool.NewSingleConnPool(c.connPool)) + return newConn(ctx, c.opt, pool.NewStickyConnPool(c.connPool)) } // Do creates a Cmd from the args and processes the cmd. diff --git a/redis_test.go b/redis_test.go index 27ce3517..45cfac7d 100644 --- a/redis_test.go +++ b/redis_test.go @@ -206,7 +206,7 @@ var _ = Describe("Client", func() { Expect(err).NotTo(HaveOccurred()) cn.SetNetConn(&badConn{}) - client.Pool().Put(cn) + client.Pool().Put(ctx, cn) err = client.Ping(ctx).Err() Expect(err).NotTo(HaveOccurred()) @@ -245,7 +245,7 @@ var _ = Describe("Client", func() { Expect(cn.UsedAt).NotTo(BeZero()) createdAt := cn.UsedAt() - client.Pool().Put(cn) + client.Pool().Put(ctx, cn) Expect(cn.UsedAt().Equal(createdAt)).To(BeTrue()) time.Sleep(time.Second) diff --git a/tx.go b/tx.go index 6da04342..ad825c61 100644 --- a/tx.go +++ b/tx.go @@ -26,7 +26,7 @@ func (c *Client) newTx(ctx context.Context) *Tx { tx := Tx{ baseClient: baseClient{ opt: c.opt, - connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), true), + connPool: pool.NewStickyConnPool(c.connPool), }, hooks: c.hooks.clone(), ctx: ctx, diff --git a/tx_test.go b/tx_test.go index f10cf218..4681122a 100644 --- a/tx_test.go +++ b/tx_test.go @@ -129,7 +129,7 @@ var _ = Describe("Tx", func() { Expect(err).NotTo(HaveOccurred()) cn.SetNetConn(&badConn{}) - client.Pool().Put(cn) + client.Pool().Put(ctx, cn) do := func() error { err := client.Watch(ctx, func(tx *redis.Tx) error {