Merge pull request #1443 from go-redis/fix/pool-panics

Port pool fixes
This commit is contained in:
Vladimir Mihailenco 2020-08-15 16:20:26 +03:00 committed by GitHub
commit 3eb3a1da7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 244 additions and 303 deletions

View File

@ -18,6 +18,7 @@ func (bm poolGetPutBenchmark) String() string {
} }
func BenchmarkPoolGetPut(b *testing.B) { func BenchmarkPoolGetPut(b *testing.B) {
ctx := context.Background()
benchmarks := []poolGetPutBenchmark{ benchmarks := []poolGetPutBenchmark{
{1}, {1},
{2}, {2},
@ -40,11 +41,11 @@ func BenchmarkPoolGetPut(b *testing.B) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
cn, err := connPool.Get(context.Background()) cn, err := connPool.Get(ctx)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
connPool.Put(cn) connPool.Put(ctx, cn)
} }
}) })
}) })
@ -60,6 +61,7 @@ func (bm poolGetRemoveBenchmark) String() string {
} }
func BenchmarkPoolGetRemove(b *testing.B) { func BenchmarkPoolGetRemove(b *testing.B) {
ctx := context.Background()
benchmarks := []poolGetRemoveBenchmark{ benchmarks := []poolGetRemoveBenchmark{
{1}, {1},
{2}, {2},
@ -68,6 +70,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
{64}, {64},
{128}, {128},
} }
for _, bm := range benchmarks { for _, bm := range benchmarks {
b.Run(bm.String(), func(b *testing.B) { b.Run(bm.String(), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{ connPool := pool.NewConnPool(&pool.Options{
@ -82,11 +85,11 @@ func BenchmarkPoolGetRemove(b *testing.B) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
cn, err := connPool.Get(context.Background()) cn, err := connPool.Get(ctx)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
connPool.Remove(cn, nil) connPool.Remove(ctx, cn, nil)
} }
}) })
}) })

View File

