Merge pull request #188 from go-redis/fix/multi-bad-conn

multi: fix recovering from bad connection.
This commit is contained in:
Vladimir Mihailenco 2015-11-14 16:05:22 +02:00
commit 0ddff681c2
15 changed files with 145 additions and 86 deletions

View File

@ -1,4 +1,5 @@
language: go language: go
sudo: false
services: services:
- redis-server - redis-server

View File

@ -63,7 +63,7 @@ func (pipe *ClusterPipeline) Exec() (cmds []Cmder, retErr error) {
continue continue
} }
cn, err := client.conn() cn, _, err := client.conn()
if err != nil { if err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)
retErr = err retErr = err

View File

@ -1,8 +1,10 @@
package redis_test package redis_test
import ( import (
"fmt"
"math/rand" "math/rand"
"net" "net"
"strings"
"testing" "testing"
"time" "time"
@ -53,7 +55,7 @@ func (s *clusterScenario) clusterClient(opt *redis.ClusterOptions) *redis.Cluste
} }
func startCluster(scenario *clusterScenario) error { func startCluster(scenario *clusterScenario) error {
// Start processes, connect individual clients // Start processes and collect node ids
for pos, port := range scenario.ports { for pos, port := range scenario.ports {
process, err := startRedis(port, "--cluster-enabled", "yes") process, err := startRedis(port, "--cluster-enabled", "yes")
if err != nil { if err != nil {
@ -81,44 +83,48 @@ func startCluster(scenario *clusterScenario) error {
// Bootstrap masters // Bootstrap masters
slots := []int{0, 5000, 10000, 16384} slots := []int{0, 5000, 10000, 16384}
for pos, client := range scenario.masters() { for pos, master := range scenario.masters() {
err := client.ClusterAddSlotsRange(slots[pos], slots[pos+1]-1).Err() err := master.ClusterAddSlotsRange(slots[pos], slots[pos+1]-1).Err()
if err != nil { if err != nil {
return err return err
} }
} }
// Bootstrap slaves // Bootstrap slaves
for pos, client := range scenario.slaves() { for idx, slave := range scenario.slaves() {
masterId := scenario.nodeIds[pos] masterId := scenario.nodeIds[idx]
// Wait for masters // Wait until master is available
err := waitForSubstring(func() string { err := eventually(func() error {
return client.ClusterNodes().Val() s := slave.ClusterNodes().Val()
}, masterId, 10*time.Second) wanted := masterId
if !strings.Contains(s, wanted) {
return fmt.Errorf("%q does not contain %q", s, wanted)
}
return nil
}, 10*time.Second)
if err != nil { if err != nil {
return err return err
} }
err = client.ClusterReplicate(masterId).Err() err = slave.ClusterReplicate(masterId).Err()
if err != nil {
return err
}
// Wait for slaves
err = waitForSubstring(func() string {
return scenario.primary().ClusterNodes().Val()
}, "slave "+masterId, 10*time.Second)
if err != nil { if err != nil {
return err return err
} }
} }
// Wait for cluster state to turn OK // Wait until all nodes have consistent info
for _, client := range scenario.clients { for _, client := range scenario.clients {
err := waitForSubstring(func() string { err := eventually(func() error {
return client.ClusterInfo().Val() for _, masterId := range scenario.nodeIds[:3] {
}, "cluster_state:ok", 10*time.Second) s := client.ClusterNodes().Val()
wanted := "slave " + masterId
if !strings.Contains(s, wanted) {
return fmt.Errorf("%q does not contain %q", s, wanted)
}
}
return nil
}, 10*time.Second)
if err != nil { if err != nil {
return err return err
} }
@ -260,7 +266,6 @@ var _ = Describe("Cluster", func() {
It("should perform multi-pipelines", func() { It("should perform multi-pipelines", func() {
slot := redis.HashSlot("A") slot := redis.HashSlot("A")
Expect(client.SlotAddrs(slot)).To(Equal([]string{"127.0.0.1:8221", "127.0.0.1:8224"}))
Expect(client.SwapSlot(slot)).To(Equal([]string{"127.0.0.1:8224", "127.0.0.1:8221"})) Expect(client.SwapSlot(slot)).To(Equal([]string{"127.0.0.1:8224", "127.0.0.1:8221"}))
pipe := client.Pipeline() pipe := client.Pipeline()
@ -288,6 +293,7 @@ var _ = Describe("Cluster", func() {
}) })
It("should return error when there are no attempts left", func() { It("should return error when there are no attempts left", func() {
Expect(client.Close()).NotTo(HaveOccurred())
client = cluster.clusterClient(&redis.ClusterOptions{ client = cluster.clusterClient(&redis.ClusterOptions{
MaxRedirects: -1, MaxRedirects: -1,
}) })

View File

@ -26,11 +26,25 @@ func (err redisError) Error() string {
} }
func isNetworkError(err error) bool { func isNetworkError(err error) bool {
if _, ok := err.(net.Error); ok || err == io.EOF { if err == io.EOF {
return true return true
} }
_, ok := err.(net.Error)
return ok
}
func isBadConn(cn *conn, ei error) bool {
if cn.rd.Buffered() > 0 {
return true
}
if ei == nil {
return false return false
} }
if _, ok := ei.(redisError); ok {
return false
}
return true
}
func isMovedError(err error) (moved bool, ask bool, addr string) { func isMovedError(err error) (moved bool, ask bool, addr string) {
if _, ok := err.(redisError); !ok { if _, ok := err.(redisError); !ok {

View File

@ -1,12 +1,10 @@
package redis_test package redis_test
import ( import (
"fmt"
"net" "net"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"strings"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"testing" "testing"
@ -100,17 +98,14 @@ func TestGinkgoSuite(t *testing.T) {
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
// Replaces ginkgo's Eventually. func eventually(fn func() error, timeout time.Duration) (err error) {
func waitForSubstring(fn func() string, substr string, timeout time.Duration) error { done := make(chan struct{})
var s string
found := make(chan struct{})
var exit int32 var exit int32
go func() { go func() {
for atomic.LoadInt32(&exit) == 0 { for atomic.LoadInt32(&exit) == 0 {
s = fn() err = fn()
if strings.Contains(s, substr) { if err == nil {
found <- struct{}{} close(done)
return return
} }
time.Sleep(timeout / 100) time.Sleep(timeout / 100)
@ -118,12 +113,12 @@ func waitForSubstring(fn func() string, substr string, timeout time.Duration) er
}() }()
select { select {
case <-found: case <-done:
return nil return nil
case <-time.After(timeout): case <-time.After(timeout):
atomic.StoreInt32(&exit, 1) atomic.StoreInt32(&exit, 1)
return err
} }
return fmt.Errorf("%q does not contain %q", s, substr)
} }
func execCmd(name string, args ...string) (*os.Process, error) { func execCmd(name string, args ...string) (*os.Process, error) {

View File

@ -10,10 +10,10 @@ var errDiscard = errors.New("redis: Discard can be used only inside Exec")
// Multi implements Redis transactions as described in // Multi implements Redis transactions as described in
// http://redis.io/topics/transactions. It's NOT safe for concurrent use // http://redis.io/topics/transactions. It's NOT safe for concurrent use
// by multiple goroutines, because Exec resets connection state. // by multiple goroutines, because Exec resets list of watched keys.
// If you don't need WATCH it is better to use Pipeline. // If you don't need WATCH it is better to use Pipeline.
// //
// TODO(vmihailenco): rename to Tx // TODO(vmihailenco): rename to Tx and rework API
type Multi struct { type Multi struct {
commandable commandable
@ -34,6 +34,18 @@ func (c *Client) Multi() *Multi {
return multi return multi
} }
func (c *Multi) putConn(cn *conn, ei error) {
var err error
if isBadConn(cn, ei) {
err = c.base.connPool.Remove(nil) // nil to force removal
} else {
err = c.base.connPool.Put(cn)
}
if err != nil {
log.Printf("redis: putConn failed: %s", err)
}
}
func (c *Multi) process(cmd Cmder) { func (c *Multi) process(cmd Cmder) {
if c.cmds == nil { if c.cmds == nil {
c.base.process(cmd) c.base.process(cmd)
@ -112,15 +124,18 @@ func (c *Multi) Exec(f func() error) ([]Cmder, error) {
return []Cmder{}, nil return []Cmder{}, nil
} }
cn, err := c.base.conn() // Strip MULTI and EXEC commands.
retCmds := cmds[1 : len(cmds)-1]
cn, _, err := c.base.conn()
if err != nil { if err != nil {
setCmdsErr(cmds[1:len(cmds)-1], err) setCmdsErr(retCmds, err)
return cmds[1 : len(cmds)-1], err return retCmds, err
} }
err = c.execCmds(cn, cmds) err = c.execCmds(cn, cmds)
c.base.putConn(cn, err) c.putConn(cn, err)
return cmds[1 : len(cmds)-1], err return retCmds, err
} }
func (c *Multi) execCmds(cn *conn, cmds []Cmder) error { func (c *Multi) execCmds(cn *conn, cmds []Cmder) error {

View File

@ -119,4 +119,30 @@ var _ = Describe("Multi", func() {
Expect(get.Val()).To(Equal("20000")) Expect(get.Val()).To(Equal("20000"))
}) })
It("should recover from bad connection", func() {
// Put bad connection in the pool.
cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{})
err = client.Pool().Put(cn)
Expect(err).NotTo(HaveOccurred())
multi := client.Multi()
defer func() {
Expect(multi.Close()).NotTo(HaveOccurred())
}()
_, err = multi.Exec(func() error {
multi.Ping()
return nil
})
Expect(err).To(MatchError("bad connection"))
_, err = multi.Exec(func() error {
multi.Ping()
return nil
})
Expect(err).NotTo(HaveOccurred())
})
}) })

View File

@ -88,7 +88,7 @@ func (pipe *Pipeline) Exec() (cmds []Cmder, retErr error) {
failedCmds := cmds failedCmds := cmds
for i := 0; i <= pipe.client.opt.MaxRetries; i++ { for i := 0; i <= pipe.client.opt.MaxRetries; i++ {
cn, err := pipe.client.conn() cn, _, err := pipe.client.conn()
if err != nil { if err != nil {
setCmdsErr(failedCmds, err) setCmdsErr(failedCmds, err)
return cmds, err return cmds, err

44
pool.go
View File

@ -18,7 +18,7 @@ var (
type pool interface { type pool interface {
First() *conn First() *conn
Get() (*conn, error) Get() (*conn, bool, error)
Put(*conn) error Put(*conn) error
Remove(*conn) error Remove(*conn) error
Len() int Len() int
@ -212,33 +212,36 @@ func (p *connPool) new() (*conn, error) {
} }
// Get returns existed connection from the pool or creates a new one. // Get returns existed connection from the pool or creates a new one.
func (p *connPool) Get() (*conn, error) { func (p *connPool) Get() (cn *conn, isNew bool, err error) {
if p.closed() { if p.closed() {
return nil, errClosed err = errClosed
return
} }
// Fetch first non-idle connection, if available. // Fetch first non-idle connection, if available.
if cn := p.First(); cn != nil { if cn = p.First(); cn != nil {
return cn, nil return
} }
// Try to create a new one. // Try to create a new one.
if p.conns.Reserve() { if p.conns.Reserve() {
cn, err := p.new() cn, err = p.new()
if err != nil { if err != nil {
p.conns.Remove(nil) p.conns.Remove(nil)
return nil, err return
} }
p.conns.Add(cn) p.conns.Add(cn)
return cn, nil isNew = true
return
} }
// Otherwise, wait for the available connection. // Otherwise, wait for the available connection.
if cn := p.wait(); cn != nil { if cn = p.wait(); cn != nil {
return cn, nil return
} }
return nil, errPoolTimeout err = errPoolTimeout
return
} }
func (p *connPool) Put(cn *conn) error { func (p *connPool) Put(cn *conn) error {
@ -327,8 +330,8 @@ func (p *singleConnPool) First() *conn {
return p.cn return p.cn
} }
func (p *singleConnPool) Get() (*conn, error) { func (p *singleConnPool) Get() (*conn, bool, error) {
return p.cn, nil return p.cn, false, nil
} }
func (p *singleConnPool) Put(cn *conn) error { func (p *singleConnPool) Put(cn *conn) error {
@ -382,24 +385,25 @@ func (p *stickyConnPool) First() *conn {
return cn return cn
} }
func (p *stickyConnPool) Get() (*conn, error) { func (p *stickyConnPool) Get() (cn *conn, isNew bool, err error) {
defer p.mx.Unlock() defer p.mx.Unlock()
p.mx.Lock() p.mx.Lock()
if p.closed { if p.closed {
return nil, errClosed err = errClosed
return
} }
if p.cn != nil { if p.cn != nil {
return p.cn, nil cn = p.cn
return
} }
cn, err := p.pool.Get() cn, isNew, err = p.pool.Get()
if err != nil { if err != nil {
return nil, err return
} }
p.cn = cn p.cn = cn
return
return p.cn, nil
} }
func (p *stickyConnPool) put() (err error) { func (p *stickyConnPool) put() (err error) {

View File

@ -107,7 +107,7 @@ var _ = Describe("pool", func() {
}) })
It("should remove broken connections", func() { It("should remove broken connections", func() {
cn, err := client.Pool().Get() cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn.Close()).NotTo(HaveOccurred()) Expect(cn.Close()).NotTo(HaveOccurred())
Expect(client.Pool().Put(cn)).NotTo(HaveOccurred()) Expect(client.Pool().Put(cn)).NotTo(HaveOccurred())
@ -141,12 +141,12 @@ var _ = Describe("pool", func() {
pool := client.Pool() pool := client.Pool()
// Reserve one connection. // Reserve one connection.
cn, err := client.Pool().Get() cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
// Reserve the rest of connections. // Reserve the rest of connections.
for i := 0; i < 9; i++ { for i := 0; i < 9; i++ {
_, err := client.Pool().Get() _, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
@ -191,7 +191,7 @@ func BenchmarkPool(b *testing.B) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
conn, err := pool.Get() conn, _, err := pool.Get()
if err != nil { if err != nil {
b.Fatalf("no error expected on pool.Get but received: %s", err.Error()) b.Fatalf("no error expected on pool.Get but received: %s", err.Error())
} }

View File

@ -47,7 +47,7 @@ func (c *Client) PSubscribe(channels ...string) (*PubSub, error) {
} }
func (c *PubSub) subscribe(cmd string, channels ...string) error { func (c *PubSub) subscribe(cmd string, channels ...string) error {
cn, err := c.conn() cn, _, err := c.conn()
if err != nil { if err != nil {
return err return err
} }
@ -112,7 +112,7 @@ func (c *PubSub) PUnsubscribe(patterns ...string) error {
} }
func (c *PubSub) Ping(payload string) error { func (c *PubSub) Ping(payload string) error {
cn, err := c.conn() cn, _, err := c.conn()
if err != nil { if err != nil {
return err return err
} }
@ -208,14 +208,16 @@ func newMessage(reply []interface{}) (interface{}, error) {
// is not received in time. This is low-level API and most clients // is not received in time. This is low-level API and most clients
// should use ReceiveMessage. // should use ReceiveMessage.
func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
cn, err := c.conn() cn, _, err := c.conn()
if err != nil { if err != nil {
return nil, err return nil, err
} }
cn.ReadTimeout = timeout cn.ReadTimeout = timeout
cmd := NewSliceCmd() cmd := NewSliceCmd()
if err := cmd.readReply(cn); err != nil { err = cmd.readReply(cn)
c.putConn(cn, err)
if err != nil {
return nil, err return nil, err
} }
return newMessage(cmd.Val()) return newMessage(cmd.Val())
@ -229,7 +231,7 @@ func (c *PubSub) Receive() (interface{}, error) {
} }
func (c *PubSub) reconnect() { func (c *PubSub) reconnect() {
c.connPool.Remove(nil) // close current connection c.connPool.Remove(nil) // nil to force removal
if len(c.channels) > 0 { if len(c.channels) > 0 {
if err := c.Subscribe(c.channels...); err != nil { if err := c.Subscribe(c.channels...); err != nil {
log.Printf("redis: Subscribe failed: %s", err) log.Printf("redis: Subscribe failed: %s", err)

View File

@ -254,7 +254,7 @@ var _ = Describe("PubSub", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
defer pubsub.Close() defer pubsub.Close()
cn, err := pubsub.Pool().Get() cn, _, err := pubsub.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{ cn.SetNetConn(&badConn{
readErr: errTimeout, readErr: errTimeout,

View File

@ -16,20 +16,16 @@ func (c *baseClient) String() string {
return fmt.Sprintf("Redis<%s db:%d>", c.opt.Addr, c.opt.DB) return fmt.Sprintf("Redis<%s db:%d>", c.opt.Addr, c.opt.DB)
} }
func (c *baseClient) conn() (*conn, error) { func (c *baseClient) conn() (*conn, bool, error) {
return c.connPool.Get() return c.connPool.Get()
} }
func (c *baseClient) putConn(cn *conn, ei error) { func (c *baseClient) putConn(cn *conn, ei error) {
var err error var err error
if cn.rd.Buffered() > 0 { if isBadConn(cn, ei) {
err = c.connPool.Remove(cn) err = c.connPool.Remove(cn)
} else if ei == nil {
err = c.connPool.Put(cn)
} else if _, ok := ei.(redisError); ok {
err = c.connPool.Put(cn)
} else { } else {
err = c.connPool.Remove(cn) err = c.connPool.Put(cn)
} }
if err != nil { if err != nil {
log.Printf("redis: putConn failed: %s", err) log.Printf("redis: putConn failed: %s", err)
@ -42,7 +38,7 @@ func (c *baseClient) process(cmd Cmder) {
cmd.reset() cmd.reset()
} }
cn, err := c.conn() cn, _, err := c.conn()
if err != nil { if err != nil {
cmd.setErr(err) cmd.setErr(err)
return return

View File

@ -157,7 +157,7 @@ var _ = Describe("Client", func() {
}) })
// Put bad connection in the pool. // Put bad connection in the pool.
cn, err := client.Pool().Get() cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{}) cn.SetNetConn(&badConn{})

View File

@ -313,7 +313,7 @@ func (pipe *RingPipeline) Exec() (cmds []Cmder, retErr error) {
for name, cmds := range cmdsMap { for name, cmds := range cmdsMap {
client := pipe.ring.shards[name].Client client := pipe.ring.shards[name].Client
cn, err := client.conn() cn, _, err := client.conn()
if err != nil { if err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)
if retErr == nil { if retErr == nil {