diff --git a/internal/error.go b/internal/error.go index cd90d0d..a425574 100644 --- a/internal/error.go +++ b/internal/error.go @@ -2,6 +2,7 @@ package internal import ( "context" + "errors" "io" "net" "strings" @@ -9,6 +10,8 @@ import ( "github.com/go-redis/redis/internal/proto" ) +var ErrSingleConnPoolClosed = errors.New("redis: SingleConnPool is closed") + func IsRetryableError(err error, retryTimeout bool) bool { switch err { case nil, context.Canceled, context.DeadlineExceeded: @@ -22,6 +25,10 @@ func IsRetryableError(err error, retryTimeout bool) bool { } return true } + if err == ErrSingleConnPoolClosed { + return true + } + s := err.Error() if s == "ERR max number of clients reached" { return true diff --git a/internal/pool/pool_single.go b/internal/pool/pool_single.go index fda1c8a..12542c6 100644 --- a/internal/pool/pool_single.go +++ b/internal/pool/pool_single.go @@ -1,9 +1,14 @@ package pool -import "context" +import ( + "context" + + "github.com/go-redis/redis/internal" +) type SingleConnPool struct { - cn *Conn + cn *Conn + cnClosed bool } var _ Pooler = (*SingleConnPool)(nil) @@ -23,6 +28,9 @@ func (p *SingleConnPool) CloseConn(*Conn) error { } func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) { + if p.cnClosed { + return nil, internal.ErrSingleConnPoolClosed + } return p.cn, nil } @@ -36,9 +44,13 @@ func (p *SingleConnPool) Remove(cn *Conn) { if p.cn != cn { panic("p.cn != cn") } + p.cnClosed = true } func (p *SingleConnPool) Len() int { + if p.cnClosed { + return 0 + } return 1 }