@ -40,8 +40,8 @@ type Pooler interface {
CloseConn(*Conn) error CloseConn(*Conn) error
Get(context.Context) (*Conn, error) Get(context.Context) (*Conn, error)
Put(*Conn) Put(context.Context, *Conn)
Remove(*Conn, error) Remove(context.Context, *Conn, error)
Len() int Len() int
IdleLen() int IdleLen() int
@ -318,15 +318,15 @@ func (p *ConnPool) popIdle() *Conn {
return cn return cn
} }
func (p *ConnPool) Put(cn *Conn) { func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
if cn.rd.Buffered() > 0 { if cn.rd.Buffered() > 0 {
internal.Logger.Printf(context.Background(), "Conn has unread data") internal.Logger.Printf(ctx, "Conn has unread data")
p.Remove(cn, BadConnError{}) p.Remove(ctx, cn, BadConnError{})
return return
} }
if !cn.pooled { if !cn.pooled {
p.Remove(cn, nil) p.Remove(ctx, cn, nil)
return return
} }
@ -337,7 +337,7 @@ func (p *ConnPool) Put(cn *Conn) {
p.freeTurn() p.freeTurn()
} }
func (p *ConnPool) Remove(cn *Conn, reason error) { func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
p.removeConnWithLock(cn) p.removeConnWithLock(cn)
p.freeTurn() p.freeTurn()
_ = p.closeConn(cn) _ = p.closeConn(cn)

View File

@ -1,64 +1,19 @@
package pool package pool
import ( import "context"
"context"
"fmt"
"sync/atomic"
)
const (
stateDefault = 0
stateInited = 1
stateClosed = 2
)
type BadConnError struct {
wrapped error
}
var _ error = (*BadConnError)(nil)
func (e BadConnError) Error() string {
s := "redis: Conn is in a bad state"
if e.wrapped != nil {
s += ": " + e.wrapped.Error()
}
return s
}
func (e BadConnError) Unwrap() error {
return e.wrapped
}
type SingleConnPool struct { type SingleConnPool struct {
pool Pooler pool Pooler
level int32 // atomic cn *Conn
stickyErr error
state uint32 // atomic
ch chan *Conn
_badConnError atomic.Value
} }
var _ Pooler = (*SingleConnPool)(nil) var _ Pooler = (*SingleConnPool)(nil)
func NewSingleConnPool(pool Pooler) *SingleConnPool { func NewSingleConnPool(pool Pooler, cn *Conn) *SingleConnPool {
p, ok := pool.(*SingleConnPool) return &SingleConnPool{
if !ok {
p = &SingleConnPool{
pool: pool, pool: pool,
ch: make(chan *Conn, 1), cn: cn,
}
}
atomic.AddInt32(&p.level, 1)
return p
}
func (p *SingleConnPool) SetConn(cn *Conn) {
if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) {
p.ch <- cn
} else {
panic("not reached")
} }
} }
@ -71,138 +26,33 @@ func (p *SingleConnPool) CloseConn(cn *Conn) error {
} }
func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) { func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) {
// In worst case this races with Close which is not a very common operation. if p.stickyErr != nil {
for i := 0; i < 1000; i++ { return nil, p.stickyErr
switch atomic.LoadUint32(&p.state) {
case stateDefault:
cn, err := p.pool.Get(ctx)
if err != nil {
return nil, err
} }
if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) { return p.cn, nil
return cn, nil
}
p.pool.Remove(cn, ErrClosed)
case stateInited:
if err := p.badConnError(); err != nil {
return nil, err
}
cn, ok := <-p.ch
if !ok {
return nil, ErrClosed
}
return cn, nil
case stateClosed:
return nil, ErrClosed
default:
panic("not reached")
}
}
return nil, fmt.Errorf("redis: SingleConnPool.Get: infinite loop")
} }
func (p *SingleConnPool) Put(cn *Conn) { func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) {}
defer func() {
if recover() != nil { func (p *SingleConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
p.freeConn(cn) p.cn = nil
} p.stickyErr = reason
}()
p.ch <- cn
} }
func (p *SingleConnPool) freeConn(cn *Conn) { func (p *SingleConnPool) Close() error {
if err := p.badConnError(); err != nil { p.cn = nil
p.pool.Remove(cn, err) p.stickyErr = ErrClosed
} else { return nil
p.pool.Put(cn)
}
}
func (p *SingleConnPool) Remove(cn *Conn, reason error) {
defer func() {
if recover() != nil {
p.pool.Remove(cn, ErrClosed)
}
}()
p._badConnError.Store(BadConnError{wrapped: reason})
p.ch <- cn
} }
func (p *SingleConnPool) Len() int { func (p *SingleConnPool) Len() int {
switch atomic.LoadUint32(&p.state) {
case stateDefault:
return 0 return 0
case stateInited:
return 1
case stateClosed:
return 0
default:
panic("not reached")
}
} }
func (p *SingleConnPool) IdleLen() int { func (p *SingleConnPool) IdleLen() int {
return len(p.ch) return 0
} }
func (p *SingleConnPool) Stats() *Stats { func (p *SingleConnPool) Stats() *Stats {
return &Stats{} return &Stats{}
} }
func (p *SingleConnPool) Close() error {
level := atomic.AddInt32(&p.level, -1)
if level > 0 {
return nil
}
for i := 0; i < 1000; i++ {
state := atomic.LoadUint32(&p.state)
if state == stateClosed {
return ErrClosed
}
if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) {
close(p.ch)
cn, ok := <-p.ch
if ok {
p.freeConn(cn)
}
return nil
}
}
return fmt.Errorf("redis: SingleConnPool.Close: infinite loop")
}
func (p *SingleConnPool) Reset() error {
if p.badConnError() == nil {
return nil
}
select {
case cn, ok := <-p.ch:
if !ok {
return ErrClosed
}
p.pool.Remove(cn, ErrClosed)
p._badConnError.Store(BadConnError{wrapped: nil})
default:
return fmt.Errorf("redis: SingleConnPool does not have a Conn")
}
if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) {
state := atomic.LoadUint32(&p.state)
return fmt.Errorf("redis: invalid SingleConnPool state: %d", state)
}
return nil
}
func (p *SingleConnPool) badConnError() error {
if v := p._badConnError.Load(); v != nil {
err := v.(BadConnError)
if err.wrapped != nil {
return err
}
}
return nil
}

View File

