diff --git a/conn.go b/conn.go index 36ba99ad..ec4abd18 100644 --- a/conn.go +++ b/conn.go @@ -43,11 +43,8 @@ func (cn *conn) init(opt *Options) error { return nil } - // Use connection to connect to Redis. - pool := newSingleConnPoolConn(cn) - - // Client is not closed because we want to reuse underlying connection. - client := newClient(opt, pool) + // Temp client for Auth and Select. + client := newClient(opt, newSingleConnPool(cn)) if opt.Password != "" { if err := client.Auth(opt.Password).Err(); err != nil { diff --git a/conn_test.go b/conn_test.go new file mode 100644 index 00000000..4f1eef6d --- /dev/null +++ b/conn_test.go @@ -0,0 +1,25 @@ +package redis_test + +import ( + "net" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "gopkg.in/redis.v3" +) + +var _ = Describe("newConnDialer with bad connection", func() { + It("should return an error", func() { + dialer := redis.NewConnDialer(&redis.Options{ + Dialer: func() (net.Conn, error) { + return &badConn{}, nil + }, + MaxRetries: 3, + Password: "password", + DB: 1, + }) + _, err := dialer() + Expect(err).To(MatchError("bad connection")) + }) +}) diff --git a/export_test.go b/export_test.go index f4687296..66ccec25 100644 --- a/export_test.go +++ b/export_test.go @@ -6,6 +6,8 @@ func (c *baseClient) Pool() pool { return c.connPool } +var NewConnDialer = newConnDialer + func (cn *conn) SetNetConn(netcn net.Conn) { cn.netcn = netcn } diff --git a/main_test.go b/main_test.go index d2f8d2ad..eafbeee5 100644 --- a/main_test.go +++ b/main_test.go @@ -232,7 +232,15 @@ func startSentinel(port, masterName, masterPort string) (*redisProcess, error) { //------------------------------------------------------------------------------ -var errTimeout = syscall.ETIMEDOUT +var ( + errTimeout = syscall.ETIMEDOUT +) + +type badConnError string + +func (e badConnError) Error() string { return string(e) } +func (e badConnError) Timeout() bool { return false } +func (e badConnError) Temporary() bool { return false } type badConn struct { net.TCPConn @@ -250,7 +258,7 @@ func (cn *badConn) Read([]byte) (int, error) { if cn.readErr != nil { return 0, cn.readErr } - return 0, net.UnknownNetworkError("badConn") + return 0, badConnError("bad connection") } func (cn *badConn) Write([]byte) (int, error) { @@ -260,5 +268,5 @@ func (cn *badConn) Write([]byte) (int, error) { if cn.writeErr != nil { return 0, cn.writeErr } - return 0, net.UnknownNetworkError("badConn") + return 0, badConnError("bad connection") } diff --git a/multi.go b/multi.go index e3d628fd..00dc9991 100644 --- a/multi.go +++ b/multi.go @@ -22,7 +22,7 @@ func (c *Client) Multi() *Multi { multi := &Multi{ base: &baseClient{ opt: c.opt, - connPool: newSingleConnPool(c.connPool, true), + connPool: newStickyConnPool(c.connPool, true), }, } multi.commandable.process = multi.process diff --git a/pool.go b/pool.go index bd494d85..d048468a 100644 --- a/pool.go +++ b/pool.go @@ -314,6 +314,52 @@ func (p *connPool) reaper() { //------------------------------------------------------------------------------ type singleConnPool struct { + cn *conn +} + +func newSingleConnPool(cn *conn) *singleConnPool { + return &singleConnPool{ + cn: cn, + } +} + +func (p *singleConnPool) First() *conn { + return p.cn +} + +func (p *singleConnPool) Get() (*conn, error) { + return p.cn, nil +} + +func (p *singleConnPool) Put(cn *conn) error { + if p.cn != cn { + panic("p.cn != cn") + } + return nil +} + +func (p *singleConnPool) Remove(cn *conn) error { + if p.cn != cn { + panic("p.cn != cn") + } + return nil +} + +func (p *singleConnPool) Len() int { + return 1 +} + +func (p *singleConnPool) FreeLen() int { + return 0 +} + +func (p *singleConnPool) Close() error { + return nil +} + +//------------------------------------------------------------------------------ + +type stickyConnPool struct { pool pool reusable bool @@ -322,27 +368,21 @@ type singleConnPool struct { mx sync.Mutex } -func newSingleConnPool(pool pool, reusable bool) *singleConnPool { - return &singleConnPool{ +func newStickyConnPool(pool pool, reusable bool) *stickyConnPool { + return &stickyConnPool{ pool: pool, reusable: reusable, } } -func newSingleConnPoolConn(cn *conn) *singleConnPool { - return &singleConnPool{ - cn: cn, - } -} - -func (p *singleConnPool) First() *conn { +func (p *stickyConnPool) First() *conn { p.mx.Lock() cn := p.cn p.mx.Unlock() return cn } -func (p *singleConnPool) Get() (*conn, error) { +func (p *stickyConnPool) Get() (*conn, error) { defer p.mx.Unlock() p.mx.Lock() @@ -362,15 +402,13 @@ func (p *singleConnPool) Get() (*conn, error) { return p.cn, nil } -func (p *singleConnPool) put() (err error) { - if p.pool != nil { - err = p.pool.Put(p.cn) - } +func (p *stickyConnPool) put() (err error) { + err = p.pool.Put(p.cn) p.cn = nil return err } -func (p *singleConnPool) Put(cn *conn) error { +func (p *stickyConnPool) Put(cn *conn) error { defer p.mx.Unlock() p.mx.Lock() if p.cn != cn { @@ -382,30 +420,32 @@ func (p *singleConnPool) Put(cn *conn) error { return nil } -func (p *singleConnPool) remove() (err error) { - if p.pool != nil { - err = p.pool.Remove(p.cn) - } +func (p *stickyConnPool) remove() (err error) { + err = p.pool.Remove(p.cn) p.cn = nil return err } -func (p *singleConnPool) Remove(cn *conn) error { +func (p *stickyConnPool) Remove(cn *conn) error { defer p.mx.Unlock() p.mx.Lock() if p.cn == nil { panic("p.cn == nil") } - if cn != nil && cn != p.cn { - panic("cn != p.cn") + if cn != nil && p.cn != cn { + panic("p.cn != cn") } if p.closed { return errClosed } - return p.remove() + if cn == nil { + return p.remove() + } else { + return nil + } } -func (p *singleConnPool) Len() int { +func (p *stickyConnPool) Len() int { defer p.mx.Unlock() p.mx.Lock() if p.cn == nil { @@ -414,7 +454,7 @@ func (p *singleConnPool) Len() int { return 1 } -func (p *singleConnPool) FreeLen() int { +func (p *stickyConnPool) FreeLen() int { defer p.mx.Unlock() p.mx.Lock() if p.cn == nil { @@ -423,7 +463,7 @@ func (p *singleConnPool) FreeLen() int { return 0 } -func (p *singleConnPool) Close() error { +func (p *stickyConnPool) Close() error { defer p.mx.Unlock() p.mx.Lock() if p.closed { diff --git a/pool_test.go b/pool_test.go index bff892cc..d59c7d2d 100644 --- a/pool_test.go +++ b/pool_test.go @@ -11,7 +11,7 @@ import ( "gopkg.in/redis.v3" ) -var _ = Describe("Pool", func() { +var _ = Describe("pool", func() { var client *redis.Client var perform = func(n int, cb func()) { diff --git a/pubsub.go b/pubsub.go index ba053e47..8096c938 100644 --- a/pubsub.go +++ b/pubsub.go @@ -29,7 +29,7 @@ func (c *Client) PubSub() *PubSub { return &PubSub{ baseClient: &baseClient{ opt: c.opt, - connPool: newSingleConnPool(c.connPool, false), + connPool: newStickyConnPool(c.connPool, false), }, } } diff --git a/redis_test.go b/redis_test.go index acc8ca1d..3ad4ae25 100644 --- a/redis_test.go +++ b/redis_test.go @@ -161,7 +161,8 @@ var _ = Describe("Client", func() { Expect(err).NotTo(HaveOccurred()) cn.SetNetConn(&badConn{}) - Expect(client.Pool().Put(cn)).NotTo(HaveOccurred()) + err = client.Pool().Put(cn) + Expect(err).NotTo(HaveOccurred()) err = client.Ping().Err() Expect(err).NotTo(HaveOccurred()) diff --git a/sentinel.go b/sentinel.go index 04a821bd..63c011d4 100644 --- a/sentinel.go +++ b/sentinel.go @@ -90,7 +90,7 @@ func (c *sentinelClient) PubSub() *PubSub { return &PubSub{ baseClient: &baseClient{ opt: c.opt, - connPool: newSingleConnPool(c.connPool, false), + connPool: newStickyConnPool(c.connPool, false), }, } }