Merge pull request #434 from go-redis/fix/timeouts

Set read/write timeouts more consistently.
This commit is contained in:
Vladimir Mihailenco 2016-12-03 17:45:52 +02:00 committed by GitHub
commit 854c88a72c
18 changed files with 343 additions and 198 deletions

View File

@ -387,7 +387,13 @@ func (c *ClusterClient) cmdSlotAndNode(state *clusterState, cmd Cmder) (int, *cl
} }
func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error { func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error {
node, err := c.state().slotMasterNode(hashtag.Slot(keys[0])) var node *clusterNode
var err error
if len(keys) > 0 {
node, err = c.state().slotMasterNode(hashtag.Slot(keys[0]))
} else {
node, err = c.nodes.Random()
}
if err != nil { if err != nil {
return err return err
} }
@ -612,10 +618,10 @@ func (c *ClusterClient) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) {
} }
func (c *ClusterClient) pipelineExec(cmds []Cmder) error { func (c *ClusterClient) pipelineExec(cmds []Cmder) error {
var retErr error var firstErr error
setRetErr := func(err error) { setFirstErr := func(err error) {
if retErr == nil { if firstErr == nil {
retErr = err firstErr = err
} }
} }
@ -625,7 +631,7 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error {
_, node, err := c.cmdSlotAndNode(state, cmd) _, node, err := c.cmdSlotAndNode(state, cmd)
if err != nil { if err != nil {
cmd.setErr(err) cmd.setErr(err)
setRetErr(err) setFirstErr(err)
continue continue
} }
cmdsMap[node] = append(cmdsMap[node], cmd) cmdsMap[node] = append(cmdsMap[node], cmd)
@ -638,13 +644,13 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error {
cn, _, err := node.Client.conn() cn, _, err := node.Client.conn()
if err != nil { if err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)
setRetErr(err) setFirstErr(err)
continue continue
} }
failedCmds, err = c.execClusterCmds(cn, cmds, failedCmds) failedCmds, err = c.execClusterCmds(cn, cmds, failedCmds)
if err != nil { if err != nil {
setRetErr(err) setFirstErr(err)
} }
node.Client.putConn(cn, err, false) node.Client.putConn(cn, err, false)
} }
@ -652,24 +658,28 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error {
cmdsMap = failedCmds cmdsMap = failedCmds
} }
return retErr return firstErr
} }
func (c *ClusterClient) execClusterCmds( func (c *ClusterClient) execClusterCmds(
cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder,
) (map[*clusterNode][]Cmder, error) { ) (map[*clusterNode][]Cmder, error) {
cn.SetWriteTimeout(c.opt.WriteTimeout)
if err := writeCmd(cn, cmds...); err != nil { if err := writeCmd(cn, cmds...); err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)
return failedCmds, err return failedCmds, err
} }
var retErr error var firstErr error
setRetErr := func(err error) { setFirstErr := func(err error) {
if retErr == nil { if firstErr == nil {
retErr = err firstErr = err
} }
} }
// Set read timeout for all commands.
cn.SetReadTimeout(c.opt.ReadTimeout)
for i, cmd := range cmds { for i, cmd := range cmds {
err := cmd.readReply(cn) err := cmd.readReply(cn)
if err == nil { if err == nil {
@ -688,7 +698,7 @@ func (c *ClusterClient) execClusterCmds(
node, err := c.nodes.Get(addr) node, err := c.nodes.Get(addr)
if err != nil { if err != nil {
setRetErr(err) setFirstErr(err)
continue continue
} }
@ -697,16 +707,16 @@ func (c *ClusterClient) execClusterCmds(
} else if ask { } else if ask {
node, err := c.nodes.Get(addr) node, err := c.nodes.Get(addr)
if err != nil { if err != nil {
setRetErr(err) setFirstErr(err)
continue continue
} }
cmd.reset() cmd.reset()
failedCmds[node] = append(failedCmds[node], NewCmd("ASKING"), cmd) failedCmds[node] = append(failedCmds[node], NewCmd("ASKING"), cmd)
} else { } else {
setRetErr(err) setFirstErr(err)
} }
} }
return failedCmds, retErr return failedCmds, firstErr
} }