@ -2,111 +2,201 @@ package pool
import ( import (
"context" "context"
"sync" "errors"
"fmt"
"sync/atomic"
) )
type StickyConnPool struct { const (
pool *ConnPool stateDefault = 0
reusable bool stateInited = 1
stateClosed = 2
)
cn *Conn type BadConnError struct {
closed bool wrapped error
mu sync.Mutex }
var _ error = (*BadConnError)(nil)
func (e BadConnError) Error() string {
s := "redis: Conn is in a bad state"
if e.wrapped != nil {
s += ": " + e.wrapped.Error()
}
return s
}
func (e BadConnError) Unwrap() error {
return e.wrapped
}
//------------------------------------------------------------------------------
type StickyConnPool struct {
pool Pooler
shared int32 // atomic
state uint32 // atomic
ch chan *Conn
_badConnError atomic.Value
} }
var _ Pooler = (*StickyConnPool)(nil) var _ Pooler = (*StickyConnPool)(nil)
func NewStickyConnPool(pool *ConnPool, reusable bool) *StickyConnPool { func NewStickyConnPool(pool Pooler) *StickyConnPool {
return &StickyConnPool{ p, ok := pool.(*StickyConnPool)
if !ok {
p = &StickyConnPool{
pool: pool, pool: pool,
reusable: reusable, ch: make(chan *Conn, 1),
} }
} }
atomic.AddInt32(&p.shared, 1)
return p
}
func (p *StickyConnPool) NewConn(context.Context) (*Conn, error) { func (p *StickyConnPool) NewConn(ctx context.Context) (*Conn, error) {
panic("not implemented") return p.pool.NewConn(ctx)
} }
func (p *StickyConnPool) CloseConn(*Conn) error { func (p *StickyConnPool) CloseConn(cn *Conn) error {
panic("not implemented") return p.pool.CloseConn(cn)
} }
func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) { func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) {
p.mu.Lock() // In worst case this races with Close which is not a very common operation.
defer p.mu.Unlock() for i := 0; i < 1000; i++ {
switch atomic.LoadUint32(&p.state) {
if p.closed { case stateDefault:
return nil, ErrClosed
}
if p.cn != nil {
return p.cn, nil
}
cn, err := p.pool.Get(ctx) cn, err := p.pool.Get(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) {
p.cn = cn
return cn, nil return cn, nil
} }
p.pool.Remove(ctx, cn, ErrClosed)
func (p *StickyConnPool) putUpstream() { case stateInited:
p.pool.Put(p.cn) if err := p.badConnError(); err != nil {
p.cn = nil return nil, err
}
cn, ok := <-p.ch
if !ok {
return nil, ErrClosed
}
return cn, nil
case stateClosed:
return nil, ErrClosed
default:
panic("not reached")
}
}
return nil, fmt.Errorf("redis: StickyConnPool.Get: infinite loop")
} }
func (p *StickyConnPool) Put(cn *Conn) {} func (p *StickyConnPool) Put(ctx context.Context, cn *Conn) {
defer func() {
func (p *StickyConnPool) removeUpstream(reason error) { if recover() != nil {
p.pool.Remove(p.cn, reason) p.freeConn(ctx, cn)
p.cn = nil }
}()
p.ch <- cn
} }
func (p *StickyConnPool) Remove(cn *Conn, reason error) { func (p *StickyConnPool) freeConn(ctx context.Context, cn *Conn) {
p.removeUpstream(reason) if err := p.badConnError(); err != nil {
p.pool.Remove(ctx, cn, err)
} else {
p.pool.Put(ctx, cn)
}
} }
func (p *StickyConnPool) Len() int { func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
p.mu.Lock() defer func() {
defer p.mu.Unlock() if recover() != nil {
p.pool.Remove(ctx, cn, ErrClosed)
if p.cn == nil {
return 0
} }
return 1 }()
} p._badConnError.Store(BadConnError{wrapped: reason})
p.ch <- cn
func (p *StickyConnPool) IdleLen() int {
p.mu.Lock()
defer p.mu.Unlock()
if p.cn == nil {
return 1
}
return 0
}
func (p *StickyConnPool) Stats() *Stats {
return nil
} }
func (p *StickyConnPool) Close() error { func (p *StickyConnPool) Close() error {
p.mu.Lock() if shared := atomic.AddInt32(&p.shared, -1); shared > 0 {
defer p.mu.Unlock() return nil
}
if p.closed { for i := 0; i < 1000; i++ {
state := atomic.LoadUint32(&p.state)
if state == stateClosed {
return ErrClosed return ErrClosed
} }
p.closed = true if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) {
close(p.ch)
if p.cn != nil { cn, ok := <-p.ch
if p.reusable { if ok {
p.putUpstream() p.freeConn(context.TODO(), cn)
} else {
p.removeUpstream(ErrClosed)
} }
return nil
}
}
return errors.New("redis: StickyConnPool.Close: infinite loop")
}
func (p *StickyConnPool) Reset(ctx context.Context) error {
if p.badConnError() == nil {
return nil
}
select {
case cn, ok := <-p.ch:
if !ok {
return ErrClosed
}
p.pool.Remove(ctx, cn, ErrClosed)
p._badConnError.Store(BadConnError{wrapped: nil})
default:
return errors.New("redis: StickyConnPool does not have a Conn")
}
if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) {
state := atomic.LoadUint32(&p.state)
return fmt.Errorf("redis: invalid StickyConnPool state: %d", state)
} }
return nil return nil
} }
func (p *StickyConnPool) badConnError() error {
if v := p._badConnError.Load(); v != nil {
err := v.(BadConnError)
if err.wrapped != nil {
return err
}
}
return nil
}
func (p *StickyConnPool) Len() int {
switch atomic.LoadUint32(&p.state) {
case stateDefault:
return 0
case stateInited:
return 1
case stateClosed:
return 0
default:
panic("not reached")
}
}
func (p *StickyConnPool) IdleLen() int {
return len(p.ch)
}
func (p *StickyConnPool) Stats() *Stats {
return &Stats{}
}

