diff --git a/.travis.yml b/.travis.yml index 6ef52f7..22cc206 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,4 +1,5 @@ language: go +sudo: false services: - redis-server diff --git a/cluster_pipeline.go b/cluster_pipeline.go index 90687a8..3c93bbf 100644 --- a/cluster_pipeline.go +++ b/cluster_pipeline.go @@ -63,7 +63,7 @@ func (pipe *ClusterPipeline) Exec() (cmds []Cmder, retErr error) { continue } - cn, err := client.conn() + cn, _, err := client.conn() if err != nil { setCmdsErr(cmds, err) retErr = err diff --git a/cluster_test.go b/cluster_test.go index 136340c..bc395b4 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -1,8 +1,10 @@ package redis_test import ( + "fmt" "math/rand" "net" + "strings" "testing" "time" @@ -53,7 +55,7 @@ func (s *clusterScenario) clusterClient(opt *redis.ClusterOptions) *redis.Cluste } func startCluster(scenario *clusterScenario) error { - // Start processes, connect individual clients + // Start processes and collect node ids for pos, port := range scenario.ports { process, err := startRedis(port, "--cluster-enabled", "yes") if err != nil { @@ -81,44 +83,48 @@ func startCluster(scenario *clusterScenario) error { // Bootstrap masters slots := []int{0, 5000, 10000, 16384} - for pos, client := range scenario.masters() { - err := client.ClusterAddSlotsRange(slots[pos], slots[pos+1]-1).Err() + for pos, master := range scenario.masters() { + err := master.ClusterAddSlotsRange(slots[pos], slots[pos+1]-1).Err() if err != nil { return err } } // Bootstrap slaves - for pos, client := range scenario.slaves() { - masterId := scenario.nodeIds[pos] + for idx, slave := range scenario.slaves() { + masterId := scenario.nodeIds[idx] - // Wait for masters - err := waitForSubstring(func() string { - return client.ClusterNodes().Val() - }, masterId, 10*time.Second) + // Wait until master is available + err := eventually(func() error { + s := slave.ClusterNodes().Val() + wanted := masterId + if !strings.Contains(s, wanted) { + return fmt.Errorf("%q does not contain %q", s, wanted) + } + return nil + }, 10*time.Second) if err != nil { return err } - err = client.ClusterReplicate(masterId).Err() - if err != nil { - return err - } - - // Wait for slaves - err = waitForSubstring(func() string { - return scenario.primary().ClusterNodes().Val() - }, "slave "+masterId, 10*time.Second) + err = slave.ClusterReplicate(masterId).Err() if err != nil { return err } } - // Wait for cluster state to turn OK + // Wait until all nodes have consistent info for _, client := range scenario.clients { - err := waitForSubstring(func() string { - return client.ClusterInfo().Val() - }, "cluster_state:ok", 10*time.Second) + err := eventually(func() error { + for _, masterId := range scenario.nodeIds[:3] { + s := client.ClusterNodes().Val() + wanted := "slave " + masterId + if !strings.Contains(s, wanted) { + return fmt.Errorf("%q does not contain %q", s, wanted) + } + } + return nil + }, 10*time.Second) if err != nil { return err } @@ -260,7 +266,6 @@ var _ = Describe("Cluster", func() { It("should perform multi-pipelines", func() { slot := redis.HashSlot("A") - Expect(client.SlotAddrs(slot)).To(Equal([]string{"127.0.0.1:8221", "127.0.0.1:8224"})) Expect(client.SwapSlot(slot)).To(Equal([]string{"127.0.0.1:8224", "127.0.0.1:8221"})) pipe := client.Pipeline() @@ -288,6 +293,7 @@ var _ = Describe("Cluster", func() { }) It("should return error when there are no attempts left", func() { + Expect(client.Close()).NotTo(HaveOccurred()) client = cluster.clusterClient(&redis.ClusterOptions{ MaxRedirects: -1, }) diff --git a/error.go b/error.go index 9e5d973..1365ca1 100644 --- a/error.go +++ b/error.go @@ -26,10 +26,24 @@ func (err redisError) Error() string { } func isNetworkError(err error) bool { - if _, ok := err.(net.Error); ok || err == io.EOF { + if err == io.EOF { return true } - return false + _, ok := err.(net.Error) + return ok +} + +func isBadConn(cn *conn, ei error) bool { + if cn.rd.Buffered() > 0 { + return true + } + if ei == nil { + return false + } + if _, ok := ei.(redisError); ok { + return false + } + return true } func isMovedError(err error) (moved bool, ask bool, addr string) { diff --git a/main_test.go b/main_test.go index eafbeee..806d7d3 100644 --- a/main_test.go +++ b/main_test.go @@ -1,12 +1,10 @@ package redis_test import ( - "fmt" "net" "os" "os/exec" "path/filepath" - "strings" "sync/atomic" "syscall" "testing" @@ -100,17 +98,14 @@ func TestGinkgoSuite(t *testing.T) { //------------------------------------------------------------------------------ -// Replaces ginkgo's Eventually. -func waitForSubstring(fn func() string, substr string, timeout time.Duration) error { - var s string - - found := make(chan struct{}) +func eventually(fn func() error, timeout time.Duration) (err error) { + done := make(chan struct{}) var exit int32 go func() { for atomic.LoadInt32(&exit) == 0 { - s = fn() - if strings.Contains(s, substr) { - found <- struct{}{} + err = fn() + if err == nil { + close(done) return } time.Sleep(timeout / 100) @@ -118,12 +113,12 @@ func waitForSubstring(fn func() string, substr string, timeout time.Duration) er }() select { - case <-found: + case <-done: return nil case <-time.After(timeout): atomic.StoreInt32(&exit, 1) + return err } - return fmt.Errorf("%q does not contain %q", s, substr) } func execCmd(name string, args ...string) (*os.Process, error) { diff --git a/multi.go b/multi.go index d0859e6..86be83f 100644 --- a/multi.go +++ b/multi.go @@ -10,10 +10,10 @@ var errDiscard = errors.New("redis: Discard can be used only inside Exec") // Multi implements Redis transactions as described in // http://redis.io/topics/transactions. It's NOT safe for concurrent use -// by multiple goroutines, because Exec resets connection state. +// by multiple goroutines, because Exec resets list of watched keys. // If you don't need WATCH it is better to use Pipeline. // -// TODO(vmihailenco): rename to Tx +// TODO(vmihailenco): rename to Tx and rework API type Multi struct { commandable @@ -34,6 +34,18 @@ func (c *Client) Multi() *Multi { return multi } +func (c *Multi) putConn(cn *conn, ei error) { + var err error + if isBadConn(cn, ei) { + err = c.base.connPool.Remove(nil) // nil to force removal + } else { + err = c.base.connPool.Put(cn) + } + if err != nil { + log.Printf("redis: putConn failed: %s", err) + } +} + func (c *Multi) process(cmd Cmder) { if c.cmds == nil { c.base.process(cmd) @@ -112,15 +124,18 @@ func (c *Multi) Exec(f func() error) ([]Cmder, error) { return []Cmder{}, nil } - cn, err := c.base.conn() + // Strip MULTI and EXEC commands. + retCmds := cmds[1 : len(cmds)-1] + + cn, _, err := c.base.conn() if err != nil { - setCmdsErr(cmds[1:len(cmds)-1], err) - return cmds[1 : len(cmds)-1], err + setCmdsErr(retCmds, err) + return retCmds, err } err = c.execCmds(cn, cmds) - c.base.putConn(cn, err) - return cmds[1 : len(cmds)-1], err + c.putConn(cn, err) + return retCmds, err } func (c *Multi) execCmds(cn *conn, cmds []Cmder) error { diff --git a/multi_test.go b/multi_test.go index b481a52..c95c703 100644 --- a/multi_test.go +++ b/multi_test.go @@ -119,4 +119,30 @@ var _ = Describe("Multi", func() { Expect(get.Val()).To(Equal("20000")) }) + It("should recover from bad connection", func() { + // Put bad connection in the pool. + cn, _, err := client.Pool().Get() + Expect(err).NotTo(HaveOccurred()) + + cn.SetNetConn(&badConn{}) + err = client.Pool().Put(cn) + Expect(err).NotTo(HaveOccurred()) + + multi := client.Multi() + defer func() { + Expect(multi.Close()).NotTo(HaveOccurred()) + }() + + _, err = multi.Exec(func() error { + multi.Ping() + return nil + }) + Expect(err).To(MatchError("bad connection")) + + _, err = multi.Exec(func() error { + multi.Ping() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + }) }) diff --git a/pipeline.go b/pipeline.go index 27c9f8e..56ff965 100644 --- a/pipeline.go +++ b/pipeline.go @@ -88,7 +88,7 @@ func (pipe *Pipeline) Exec() (cmds []Cmder, retErr error) { failedCmds := cmds for i := 0; i <= pipe.client.opt.MaxRetries; i++ { - cn, err := pipe.client.conn() + cn, _, err := pipe.client.conn() if err != nil { setCmdsErr(failedCmds, err) return cmds, err diff --git a/pool.go b/pool.go index d048468..a6ab257 100644 --- a/pool.go +++ b/pool.go @@ -18,7 +18,7 @@ var ( type pool interface { First() *conn - Get() (*conn, error) + Get() (*conn, bool, error) Put(*conn) error Remove(*conn) error Len() int @@ -212,33 +212,36 @@ func (p *connPool) new() (*conn, error) { } // Get returns existed connection from the pool or creates a new one. -func (p *connPool) Get() (*conn, error) { +func (p *connPool) Get() (cn *conn, isNew bool, err error) { if p.closed() { - return nil, errClosed + err = errClosed + return } // Fetch first non-idle connection, if available. - if cn := p.First(); cn != nil { - return cn, nil + if cn = p.First(); cn != nil { + return } // Try to create a new one. if p.conns.Reserve() { - cn, err := p.new() + cn, err = p.new() if err != nil { p.conns.Remove(nil) - return nil, err + return } p.conns.Add(cn) - return cn, nil + isNew = true + return } // Otherwise, wait for the available connection. - if cn := p.wait(); cn != nil { - return cn, nil + if cn = p.wait(); cn != nil { + return } - return nil, errPoolTimeout + err = errPoolTimeout + return } func (p *connPool) Put(cn *conn) error { @@ -327,8 +330,8 @@ func (p *singleConnPool) First() *conn { return p.cn } -func (p *singleConnPool) Get() (*conn, error) { - return p.cn, nil +func (p *singleConnPool) Get() (*conn, bool, error) { + return p.cn, false, nil } func (p *singleConnPool) Put(cn *conn) error { @@ -382,24 +385,25 @@ func (p *stickyConnPool) First() *conn { return cn } -func (p *stickyConnPool) Get() (*conn, error) { +func (p *stickyConnPool) Get() (cn *conn, isNew bool, err error) { defer p.mx.Unlock() p.mx.Lock() if p.closed { - return nil, errClosed + err = errClosed + return } if p.cn != nil { - return p.cn, nil + cn = p.cn + return } - cn, err := p.pool.Get() + cn, isNew, err = p.pool.Get() if err != nil { - return nil, err + return } p.cn = cn - - return p.cn, nil + return } func (p *stickyConnPool) put() (err error) { diff --git a/pool_test.go b/pool_test.go index d59c7d2..9eb2c99 100644 --- a/pool_test.go +++ b/pool_test.go @@ -107,7 +107,7 @@ var _ = Describe("pool", func() { }) It("should remove broken connections", func() { - cn, err := client.Pool().Get() + cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) Expect(cn.Close()).NotTo(HaveOccurred()) Expect(client.Pool().Put(cn)).NotTo(HaveOccurred()) @@ -141,12 +141,12 @@ var _ = Describe("pool", func() { pool := client.Pool() // Reserve one connection. - cn, err := client.Pool().Get() + cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) // Reserve the rest of connections. for i := 0; i < 9; i++ { - _, err := client.Pool().Get() + _, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) } @@ -191,7 +191,7 @@ func BenchmarkPool(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - conn, err := pool.Get() + conn, _, err := pool.Get() if err != nil { b.Fatalf("no error expected on pool.Get but received: %s", err.Error()) } diff --git a/pubsub.go b/pubsub.go index 8096c93..a3792d3 100644 --- a/pubsub.go +++ b/pubsub.go @@ -47,7 +47,7 @@ func (c *Client) PSubscribe(channels ...string) (*PubSub, error) { } func (c *PubSub) subscribe(cmd string, channels ...string) error { - cn, err := c.conn() + cn, _, err := c.conn() if err != nil { return err } @@ -112,7 +112,7 @@ func (c *PubSub) PUnsubscribe(patterns ...string) error { } func (c *PubSub) Ping(payload string) error { - cn, err := c.conn() + cn, _, err := c.conn() if err != nil { return err } @@ -208,14 +208,16 @@ func newMessage(reply []interface{}) (interface{}, error) { // is not received in time. This is low-level API and most clients // should use ReceiveMessage. func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { - cn, err := c.conn() + cn, _, err := c.conn() if err != nil { return nil, err } cn.ReadTimeout = timeout cmd := NewSliceCmd() - if err := cmd.readReply(cn); err != nil { + err = cmd.readReply(cn) + c.putConn(cn, err) + if err != nil { return nil, err } return newMessage(cmd.Val()) @@ -229,7 +231,7 @@ func (c *PubSub) Receive() (interface{}, error) { } func (c *PubSub) reconnect() { - c.connPool.Remove(nil) // close current connection + c.connPool.Remove(nil) // nil to force removal if len(c.channels) > 0 { if err := c.Subscribe(c.channels...); err != nil { log.Printf("redis: Subscribe failed: %s", err) diff --git a/pubsub_test.go b/pubsub_test.go index 5a7b0da..411c643 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -254,7 +254,7 @@ var _ = Describe("PubSub", func() { Expect(err).NotTo(HaveOccurred()) defer pubsub.Close() - cn, err := pubsub.Pool().Get() + cn, _, err := pubsub.Pool().Get() Expect(err).NotTo(HaveOccurred()) cn.SetNetConn(&badConn{ readErr: errTimeout, diff --git a/redis.go b/redis.go index 2d1076d..6d654b1 100644 --- a/redis.go +++ b/redis.go @@ -16,20 +16,16 @@ func (c *baseClient) String() string { return fmt.Sprintf("Redis<%s db:%d>", c.opt.Addr, c.opt.DB) } -func (c *baseClient) conn() (*conn, error) { +func (c *baseClient) conn() (*conn, bool, error) { return c.connPool.Get() } func (c *baseClient) putConn(cn *conn, ei error) { var err error - if cn.rd.Buffered() > 0 { + if isBadConn(cn, ei) { err = c.connPool.Remove(cn) - } else if ei == nil { - err = c.connPool.Put(cn) - } else if _, ok := ei.(redisError); ok { - err = c.connPool.Put(cn) } else { - err = c.connPool.Remove(cn) + err = c.connPool.Put(cn) } if err != nil { log.Printf("redis: putConn failed: %s", err) @@ -42,7 +38,7 @@ func (c *baseClient) process(cmd Cmder) { cmd.reset() } - cn, err := c.conn() + cn, _, err := c.conn() if err != nil { cmd.setErr(err) return diff --git a/redis_test.go b/redis_test.go index 3ad4ae2..f1ebf62 100644 --- a/redis_test.go +++ b/redis_test.go @@ -157,7 +157,7 @@ var _ = Describe("Client", func() { }) // Put bad connection in the pool. - cn, err := client.Pool().Get() + cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) cn.SetNetConn(&badConn{}) diff --git a/ring.go b/ring.go index 8005af9..facf3e6 100644 --- a/ring.go +++ b/ring.go @@ -313,7 +313,7 @@ func (pipe *RingPipeline) Exec() (cmds []Cmder, retErr error) { for name, cmds := range cmdsMap { client := pipe.ring.shards[name].Client - cn, err := client.conn() + cn, _, err := client.conn() if err != nil { setCmdsErr(cmds, err) if retErr == nil {