View File

@ -483,12 +483,19 @@ var _ = Describe("ClusterClient", func() {
describeClusterClient() describeClusterClient()
}) })
})
var _ = Describe("ClusterClient without nodes", func() {
var client *redis.ClusterClient
Describe("ClusterClient without nodes", func() {
BeforeEach(func() { BeforeEach(func() {
client = redis.NewClusterClient(&redis.ClusterOptions{}) client = redis.NewClusterClient(&redis.ClusterOptions{})
}) })
AfterEach(func() {
Expect(client.Close()).NotTo(HaveOccurred())
})
It("returns an error", func() { It("returns an error", func() {
err := client.Ping().Err() err := client.Ping().Err()
Expect(err).To(MatchError("redis: cluster has no nodes")) Expect(err).To(MatchError("redis: cluster has no nodes"))
@ -503,13 +510,19 @@ var _ = Describe("ClusterClient", func() {
}) })
}) })
Describe("ClusterClient without valid nodes", func() { var _ = Describe("ClusterClient without valid nodes", func() {
var client *redis.ClusterClient
BeforeEach(func() { BeforeEach(func() {
client = redis.NewClusterClient(&redis.ClusterOptions{ client = redis.NewClusterClient(&redis.ClusterOptions{
Addrs: []string{redisAddr}, Addrs: []string{redisAddr},
}) })
}) })
AfterEach(func() {
Expect(client.Close()).NotTo(HaveOccurred())
})
It("returns an error", func() { It("returns an error", func() {
err := client.Ping().Err() err := client.Ping().Err()
Expect(err).To(MatchError("ERR This instance has cluster support disabled")) Expect(err).To(MatchError("ERR This instance has cluster support disabled"))
@ -523,6 +536,72 @@ var _ = Describe("ClusterClient", func() {
Expect(err).To(MatchError("ERR This instance has cluster support disabled")) Expect(err).To(MatchError("ERR This instance has cluster support disabled"))
}) })
}) })
var _ = Describe("ClusterClient timeout", func() {
var client *redis.ClusterClient
AfterEach(func() {
Expect(client.Close()).NotTo(HaveOccurred())
})
testTimeout := func() {
It("Ping timeouts", func() {
err := client.Ping().Err()
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Pipeline timeouts", func() {
_, err := client.Pipelined(func(pipe *redis.Pipeline) error {
pipe.Ping()
return nil
})
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Tx timeouts", func() {
err := client.Watch(func(tx *redis.Tx) error {
return tx.Ping().Err()
})
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Tx Pipeline timeouts", func() {
err := client.Watch(func(tx *redis.Tx) error {
_, err := tx.Pipelined(func(pipe *redis.Pipeline) error {
pipe.Ping()
return nil
})
return err
})
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
}
Context("read timeout", func() {
BeforeEach(func() {
opt := redisClusterOptions()
opt.ReadTimeout = time.Nanosecond
opt.WriteTimeout = -1
client = cluster.clusterClient(opt)
})
testTimeout()
})
Context("write timeout", func() {
BeforeEach(func() {
opt := redisClusterOptions()
opt.ReadTimeout = time.Nanosecond
opt.WriteTimeout = -1
client = cluster.clusterClient(opt)
})
testTimeout()
})
}) })
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------

View File

@ -64,7 +64,7 @@ func writeCmd(cn *pool.Conn, cmds ...Cmder) error {
} }
} }
_, err := cn.Write(cn.Wb.Bytes()) _, err := cn.NetConn.Write(cn.Wb.Bytes())
return err return err
} }

View File