View File

@ -13,7 +13,7 @@ import (
) )
var _ = Describe("ConnPool", func() { var _ = Describe("ConnPool", func() {
c := context.Background() ctx := context.Background()
var connPool *pool.ConnPool var connPool *pool.ConnPool
BeforeEach(func() { BeforeEach(func() {
@ -32,13 +32,13 @@ var _ = Describe("ConnPool", func() {
It("should unblock client when conn is removed", func() { It("should unblock client when conn is removed", func() {
// Reserve one connection. // Reserve one connection.
cn, err := connPool.Get(c) cn, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
// Reserve all other connections. // Reserve all other connections.
var cns []*pool.Conn var cns []*pool.Conn
for i := 0; i < 9; i++ { for i := 0; i < 9; i++ {
cn, err := connPool.Get(c) cn, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cns = append(cns, cn) cns = append(cns, cn)
} }
@ -49,11 +49,11 @@ var _ = Describe("ConnPool", func() {
defer GinkgoRecover() defer GinkgoRecover()
started <- true started <- true
_, err := connPool.Get(c) _, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
done <- true done <- true
connPool.Put(cn) connPool.Put(ctx, cn)
}() }()
<-started <-started
@ -65,7 +65,7 @@ var _ = Describe("ConnPool", func() {
// ok // ok
} }
connPool.Remove(cn, nil) connPool.Remove(ctx, cn, nil)
// Check that Get is unblocked. // Check that Get is unblocked.
select { select {
@ -76,14 +76,14 @@ var _ = Describe("ConnPool", func() {
} }
for _, cn := range cns { for _, cn := range cns {
connPool.Put(cn) connPool.Put(ctx, cn)
} }
}) })
}) })
var _ = Describe("MinIdleConns", func() { var _ = Describe("MinIdleConns", func() {
c := context.Background()
const poolSize = 100 const poolSize = 100
ctx := context.Background()
var minIdleConns int var minIdleConns int
var connPool *pool.ConnPool var connPool *pool.ConnPool
@ -113,7 +113,7 @@ var _ = Describe("MinIdleConns", func() {
BeforeEach(func() { BeforeEach(func() {
var err error var err error
cn, err = connPool.Get(c) cn, err = connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Eventually(func() int { Eventually(func() int {
@ -128,7 +128,7 @@ var _ = Describe("MinIdleConns", func() {
Context("after Remove", func() { Context("after Remove", func() {
BeforeEach(func() { BeforeEach(func() {
connPool.Remove(cn, nil) connPool.Remove(ctx, cn, nil)
}) })
It("has idle connections", func() { It("has idle connections", func() {
@ -148,7 +148,7 @@ var _ = Describe("MinIdleConns", func() {
perform(poolSize, func(_ int) { perform(poolSize, func(_ int) {
defer GinkgoRecover() defer GinkgoRecover()
cn, err := connPool.Get(c) cn, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
mu.Lock() mu.Lock()
cns = append(cns, cn) cns = append(cns, cn)
@ -163,7 +163,7 @@ var _ = Describe("MinIdleConns", func() {
It("Get is blocked", func() { It("Get is blocked", func() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
connPool.Get(c) connPool.Get(ctx)
close(done) close(done)
}() }()
@ -186,7 +186,7 @@ var _ = Describe("MinIdleConns", func() {
BeforeEach(func() { BeforeEach(func() {
perform(len(cns), func(i int) { perform(len(cns), func(i int) {
mu.RLock() mu.RLock()
connPool.Put(cns[i]) connPool.Put(ctx, cns[i])
mu.RUnlock() mu.RUnlock()
}) })
@ -205,7 +205,7 @@ var _ = Describe("MinIdleConns", func() {
BeforeEach(func() { BeforeEach(func() {
perform(len(cns), func(i int) { perform(len(cns), func(i int) {
mu.RLock() mu.RLock()
connPool.Remove(cns[i], nil) connPool.Remove(ctx, cns[i], nil)
mu.RUnlock() mu.RUnlock()
}) })
@ -250,11 +250,10 @@ var _ = Describe("MinIdleConns", func() {
}) })
var _ = Describe("conns reaper", func() { var _ = Describe("conns reaper", func() {
c := context.Background()
const idleTimeout = time.Minute const idleTimeout = time.Minute
const maxAge = time.Hour const maxAge = time.Hour
ctx := context.Background()
var connPool *pool.ConnPool var connPool *pool.ConnPool
var conns, staleConns, closedConns []*pool.Conn var conns, staleConns, closedConns []*pool.Conn
@ -279,7 +278,7 @@ var _ = Describe("conns reaper", func() {
// add stale connections // add stale connections
staleConns = nil staleConns = nil
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
cn, err := connPool.Get(c) cn, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
switch typ { switch typ {
case "idle": case "idle":
@ -293,13 +292,13 @@ var _ = Describe("conns reaper", func() {
// add fresh connections // add fresh connections
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
cn, err := connPool.Get(c) cn, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
conns = append(conns, cn) conns = append(conns, cn)
} }
for _, cn := range conns { for _, cn := range conns {
connPool.Put(cn) connPool.Put(ctx, cn)
} }
Expect(connPool.Len()).To(Equal(6)) Expect(connPool.Len()).To(Equal(6))
@ -338,7 +337,7 @@ var _ = Describe("conns reaper", func() {
for j := 0; j < 3; j++ { for j := 0; j < 3; j++ {
var freeCns []*pool.Conn var freeCns []*pool.Conn
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
cn, err := connPool.Get(c) cn, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil()) Expect(cn).NotTo(BeNil())
freeCns = append(freeCns, cn) freeCns = append(freeCns, cn)
@ -347,7 +346,7 @@ var _ = Describe("conns reaper", func() {
Expect(connPool.Len()).To(Equal(3)) Expect(connPool.Len()).To(Equal(3))
Expect(connPool.IdleLen()).To(Equal(0)) Expect(connPool.IdleLen()).To(Equal(0))
cn, err := connPool.Get(c) cn, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil()) Expect(cn).NotTo(BeNil())
conns = append(conns, cn) conns = append(conns, cn)
@ -355,13 +354,13 @@ var _ = Describe("conns reaper", func() {
Expect(connPool.Len()).To(Equal(4)) Expect(connPool.Len()).To(Equal(4))
Expect(connPool.IdleLen()).To(Equal(0)) Expect(connPool.IdleLen()).To(Equal(0))
connPool.Remove(cn, nil) connPool.Remove(ctx, cn, nil)
Expect(connPool.Len()).To(Equal(3)) Expect(connPool.Len()).To(Equal(3))
Expect(connPool.IdleLen()).To(Equal(0)) Expect(connPool.IdleLen()).To(Equal(0))
for _, cn := range freeCns { for _, cn := range freeCns {
connPool.Put(cn) connPool.Put(ctx, cn)
} }
Expect(connPool.Len()).To(Equal(3)) Expect(connPool.Len()).To(Equal(3))
@ -375,7 +374,7 @@ var _ = Describe("conns reaper", func() {
}) })
var _ = Describe("race", func() { var _ = Describe("race", func() {
c := context.Background() ctx := context.Background()
var connPool *pool.ConnPool var connPool *pool.ConnPool
var C, N int var C, N int
@ -402,18 +401,18 @@ var _ = Describe("race", func() {
perform(C, func(id int) { perform(C, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
cn, err := connPool.Get(c) cn, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
if err == nil { if err == nil {
connPool.Put(cn) connPool.Put(ctx, cn)
} }
} }
}, func(id int) { }, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
cn, err := connPool.Get(c) cn, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
if err == nil { if err == nil {
connPool.Remove(cn, nil) connPool.Remove(ctx, cn, nil)
} }
} }
}) })

View File

@ -85,7 +85,7 @@ var _ = Describe("pool", func() {
cn, err := client.Pool().Get(context.Background()) cn, err := client.Pool().Get(context.Background())
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{}) cn.SetNetConn(&badConn{})
client.Pool().Put(cn) client.Pool().Put(ctx, cn)
err = client.Ping(ctx).Err() err = client.Ping(ctx).Err()
Expect(err).To(MatchError("bad connection")) Expect(err).To(MatchError("bad connection"))

View File

@ -218,7 +218,7 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
return c.initConn(ctx, cn) return c.initConn(ctx, cn)
}) })
if err != nil { if err != nil {
c.connPool.Remove(cn, err) c.connPool.Remove(ctx, cn, err)
if err := internal.Unwrap(err); err != nil { if err := internal.Unwrap(err); err != nil {
return nil, err return nil, err
} }
@ -241,8 +241,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
return nil return nil
} }
connPool := pool.NewSingleConnPool(nil) connPool := pool.NewSingleConnPool(c.connPool, cn)
connPool.SetConn(cn)
conn := newConn(ctx, c.opt, connPool) conn := newConn(ctx, c.opt, connPool)
_, err := conn.Pipelined(ctx, func(pipe Pipeliner) error { _, err := conn.Pipelined(ctx, func(pipe Pipeliner) error {
@ -274,15 +273,15 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
return nil return nil
} }
func (c *baseClient) releaseConn(cn *pool.Conn, err error) { func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) {
if c.opt.Limiter != nil { if c.opt.Limiter != nil {
c.opt.Limiter.ReportResult(err) c.opt.Limiter.ReportResult(err)
} }
if isBadConn(err, false) { if isBadConn(err, false) {
c.connPool.Remove(cn, err) c.connPool.Remove(ctx, cn, err)
} else { } else {
c.connPool.Put(cn) c.connPool.Put(ctx, cn)
} }
} }
@ -295,7 +294,7 @@ func (c *baseClient) withConn(
return err return err
} }
defer func() { defer func() {
c.releaseConn(cn, err) c.releaseConn(ctx, cn, err)
}() }()
err = fn(ctx, cn) err = fn(ctx, cn)
@ -585,7 +584,7 @@ func (c *Client) WithContext(ctx context.Context) *Client {
} }
func (c *Client) Conn(ctx context.Context) *Conn { func (c *Client) Conn(ctx context.Context) *Conn {
return newConn(ctx, c.opt, pool.NewSingleConnPool(c.connPool)) return newConn(ctx, c.opt, pool.NewStickyConnPool(c.connPool))
} }
// Do creates a Cmd from the args and processes the cmd. // Do creates a Cmd from the args and processes the cmd.

View File

@ -206,7 +206,7 @@ var _ = Describe("Client", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{}) cn.SetNetConn(&badConn{})
client.Pool().Put(cn) client.Pool().Put(ctx, cn)
err = client.Ping(ctx).Err() err = client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -245,7 +245,7 @@ var _ = Describe("Client", func() {
Expect(cn.UsedAt).NotTo(BeZero()) Expect(cn.UsedAt).NotTo(BeZero())
createdAt := cn.UsedAt() createdAt := cn.UsedAt()
client.Pool().Put(cn) client.Pool().Put(ctx, cn)
Expect(cn.UsedAt().Equal(createdAt)).To(BeTrue()) Expect(cn.UsedAt().Equal(createdAt)).To(BeTrue())
time.Sleep(time.Second) time.Sleep(time.Second)

2
tx.go
View File

@ -26,7 +26,7 @@ func (c *Client) newTx(ctx context.Context) *Tx {
tx := Tx{ tx := Tx{
baseClient: baseClient{ baseClient: baseClient{
opt: c.opt, opt: c.opt,
connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), true), connPool: pool.NewStickyConnPool(c.connPool),
}, },
hooks: c.hooks.clone(), hooks: c.hooks.clone(),
ctx: ctx, ctx: ctx,

View File

@ -129,7 +129,7 @@ var _ = Describe("Tx", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{}) cn.SetNetConn(&badConn{})
client.Pool().Put(cn) client.Pool().Put(ctx, cn)
do := func() error { do := func() error {
err := client.Watch(ctx, func(tx *redis.Tx) error { err := client.Watch(ctx, func(tx *redis.Tx) error {