package pool import ( "context" "errors" "net" "sync" "sync/atomic" "time" "github.com/go-redis/redis/internal" ) var ErrClosed = errors.New("redis: client is closed") var ErrPoolTimeout = errors.New("redis: connection pool timeout") var timers = sync.Pool{ New: func() interface{} { t := time.NewTimer(time.Hour) t.Stop() return t }, } // Stats contains pool state information and accumulated stats. type Stats struct { Hits uint32 // number of times free connection was found in the pool Misses uint32 // number of times free connection was NOT found in the pool Timeouts uint32 // number of times a wait timeout occurred TotalConns uint32 // number of total connections in the pool IdleConns uint32 // number of idle connections in the pool StaleConns uint32 // number of stale connections removed from the pool } type Pooler interface { NewConn(context.Context) (*Conn, error) CloseConn(*Conn) error Get(context.Context) (*Conn, error) Put(*Conn) Remove(*Conn, error) Len() int IdleLen() int Stats() *Stats Close() error } type Options struct { Dialer func(c context.Context) (net.Conn, error) OnClose func(*Conn) error PoolSize int MinIdleConns int MaxConnAge time.Duration PoolTimeout time.Duration IdleTimeout time.Duration IdleCheckFrequency time.Duration } type ConnPool struct { opt *Options dialErrorsNum uint32 // atomic lastDialErrorMu sync.RWMutex lastDialError error queue chan struct{} connsMu sync.Mutex conns []*Conn idleConns []*Conn poolSize int idleConnsLen int stats Stats _closed uint32 // atomic } var _ Pooler = (*ConnPool)(nil) func NewConnPool(opt *Options) *ConnPool { p := &ConnPool{ opt: opt, queue: make(chan struct{}, opt.PoolSize), conns: make([]*Conn, 0, opt.PoolSize), idleConns: make([]*Conn, 0, opt.PoolSize), } p.checkMinIdleConns() if opt.IdleTimeout > 0 && opt.IdleCheckFrequency > 0 { go p.reaper(opt.IdleCheckFrequency) } return p } func (p *ConnPool) checkMinIdleConns() { if p.opt.MinIdleConns == 0 { return } for p.poolSize < p.opt.PoolSize && p.idleConnsLen < p.opt.MinIdleConns { p.poolSize++ p.idleConnsLen++ go func() { err := p.addIdleConn() if err != nil { p.connsMu.Lock() p.poolSize-- p.idleConnsLen-- p.connsMu.Unlock() } }() } } func (p *ConnPool) addIdleConn() error { cn, err := p.dialConn(context.TODO(), true) if err != nil { return err } p.connsMu.Lock() p.conns = append(p.conns, cn) p.idleConns = append(p.idleConns, cn) p.connsMu.Unlock() return nil } func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) { return p.newConn(ctx, false) } func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { cn, err := p.dialConn(ctx, pooled) if err != nil { return nil, err } p.connsMu.Lock() p.conns = append(p.conns, cn) if pooled { // If pool is full remove the cn on next Put. if p.poolSize >= p.opt.PoolSize { cn.pooled = false } else { p.poolSize++ } } p.connsMu.Unlock() return cn, nil } func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) { if p.closed() { return nil, ErrClosed } if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.opt.PoolSize) { return nil, p.getLastDialError() } netConn, err := p.opt.Dialer(ctx) if err != nil { p.setLastDialError(err) if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) { go p.tryDial() } return nil, err } cn := NewConn(netConn) cn.pooled = pooled return cn, nil } func (p *ConnPool) tryDial() { for { if p.closed() { return } conn, err := p.opt.Dialer(context.Background()) if err != nil { p.setLastDialError(err) time.Sleep(time.Second) continue } atomic.StoreUint32(&p.dialErrorsNum, 0) _ = conn.Close() return } } func (p *ConnPool) setLastDialError(err error) { p.lastDialErrorMu.Lock() p.lastDialError = err p.lastDialErrorMu.Unlock() } func (p *ConnPool) getLastDialError() error { p.lastDialErrorMu.RLock() err := p.lastDialError p.lastDialErrorMu.RUnlock() return err } // Get returns existed connection from the pool or creates a new one. func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { if p.closed() { return nil, ErrClosed } err := p.waitTurn(ctx) if err != nil { return nil, err } for { p.connsMu.Lock() cn := p.popIdle() p.connsMu.Unlock() if cn == nil { break } if p.isStaleConn(cn) { _ = p.CloseConn(cn) continue } atomic.AddUint32(&p.stats.Hits, 1) return cn, nil } atomic.AddUint32(&p.stats.Misses, 1) newcn, err := p.newConn(ctx, true) if err != nil { p.freeTurn() return nil, err } return newcn, nil } func (p *ConnPool) getTurn() { p.queue <- struct{}{} } func (p *ConnPool) waitTurn(ctx context.Context) error { select { case <-ctx.Done(): return ctx.Err() default: } select { case p.queue <- struct{}{}: return nil default: } timer := timers.Get().(*time.Timer) timer.Reset(p.opt.PoolTimeout) select { case <-ctx.Done(): if !timer.Stop() { <-timer.C } timers.Put(timer) return ctx.Err() case p.queue <- struct{}{}: if !timer.Stop() { <-timer.C } timers.Put(timer) return nil case <-timer.C: timers.Put(timer) atomic.AddUint32(&p.stats.Timeouts, 1) return ErrPoolTimeout } } func (p *ConnPool) freeTurn() { <-p.queue } func (p *ConnPool) popIdle() *Conn { if len(p.idleConns) == 0 { return nil } idx := len(p.idleConns) - 1 cn := p.idleConns[idx] p.idleConns = p.idleConns[:idx] p.idleConnsLen-- p.checkMinIdleConns() return cn } func (p *ConnPool) Put(cn *Conn) { if !cn.pooled { p.Remove(cn, nil) return } p.connsMu.Lock() p.idleConns = append(p.idleConns, cn) p.idleConnsLen++ p.connsMu.Unlock() p.freeTurn() } func (p *ConnPool) Remove(cn *Conn, reason error) { p.removeConnWithLock(cn) p.freeTurn() _ = p.closeConn(cn) } func (p *ConnPool) CloseConn(cn *Conn) error { p.removeConnWithLock(cn) return p.closeConn(cn) } func (p *ConnPool) removeConnWithLock(cn *Conn) { p.connsMu.Lock() p.removeConn(cn) p.connsMu.Unlock() } func (p *ConnPool) removeConn(cn *Conn) { for i, c := range p.conns { if c == cn { p.conns = append(p.conns[:i], p.conns[i+1:]...) if cn.pooled { p.poolSize-- p.checkMinIdleConns() } return } } } func (p *ConnPool) closeConn(cn *Conn) error { if p.opt.OnClose != nil { _ = p.opt.OnClose(cn) } return cn.Close() } // Len returns total number of connections. func (p *ConnPool) Len() int { p.connsMu.Lock() n := len(p.conns) p.connsMu.Unlock() return n } // IdleLen returns number of idle connections. func (p *ConnPool) IdleLen() int { p.connsMu.Lock() n := p.idleConnsLen p.connsMu.Unlock() return n } func (p *ConnPool) Stats() *Stats { idleLen := p.IdleLen() return &Stats{ Hits: atomic.LoadUint32(&p.stats.Hits), Misses: atomic.LoadUint32(&p.stats.Misses), Timeouts: atomic.LoadUint32(&p.stats.Timeouts), TotalConns: uint32(p.Len()), IdleConns: uint32(idleLen), StaleConns: atomic.LoadUint32(&p.stats.StaleConns), } } func (p *ConnPool) closed() bool { return atomic.LoadUint32(&p._closed) == 1 } func (p *ConnPool) Filter(fn func(*Conn) bool) error { var firstErr error p.connsMu.Lock() for _, cn := range p.conns { if fn(cn) { if err := p.closeConn(cn); err != nil && firstErr == nil { firstErr = err } } } p.connsMu.Unlock() return firstErr } func (p *ConnPool) Close() error { if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) { return ErrClosed } var firstErr error p.connsMu.Lock() for _, cn := range p.conns { if err := p.closeConn(cn); err != nil && firstErr == nil { firstErr = err } } p.conns = nil p.poolSize = 0 p.idleConns = nil p.idleConnsLen = 0 p.connsMu.Unlock() return firstErr } func (p *ConnPool) reaper(frequency time.Duration) { ticker := time.NewTicker(frequency) defer ticker.Stop() for range ticker.C { if p.closed() { break } _, err := p.ReapStaleConns() if err != nil { internal.Logger.Printf("ReapStaleConns failed: %s", err) continue } } } func (p *ConnPool) ReapStaleConns() (int, error) { var n int for { p.getTurn() p.connsMu.Lock() cn := p.reapStaleConn() p.connsMu.Unlock() p.freeTurn() if cn != nil { _ = p.closeConn(cn) n++ } else { break } } atomic.AddUint32(&p.stats.StaleConns, uint32(n)) return n, nil } func (p *ConnPool) reapStaleConn() *Conn { if len(p.idleConns) == 0 { return nil } cn := p.idleConns[0] if !p.isStaleConn(cn) { return nil } p.idleConns = append(p.idleConns[:0], p.idleConns[1:]...) p.idleConnsLen-- p.removeConn(cn) return cn } func (p *ConnPool) isStaleConn(cn *Conn) bool { if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 { return false } now := time.Now() if p.opt.IdleTimeout > 0 && now.Sub(cn.UsedAt()) >= p.opt.IdleTimeout { return true } if p.opt.MaxConnAge > 0 && now.Sub(cn.createdAt) >= p.opt.MaxConnAge { return true } return false }