@ -18,9 +18,6 @@ type Conn struct {
Inited bool Inited bool
UsedAt time.Time UsedAt time.Time
ReadTimeout time.Duration
WriteTimeout time.Duration
} }
func NewConn(netConn net.Conn) *Conn { func NewConn(netConn net.Conn) *Conn {
@ -30,7 +27,7 @@ func NewConn(netConn net.Conn) *Conn {
UsedAt: time.Now(), UsedAt: time.Now(),
} }
cn.Rd = proto.NewReader(cn) cn.Rd = proto.NewReader(cn.NetConn)
return cn return cn
} }
@ -38,28 +35,21 @@ func (cn *Conn) IsStale(timeout time.Duration) bool {
return timeout > 0 && time.Since(cn.UsedAt) > timeout return timeout > 0 && time.Since(cn.UsedAt) > timeout
} }
func (cn *Conn) Read(b []byte) (int, error) { func (cn *Conn) SetReadTimeout(timeout time.Duration) error {
cn.UsedAt = time.Now() cn.UsedAt = time.Now()
if cn.ReadTimeout != 0 { if timeout > 0 {
cn.NetConn.SetReadDeadline(cn.UsedAt.Add(cn.ReadTimeout)) return cn.NetConn.SetReadDeadline(cn.UsedAt.Add(timeout))
} else {
cn.NetConn.SetReadDeadline(noDeadline)
} }
return cn.NetConn.Read(b) return cn.NetConn.SetReadDeadline(noDeadline)
} }
func (cn *Conn) Write(b []byte) (int, error) { func (cn *Conn) SetWriteTimeout(timeout time.Duration) error {
cn.UsedAt = time.Now() cn.UsedAt = time.Now()
if cn.WriteTimeout != 0 { if timeout > 0 {
cn.NetConn.SetWriteDeadline(cn.UsedAt.Add(cn.WriteTimeout)) return cn.NetConn.SetWriteDeadline(cn.UsedAt.Add(timeout))
} else {
cn.NetConn.SetWriteDeadline(noDeadline)
} }
return cn.NetConn.Write(b) return cn.NetConn.SetWriteDeadline(noDeadline)
}
func (cn *Conn) RemoteAddr() net.Addr {
return cn.NetConn.RemoteAddr()
} }
func (cn *Conn) Close() error { func (cn *Conn) Close() error {

View File

@ -266,19 +266,19 @@ func (p *ConnPool) Closed() bool {
return atomic.LoadInt32(&p._closed) == 1 return atomic.LoadInt32(&p._closed) == 1
} }
func (p *ConnPool) Close() (retErr error) { func (p *ConnPool) Close() error {
if !atomic.CompareAndSwapInt32(&p._closed, 0, 1) { if !atomic.CompareAndSwapInt32(&p._closed, 0, 1) {
return ErrClosed return ErrClosed
} }
p.connsMu.Lock() p.connsMu.Lock()
// Close all connections. var firstErr error
for _, cn := range p.conns { for _, cn := range p.conns {
if cn == nil { if cn == nil {
continue continue
} }
if err := p.closeConn(cn, ErrClosed); err != nil && retErr == nil { if err := p.closeConn(cn, ErrClosed); err != nil && firstErr == nil {
retErr = err firstErr = err
} }
} }
p.conns = nil p.conns = nil
@ -288,7 +288,7 @@ func (p *ConnPool) Close() (retErr error) {
p.freeConns = nil p.freeConns = nil
p.freeConnsMu.Unlock() p.freeConnsMu.Unlock()
return retErr return firstErr
} }
func (p *ConnPool) closeConn(cn *Conn, reason error) error { func (p *ConnPool) closeConn(cn *Conn, reason error) error {

View File

@ -49,7 +49,7 @@ func (p *StickyConnPool) Get() (*Conn, bool, error) {
return cn, true, nil return cn, true, nil
} }
func (p *StickyConnPool) put() (err error) { func (p *StickyConnPool) putUpstream() (err error) {
err = p.pool.Put(p.cn) err = p.pool.Put(p.cn)
p.cn = nil p.cn = nil
return err return err
@ -67,7 +67,7 @@ func (p *StickyConnPool) Put(cn *Conn) error {
return nil return nil
} }
func (p *StickyConnPool) remove(reason error) error { func (p *StickyConnPool) removeUpstream(reason error) error {
err := p.pool.Remove(p.cn, reason) err := p.pool.Remove(p.cn, reason)
p.cn = nil p.cn = nil
return err return err
@ -85,7 +85,7 @@ func (p *StickyConnPool) Remove(cn *Conn, reason error) error {
if cn != nil && p.cn != cn { if cn != nil && p.cn != cn {
panic("p.cn != cn") panic("p.cn != cn")
} }
return p.remove(reason) return p.removeUpstream(reason)
} }
func (p *StickyConnPool) Len() int { func (p *StickyConnPool) Len() int {
@ -120,10 +120,10 @@ func (p *StickyConnPool) Close() error {
var err error var err error
if p.cn != nil { if p.cn != nil {
if p.reusable { if p.reusable {
err = p.put() err = p.putUpstream()
} else { } else {
reason := errors.New("redis: sticky not reusable connection") reason := errors.New("redis: unreusable sticky connection")
err = p.remove(reason) err = p.removeUpstream(reason)
} }
} }
return err return err

View File

@ -2,6 +2,7 @@ package redis_test
import ( import (
"errors" "errors"
"fmt"
"net" "net"
"os" "os"
"os/exec" "os/exec"
@ -159,8 +160,7 @@ func perform(n int, cbs ...func(int)) {
func eventually(fn func() error, timeout time.Duration) error { func eventually(fn func() error, timeout time.Duration) error {
var exit int32 var exit int32
var retErr error errCh := make(chan error)
var mu sync.Mutex
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
@ -172,9 +172,10 @@ func eventually(fn func() error, timeout time.Duration) error {
close(done) close(done)
return return
} }
mu.Lock() select {
retErr = err case errCh <- err:
mu.Unlock() default:
}
time.Sleep(timeout / 100) time.Sleep(timeout / 100)
} }
}() }()
@ -184,10 +185,12 @@ func eventually(fn func() error, timeout time.Duration) error {
return nil return nil
case <-time.After(timeout): case <-time.After(timeout):
atomic.StoreInt32(&exit, 1) atomic.StoreInt32(&exit, 1)
mu.Lock() select {
err := retErr case err := <-errCh:
mu.Unlock()
return err return err
default:
return fmt.Errorf("timeout after %s", timeout)
}
} }
} }

View File

@ -90,9 +90,13 @@ func (opt *Options) init() {
} }
if opt.ReadTimeout == 0 { if opt.ReadTimeout == 0 {
opt.ReadTimeout = 3 * time.Second opt.ReadTimeout = 3 * time.Second
} else if opt.ReadTimeout == -1 {
opt.ReadTimeout = 0
} }
if opt.WriteTimeout == 0 { if opt.WriteTimeout == 0 {
opt.WriteTimeout = opt.ReadTimeout opt.WriteTimeout = opt.ReadTimeout
} else if opt.WriteTimeout == -1 {
opt.WriteTimeout = 0
} }
if opt.PoolTimeout == 0 { if opt.PoolTimeout == 0 {
opt.PoolTimeout = opt.ReadTimeout + time.Second opt.PoolTimeout = opt.ReadTimeout + time.Second

View File

@ -1,9 +1,9 @@
package redis package redis
import ( import (
"errors"
"sync" "sync"
"gopkg.in/redis.v5/internal"
"gopkg.in/redis.v5/internal/pool" "gopkg.in/redis.v5/internal/pool"
) )
@ -67,7 +67,7 @@ func (c *Pipeline) Exec() ([]Cmder, error) {
} }
if len(c.cmds) == 0 { if len(c.cmds) == 0 {
return c.cmds, nil return nil, errors.New("redis: pipeline is empty")
} }
cmds := c.cmds cmds := c.cmds
@ -84,24 +84,3 @@ func (c *Pipeline) pipelined(fn func(*Pipeline) error) ([]Cmder, error) {
_ = c.Close() _ = c.Close()
return cmds, err return cmds, err
} }
func execCmds(cn *pool.Conn, cmds []Cmder) (retry bool, firstErr error) {
if err := writeCmd(cn, cmds...); err != nil {
setCmdsErr(cmds, err)
return true, err
}
for i, cmd := range cmds {
err := cmd.readReply(cn)
if err == nil {
continue
}
if i == 0 && internal.IsNetworkError(err) {
return true, err
}
if firstErr == nil {
firstErr = err
}
}
return false, firstErr
}

View File

@ -4,13 +4,13 @@ import (
"strconv" "strconv"
"sync" "sync"
"gopkg.in/redis.v5"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"gopkg.in/redis.v5"
) )
var _ = Describe("Pipelining", func() { var _ = Describe("Pipeline", func() {
var client *redis.Client var client *redis.Client
BeforeEach(func() { BeforeEach(func() {
@ -51,15 +51,12 @@ var _ = Describe("Pipelining", func() {
Expect(getNil.Val()).To(Equal("")) Expect(getNil.Val()).To(Equal(""))
}) })
It("should discard", func() { It("discards queued commands", func() {
pipeline := client.Pipeline() pipeline := client.Pipeline()
pipeline.Get("key") pipeline.Get("key")
pipeline.Discard() pipeline.Discard()
cmds, err := pipeline.Exec() _, err := pipeline.Exec()
Expect(err).NotTo(HaveOccurred()) Expect(err).To(MatchError("redis: pipeline is empty"))
Expect(cmds).To(HaveLen(0))
Expect(pipeline.Close()).NotTo(HaveOccurred())
}) })
It("should support block style", func() { It("should support block style", func() {
@ -84,12 +81,10 @@ var _ = Describe("Pipelining", func() {
Expect(pipeline.Close()).NotTo(HaveOccurred()) Expect(pipeline.Close()).NotTo(HaveOccurred())
}) })
It("should pipeline with empty queue", func() { It("returns an error when there are no commands", func() {
pipeline := client.Pipeline() pipeline := client.Pipeline()
cmds, err := pipeline.Exec() _, err := pipeline.Exec()
Expect(err).NotTo(HaveOccurred()) Expect(err).To(MatchError("redis: pipeline is empty"))
Expect(cmds).To(HaveLen(0))
Expect(pipeline.Close()).NotTo(HaveOccurred())
}) })
It("should increment correctly", func() { It("should increment correctly", func() {

View File

@ -35,12 +35,6 @@ func (c *PubSub) putConn(cn *pool.Conn, err error) {
} }
func (c *PubSub) subscribe(redisCmd string, channels ...string) error { func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
cn, _, err := c.conn()
if err != nil {
return err
}
c.putConn(cn, err)
args := make([]interface{}, 1+len(channels)) args := make([]interface{}, 1+len(channels))
args[0] = redisCmd args[0] = redisCmd
for i, channel := range channels { for i, channel := range channels {
@ -48,7 +42,15 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
} }
cmd := NewSliceCmd(args...) cmd := NewSliceCmd(args...)
return writeCmd(cn, cmd) cn, _, err := c.conn()
if err != nil {
return err
}
cn.SetWriteTimeout(c.base.opt.WriteTimeout)
err = writeCmd(cn, cmd)
c.putConn(cn, err)
return err
} }
// Subscribes the client to the specified channels. // Subscribes the client to the specified channels.
@ -94,17 +96,21 @@ func (c *PubSub) Close() error {
} }
func (c *PubSub) Ping(payload string) error { func (c *PubSub) Ping(payload string) error {
cn, _, err := c.conn()
if err != nil {
return err
}
args := []interface{}{"PING"} args := []interface{}{"PING"}
if payload != "" { if payload != "" {
args = append(args, payload) args = append(args, payload)
} }
cmd := NewCmd(args...) cmd := NewCmd(args...)
return writeCmd(cn, cmd)
cn, _, err := c.conn()
if err != nil {
return err
}
cn.SetWriteTimeout(c.base.opt.WriteTimeout)
err = writeCmd(cn, cmd)
c.putConn(cn, err)
return err
} }
// Message received after a successful subscription to channel. // Message received after a successful subscription to channel.
@ -176,13 +182,14 @@ func (c *PubSub) newMessage(reply []interface{}) (interface{}, error) {
// is not received in time. This is low-level API and most clients // is not received in time. This is low-level API and most clients
// should use ReceiveMessage. // should use ReceiveMessage.
func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
cmd := NewSliceCmd()
cn, _, err := c.conn() cn, _, err := c.conn()
if err != nil { if err != nil {
return nil, err return nil, err
} }
cn.ReadTimeout = timeout
cmd := NewSliceCmd() cn.SetReadTimeout(timeout)
err = cmd.readReply(cn) err = cmd.readReply(cn)
c.putConn(cn, err) c.putConn(cn, err)
if err != nil { if err != nil {

View File

@ -315,7 +315,7 @@ var _ = Describe("PubSub", func() {
Eventually(done).Should(Receive()) Eventually(done).Should(Receive())
stats := client.PoolStats() stats := client.PoolStats()
Expect(stats.Requests).To(Equal(uint32(4))) Expect(stats.Requests).To(Equal(uint32(3)))
Expect(stats.Hits).To(Equal(uint32(1))) Expect(stats.Hits).To(Equal(uint32(1)))
} }

View File

@ -3,6 +3,7 @@ package redis // import "gopkg.in/redis.v5"
import ( import (
"fmt" "fmt"
"log" "log"
"time"
"gopkg.in/redis.v5/internal" "gopkg.in/redis.v5/internal"
"gopkg.in/redis.v5/internal/pool" "gopkg.in/redis.v5/internal/pool"
@ -105,14 +106,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error {
return err return err
} }
readTimeout := cmd.readTimeout() cn.SetWriteTimeout(c.opt.WriteTimeout)
if readTimeout != nil {
cn.ReadTimeout = *readTimeout
} else {
cn.ReadTimeout = c.opt.ReadTimeout
}
cn.WriteTimeout = c.opt.WriteTimeout
if err := writeCmd(cn, cmd); err != nil { if err := writeCmd(cn, cmd); err != nil {
c.putConn(cn, err, false) c.putConn(cn, err, false)
cmd.setErr(err) cmd.setErr(err)
@ -122,8 +116,9 @@ func (c *baseClient) defaultProcess(cmd Cmder) error {
return err return err
} }
cn.SetReadTimeout(c.cmdTimeout(cmd))
err = cmd.readReply(cn) err = cmd.readReply(cn)
c.putConn(cn, err, readTimeout != nil) c.putConn(cn, err, false)
if err != nil && internal.IsRetryableError(err) { if err != nil && internal.IsRetryableError(err) {
continue continue
} }
@ -134,6 +129,14 @@ func (c *baseClient) defaultProcess(cmd Cmder) error {
return cmd.Err() return cmd.Err()
} }
func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
if timeout := cmd.readTimeout(); timeout != nil {
return *timeout
} else {
return c.opt.ReadTimeout
}
}
func (c *baseClient) closed() bool { func (c *baseClient) closed() bool {
return c.connPool.Closed() return c.connPool.Closed()
} }
@ -143,16 +146,16 @@ func (c *baseClient) closed() bool {
// It is rare to Close a Client, as the Client is meant to be // It is rare to Close a Client, as the Client is meant to be
// long-lived and shared between many goroutines. // long-lived and shared between many goroutines.
func (c *baseClient) Close() error { func (c *baseClient) Close() error {
var retErr error var firstErr error
if c.onClose != nil { if c.onClose != nil {
if err := c.onClose(); err != nil && retErr == nil { if err := c.onClose(); err != nil && firstErr == nil {
retErr = err firstErr = err
} }
} }
if err := c.connPool.Close(); err != nil && retErr == nil { if err := c.connPool.Close(); err != nil && firstErr == nil {
retErr = err firstErr = err
} }
return retErr return firstErr
} }
func (c *baseClient) getAddr() string { func (c *baseClient) getAddr() string {
@ -225,7 +228,7 @@ func (c *Client) pipelineExec(cmds []Cmder) error {
return err return err
} }
retry, err := execCmds(cn, cmds) retry, err := c.execCmds(cn, cmds)
c.putConn(cn, err, false) c.putConn(cn, err, false)
if err == nil { if err == nil {
return nil return nil
@ -240,6 +243,31 @@ func (c *Client) pipelineExec(cmds []Cmder) error {
return firstErr return firstErr
} }
func (c *Client) execCmds(cn *pool.Conn, cmds []Cmder) (retry bool, firstErr error) {
cn.SetWriteTimeout(c.opt.WriteTimeout)
if err := writeCmd(cn, cmds...); err != nil {
setCmdsErr(cmds, err)
return true, err
}
// Set read timeout for all commands.
cn.SetReadTimeout(c.opt.ReadTimeout)
for i, cmd := range cmds {
err := cmd.readReply(cn)
if err == nil {
continue
}
if i == 0 && internal.IsNetworkError(err) {
return true, err
}
if firstErr == nil {
firstErr = err
}
}
return false, firstErr
}
func (c *Client) pubSub() *PubSub { func (c *Client) pubSub() *PubSub {
return &PubSub{ return &PubSub{
base: baseClient{ base: baseClient{

View File

@ -3,11 +3,12 @@ package redis_test
import ( import (
"bytes" "bytes"
"net" "net"
"time"
"gopkg.in/redis.v5"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"gopkg.in/redis.v5"
) )
var _ = Describe("Client", func() { var _ = Describe("Client", func() {
@ -15,7 +16,7 @@ var _ = Describe("Client", func() {
BeforeEach(func() { BeforeEach(func() {
client = redis.NewClient(redisOptions()) client = redis.NewClient(redisOptions())
Expect(client.FlushDb().Err()).To(BeNil()) Expect(client.FlushDb().Err()).NotTo(HaveOccurred())
}) })
AfterEach(func() { AfterEach(func() {
@ -174,7 +175,7 @@ var _ = Describe("Client", func() {
Expect(cn.UsedAt.After(createdAt)).To(BeTrue()) Expect(cn.UsedAt.After(createdAt)).To(BeTrue())
}) })
It("should escape special chars", func() { It("should process command with special chars", func() {
set := client.Set("key", "hello1\r\nhello2\r\n", 0) set := client.Set("key", "hello1\r\nhello2\r\n", 0)
Expect(set.Err()).NotTo(HaveOccurred()) Expect(set.Err()).NotTo(HaveOccurred())
Expect(set.Val()).To(Equal("OK")) Expect(set.Val()).To(Equal("OK"))
@ -191,12 +192,84 @@ var _ = Describe("Client", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
// Reconnect to get new connection. // Reconnect to get new connection.
Expect(client.Close()).To(BeNil()) Expect(client.Close()).NotTo(HaveOccurred())
client = redis.NewClient(redisOptions()) client = redis.NewClient(redisOptions())
got, err := client.Get("key").Bytes() got, err := client.Get("key").Bytes()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(got).To(Equal(bigVal)) Expect(got).To(Equal(bigVal))
}) })
})
var _ = Describe("Client timeout", func() {
var client *redis.Client
AfterEach(func() {
Expect(client.Close()).NotTo(HaveOccurred())
})
testTimeout := func() {
It("Ping timeouts", func() {
err := client.Ping().Err()
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Pipeline timeouts", func() {
_, err := client.Pipelined(func(pipe *redis.Pipeline) error {
pipe.Ping()
return nil
})
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Subscribe timeouts", func() {
_, err := client.Subscribe("_")
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Tx timeouts", func() {
err := client.Watch(func(tx *redis.Tx) error {
return tx.Ping().Err()
})
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Tx Pipeline timeouts", func() {
err := client.Watch(func(tx *redis.Tx) error {
_, err := tx.Pipelined(func(pipe *redis.Pipeline) error {
pipe.Ping()
return nil
})
return err
})
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
}
Context("read timeout", func() {
BeforeEach(func() {
opt := redisOptions()
opt.ReadTimeout = time.Nanosecond
opt.WriteTimeout = -1
client = redis.NewClient(opt)
})
testTimeout()
})
Context("write timeout", func() {
BeforeEach(func() {
opt := redisOptions()
opt.ReadTimeout = -1
opt.WriteTimeout = time.Nanosecond
client = redis.NewClient(opt)
})
testTimeout()
})
}) })

11
ring.go
View File

@ -332,7 +332,7 @@ func (c *Ring) heartbeat() {
// //
// It is rare to Close a Ring, as the Ring is meant to be long-lived // It is rare to Close a Ring, as the Ring is meant to be long-lived
// and shared between many goroutines. // and shared between many goroutines.
func (c *Ring) Close() (retErr error) { func (c *Ring) Close() error {
defer c.mu.Unlock() defer c.mu.Unlock()
c.mu.Lock() c.mu.Lock()
@ -341,15 +341,16 @@ func (c *Ring) Close() (retErr error) {
} }
c.closed = true c.closed = true
var firstErr error
for _, shard := range c.shards { for _, shard := range c.shards {
if err := shard.Client.Close(); err != nil { if err := shard.Client.Close(); err != nil && firstErr == nil {
retErr = err firstErr = err
} }
} }
c.hash = nil c.hash = nil
c.shards = nil c.shards = nil
return retErr return firstErr
} }
func (c *Ring) Pipeline() *Pipeline { func (c *Ring) Pipeline() *Pipeline {
@ -402,7 +403,7 @@ func (c *Ring) pipelineExec(cmds []Cmder) (firstErr error) {
continue continue
} }
retry, err := execCmds(cn, cmds) retry, err := shard.Client.execCmds(cn, cmds)
shard.Client.putConn(cn, err, false) shard.Client.putConn(cn, err, false)
if err == nil { if err == nil {
continue continue

View File

@ -267,10 +267,10 @@ func (d *sentinelFailover) closeOldConns(newMaster string) {
if cn == nil { if cn == nil {
break break
} }
if cn.RemoteAddr().String() != newMaster { if cn.NetConn.RemoteAddr().String() != newMaster {
err := fmt.Errorf( err := fmt.Errorf(
"sentinel: closing connection to the old master %s", "sentinel: closing connection to the old master %s",
cn.RemoteAddr(), cn.NetConn.RemoteAddr(),
) )
internal.Logf(err.Error()) internal.Logf(err.Error())
d.pool.Remove(cn, err) d.pool.Remove(cn, err)

12
tx.go
View File

@ -45,11 +45,11 @@ func (c *Client) Watch(fn func(*Tx) error, keys ...string) error {
return err return err
} }
} }
retErr := fn(tx) firstErr := fn(tx)
if err := tx.close(); err != nil && retErr == nil { if err := tx.close(); err != nil && firstErr == nil {
retErr = err firstErr = err
} }
return retErr return firstErr
} }
// close closes the transaction, releasing any open resources. // close closes the transaction, releasing any open resources.
@ -133,12 +133,16 @@ func (c *Tx) exec(cmds []Cmder) error {
} }
func (c *Tx) execCmds(cn *pool.Conn, cmds []Cmder) error { func (c *Tx) execCmds(cn *pool.Conn, cmds []Cmder) error {
cn.SetWriteTimeout(c.opt.WriteTimeout)
err := writeCmd(cn, cmds...) err := writeCmd(cn, cmds...)
if err != nil { if err != nil {
setCmdsErr(cmds[1:len(cmds)-1], err) setCmdsErr(cmds[1:len(cmds)-1], err)
return err return err
} }
// Set read timeout for all commands.
cn.SetReadTimeout(c.opt.ReadTimeout)
// Omit last command (EXEC). // Omit last command (EXEC).
cmdsLen := len(cmds) - 1 cmdsLen := len(cmds) - 1

View File

@ -86,14 +86,12 @@ var _ = Describe("Tx", func() {
Expect(get.Val()).To(Equal("hello2")) Expect(get.Val()).To(Equal("hello2"))
}) })
It("should exec empty", func() { It("returns an error when there are no commands", func() {
err := client.Watch(func(tx *redis.Tx) error { err := client.Watch(func(tx *redis.Tx) error {
cmds, err := tx.Pipelined(func(*redis.Pipeline) error { return nil }) _, err := tx.Pipelined(func(*redis.Pipeline) error { return nil })
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(0))
return err return err
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).To(MatchError("redis: pipeline is empty"))
v, err := client.Ping().Result() v, err := client.Ping().Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -150,30 +148,4 @@ var _ = Describe("Tx", func() {
err = do() err = do()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
It("should recover from bad connection when there are no commands", func() {
// Put bad connection in the pool.
cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{}
err = client.Pool().Put(cn)
Expect(err).NotTo(HaveOccurred())
do := func() error {
err := client.Watch(func(tx *redis.Tx) error {
_, err := tx.Pipelined(func(pipe *redis.Pipeline) error {
return nil
})
return err
}, "key")
return err
}
err = do()
Expect(err).To(MatchError("bad connection"))
err = do()
Expect(err).NotTo(HaveOccurred())
})
}) })