diff --git a/go.mod b/go.mod index 57f2bc0..eb4969d 100644 --- a/go.mod +++ b/go.mod @@ -4,3 +4,5 @@ require ( github.com/onsi/ginkgo v1.8.0 github.com/onsi/gomega v1.5.0 ) + +go 1.13 diff --git a/internal/pool/bench_test.go b/internal/pool/bench_test.go index 54fbf91..c60c66e 100644 --- a/internal/pool/bench_test.go +++ b/internal/pool/bench_test.go @@ -85,7 +85,7 @@ func BenchmarkPoolGetRemove(b *testing.B) { if err != nil { b.Fatal(err) } - connPool.Remove(cn) + connPool.Remove(cn, nil) } }) }) diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 88b059c..cd4c8d6 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -38,7 +38,7 @@ type Pooler interface { Get() (*Conn, error) Put(*Conn) - Remove(*Conn) + Remove(*Conn, error) Len() int IdleLen() int @@ -289,7 +289,7 @@ func (p *ConnPool) popIdle() *Conn { func (p *ConnPool) Put(cn *Conn) { if !cn.pooled { - p.Remove(cn) + p.Remove(cn, nil) return } @@ -300,7 +300,7 @@ func (p *ConnPool) Put(cn *Conn) { p.freeTurn() } -func (p *ConnPool) Remove(cn *Conn) { +func (p *ConnPool) Remove(cn *Conn, reason error) { p.removeConn(cn) p.freeTurn() _ = p.closeConn(cn) diff --git a/internal/pool/pool_single.go b/internal/pool/pool_single.go index b35b78a..cd0289b 100644 --- a/internal/pool/pool_single.go +++ b/internal/pool/pool_single.go @@ -1,53 +1,203 @@ package pool +import ( + "fmt" + "sync/atomic" +) + +const ( + stateDefault = 0 + stateInited = 1 + stateClosed = 2 +) + +type BadConnError struct { + wrapped error +} + +var _ error = (*BadConnError)(nil) + +func (e BadConnError) Error() string { + return "pg: Conn is in a bad state" +} + +func (e BadConnError) Unwrap() error { + return e.wrapped +} + type SingleConnPool struct { - cn *Conn + pool Pooler + level int32 // atomic + + state uint32 // atomic + ch chan *Conn + + _badConnError atomic.Value } var _ Pooler = (*SingleConnPool)(nil) -func NewSingleConnPool(cn *Conn) *SingleConnPool { - return &SingleConnPool{ - cn: cn, +func NewSingleConnPool(pool Pooler) *SingleConnPool { + p, ok := pool.(*SingleConnPool) + if !ok { + p = &SingleConnPool{ + pool: pool, + ch: make(chan *Conn, 1), + } + } + atomic.AddInt32(&p.level, 1) + return p +} + +func (p *SingleConnPool) SetConn(cn *Conn) { + if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) { + p.ch <- cn + } else { + panic("not reached") } } func (p *SingleConnPool) NewConn() (*Conn, error) { - panic("not implemented") + return p.pool.NewConn() } -func (p *SingleConnPool) CloseConn(*Conn) error { - panic("not implemented") +func (p *SingleConnPool) CloseConn(cn *Conn) error { + return p.pool.CloseConn(cn) } func (p *SingleConnPool) Get() (*Conn, error) { - return p.cn, nil + // In worst case this races with Close which is not a very common operation. + for i := 0; i < 1000; i++ { + switch atomic.LoadUint32(&p.state) { + case stateDefault: + cn, err := p.pool.Get() + if err != nil { + return nil, err + } + if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) { + return cn, nil + } + p.pool.Remove(cn, ErrClosed) + case stateInited: + if err := p.badConnError(); err != nil { + return nil, err + } + cn, ok := <-p.ch + if !ok { + return nil, ErrClosed + } + return cn, nil + case stateClosed: + return nil, ErrClosed + default: + panic("not reached") + } + } + return nil, fmt.Errorf("pg: SingleConnPool.Get: infinite loop") } func (p *SingleConnPool) Put(cn *Conn) { - if p.cn != cn { - panic("p.cn != cn") + defer func() { + if recover() != nil { + p.freeConn(cn) + } + }() + p.ch <- cn +} + +func (p *SingleConnPool) freeConn(cn *Conn) { + if err := p.badConnError(); err != nil { + p.pool.Remove(cn, err) + } else { + p.pool.Put(cn) } } -func (p *SingleConnPool) Remove(cn *Conn) { - if p.cn != cn { - panic("p.cn != cn") - } +func (p *SingleConnPool) Remove(cn *Conn, reason error) { + defer func() { + if recover() != nil { + p.pool.Remove(cn, ErrClosed) + } + }() + p._badConnError.Store(BadConnError{wrapped: reason}) + p.ch <- cn } func (p *SingleConnPool) Len() int { - return 1 + switch atomic.LoadUint32(&p.state) { + case stateDefault: + return 0 + case stateInited: + return 1 + case stateClosed: + return 0 + default: + panic("not reached") + } } func (p *SingleConnPool) IdleLen() int { - return 0 + return len(p.ch) } func (p *SingleConnPool) Stats() *Stats { - return nil + return &Stats{} } func (p *SingleConnPool) Close() error { + level := atomic.AddInt32(&p.level, -1) + if level > 0 { + return nil + } + + for i := 0; i < 1000; i++ { + state := atomic.LoadUint32(&p.state) + if state == stateClosed { + return ErrClosed + } + if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) { + close(p.ch) + cn, ok := <-p.ch + if ok { + p.freeConn(cn) + } + return nil + } + } + + return fmt.Errorf("pg: SingleConnPool.Close: infinite loop") +} + +func (p *SingleConnPool) Reset() error { + if p.badConnError() == nil { + return nil + } + + select { + case cn, ok := <-p.ch: + if !ok { + return ErrClosed + } + p.pool.Remove(cn, ErrClosed) + p._badConnError.Store(BadConnError{wrapped: nil}) + default: + return fmt.Errorf("pg: SingleConnPool does not have a Conn") + } + + if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) { + state := atomic.LoadUint32(&p.state) + return fmt.Errorf("pg: invalid SingleConnPool state: %d", state) + } + + return nil +} + +func (p *SingleConnPool) badConnError() error { + if v := p._badConnError.Load(); v != nil { + err := v.(BadConnError) + if err.wrapped != nil { + return err + } + } return nil } diff --git a/internal/pool/pool_sticky.go b/internal/pool/pool_sticky.go index 91bd913..3e8f503 100644 --- a/internal/pool/pool_sticky.go +++ b/internal/pool/pool_sticky.go @@ -55,13 +55,13 @@ func (p *StickyConnPool) putUpstream() { func (p *StickyConnPool) Put(cn *Conn) {} -func (p *StickyConnPool) removeUpstream() { - p.pool.Remove(p.cn) +func (p *StickyConnPool) removeUpstream(reason error) { + p.pool.Remove(p.cn, reason) p.cn = nil } -func (p *StickyConnPool) Remove(cn *Conn) { - p.removeUpstream() +func (p *StickyConnPool) Remove(cn *Conn, reason error) { + p.removeUpstream(reason) } func (p *StickyConnPool) Len() int { @@ -101,7 +101,7 @@ func (p *StickyConnPool) Close() error { if p.reusable { p.putUpstream() } else { - p.removeUpstream() + p.removeUpstream(ErrClosed) } } diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index 07fb48a..ae50ac9 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -63,7 +63,7 @@ var _ = Describe("ConnPool", func() { // ok } - connPool.Remove(cn) + connPool.Remove(cn, nil) // Check that Get is unblocked. select { @@ -125,7 +125,7 @@ var _ = Describe("MinIdleConns", func() { Context("after Remove", func() { BeforeEach(func() { - connPool.Remove(cn) + connPool.Remove(cn, nil) }) It("has idle connections", func() { @@ -202,7 +202,7 @@ var _ = Describe("MinIdleConns", func() { BeforeEach(func() { perform(len(cns), func(i int) { mu.RLock() - connPool.Remove(cns[i]) + connPool.Remove(cns[i], nil) mu.RUnlock() }) @@ -350,7 +350,7 @@ var _ = Describe("conns reaper", func() { Expect(connPool.Len()).To(Equal(4)) Expect(connPool.IdleLen()).To(Equal(0)) - connPool.Remove(cn) + connPool.Remove(cn, nil) Expect(connPool.Len()).To(Equal(3)) Expect(connPool.IdleLen()).To(Equal(0)) @@ -407,7 +407,7 @@ var _ = Describe("race", func() { cn, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) if err == nil { - connPool.Remove(cn) + connPool.Remove(cn, nil) } } }) diff --git a/internal/util.go b/internal/util.go index ffd2353..80a6003 100644 --- a/internal/util.go +++ b/internal/util.go @@ -27,3 +27,13 @@ func isLower(s string) bool { } return true } + +func Unwrap(err error) error { + u, ok := err.(interface { + Unwrap() error + }) + if !ok { + return nil + } + return u.Unwrap() +} diff --git a/redis.go b/redis.go index a767376..2a6013c 100644 --- a/redis.go +++ b/redis.go @@ -86,7 +86,10 @@ func (c *baseClient) _getConn() (*pool.Conn, error) { err = c.initConn(cn) if err != nil { - c.connPool.Remove(cn) + c.connPool.Remove(cn, err) + if err := internal.Unwrap(err); err != nil { + return nil, err + } return nil, err } @@ -99,7 +102,7 @@ func (c *baseClient) releaseConn(cn *pool.Conn, err error) { } if internal.IsBadConn(err, false) { - c.connPool.Remove(cn) + c.connPool.Remove(cn, err) } else { c.connPool.Put(cn) } @@ -113,7 +116,7 @@ func (c *baseClient) releaseConnStrict(cn *pool.Conn, err error) { if err == nil || internal.IsRedisError(err) { c.connPool.Put(cn) } else { - c.connPool.Remove(cn) + c.connPool.Remove(cn, err) } } @@ -541,10 +544,12 @@ type Conn struct { } func newConn(opt *Options, cn *pool.Conn) *Conn { + connPool := pool.NewSingleConnPool(nil) + connPool.SetConn(cn) c := Conn{ baseClient: baseClient{ opt: opt, - connPool: pool.NewSingleConnPool(cn), + connPool: connPool, }, } c.baseClient.init()