Set read/write timeouts more consistently.

This commit is contained in:
Vladimir Mihailenco 2016-12-03 17:30:13 +02:00
parent e7f23a300b
commit b4efc45f1c
18 changed files with 343 additions and 198 deletions

View File

@ -387,7 +387,13 @@ func (c *ClusterClient) cmdSlotAndNode(state *clusterState, cmd Cmder) (int, *cl
}
func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error {
node, err := c.state().slotMasterNode(hashtag.Slot(keys[0]))
var node *clusterNode
var err error
if len(keys) > 0 {
node, err = c.state().slotMasterNode(hashtag.Slot(keys[0]))
} else {
node, err = c.nodes.Random()
}
if err != nil {
return err
}
@ -612,10 +618,10 @@ func (c *ClusterClient) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) {
}
func (c *ClusterClient) pipelineExec(cmds []Cmder) error {
var retErr error
setRetErr := func(err error) {
if retErr == nil {
retErr = err
var firstErr error
setFirstErr := func(err error) {
if firstErr == nil {
firstErr = err
}
}
@ -625,7 +631,7 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error {
_, node, err := c.cmdSlotAndNode(state, cmd)
if err != nil {
cmd.setErr(err)
setRetErr(err)
setFirstErr(err)
continue
}
cmdsMap[node] = append(cmdsMap[node], cmd)
@ -638,13 +644,13 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error {
cn, _, err := node.Client.conn()
if err != nil {
setCmdsErr(cmds, err)
setRetErr(err)
setFirstErr(err)
continue
}
failedCmds, err = c.execClusterCmds(cn, cmds, failedCmds)
if err != nil {
setRetErr(err)
setFirstErr(err)
}
node.Client.putConn(cn, err, false)
}
@ -652,24 +658,28 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error {
cmdsMap = failedCmds
}
return retErr
return firstErr
}
func (c *ClusterClient) execClusterCmds(
cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder,
) (map[*clusterNode][]Cmder, error) {
cn.SetWriteTimeout(c.opt.WriteTimeout)
if err := writeCmd(cn, cmds...); err != nil {
setCmdsErr(cmds, err)
return failedCmds, err
}
var retErr error
setRetErr := func(err error) {
if retErr == nil {
retErr = err
var firstErr error
setFirstErr := func(err error) {
if firstErr == nil {
firstErr = err
}
}
// Set read timeout for all commands.
cn.SetReadTimeout(c.opt.ReadTimeout)
for i, cmd := range cmds {
err := cmd.readReply(cn)
if err == nil {
@ -688,7 +698,7 @@ func (c *ClusterClient) execClusterCmds(
node, err := c.nodes.Get(addr)
if err != nil {
setRetErr(err)
setFirstErr(err)
continue
}
@ -697,16 +707,16 @@ func (c *ClusterClient) execClusterCmds(
} else if ask {
node, err := c.nodes.Get(addr)
if err != nil {
setRetErr(err)
setFirstErr(err)
continue
}
cmd.reset()
failedCmds[node] = append(failedCmds[node], NewCmd("ASKING"), cmd)
} else {
setRetErr(err)
setFirstErr(err)
}
}
return failedCmds, retErr
return failedCmds, firstErr
}

View File

@ -483,12 +483,19 @@ var _ = Describe("ClusterClient", func() {
describeClusterClient()
})
})
var _ = Describe("ClusterClient without nodes", func() {
var client *redis.ClusterClient
Describe("ClusterClient without nodes", func() {
BeforeEach(func() {
client = redis.NewClusterClient(&redis.ClusterOptions{})
})
AfterEach(func() {
Expect(client.Close()).NotTo(HaveOccurred())
})
It("returns an error", func() {
err := client.Ping().Err()
Expect(err).To(MatchError("redis: cluster has no nodes"))
@ -501,15 +508,21 @@ var _ = Describe("ClusterClient", func() {
})
Expect(err).To(MatchError("redis: cluster has no nodes"))
})
})
})
var _ = Describe("ClusterClient without valid nodes", func() {
var client *redis.ClusterClient
Describe("ClusterClient without valid nodes", func() {
BeforeEach(func() {
client = redis.NewClusterClient(&redis.ClusterOptions{
Addrs: []string{redisAddr},
})
})
AfterEach(func() {
Expect(client.Close()).NotTo(HaveOccurred())
})
It("returns an error", func() {
err := client.Ping().Err()
Expect(err).To(MatchError("ERR This instance has cluster support disabled"))
@ -522,6 +535,72 @@ var _ = Describe("ClusterClient", func() {
})
Expect(err).To(MatchError("ERR This instance has cluster support disabled"))
})
})
var _ = Describe("ClusterClient timeout", func() {
var client *redis.ClusterClient
AfterEach(func() {
Expect(client.Close()).NotTo(HaveOccurred())
})
testTimeout := func() {
It("Ping timeouts", func() {
err := client.Ping().Err()
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Pipeline timeouts", func() {
_, err := client.Pipelined(func(pipe *redis.Pipeline) error {
pipe.Ping()
return nil
})
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Tx timeouts", func() {
err := client.Watch(func(tx *redis.Tx) error {
return tx.Ping().Err()
})
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Tx Pipeline timeouts", func() {
err := client.Watch(func(tx *redis.Tx) error {
_, err := tx.Pipelined(func(pipe *redis.Pipeline) error {
pipe.Ping()
return nil
})
return err
})
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
}
Context("read timeout", func() {
BeforeEach(func() {
opt := redisClusterOptions()
opt.ReadTimeout = time.Nanosecond
opt.WriteTimeout = -1
client = cluster.clusterClient(opt)
})
testTimeout()
})
Context("write timeout", func() {
BeforeEach(func() {
opt := redisClusterOptions()
opt.ReadTimeout = time.Nanosecond
opt.WriteTimeout = -1
client = cluster.clusterClient(opt)
})
testTimeout()
})
})

View File

@ -64,7 +64,7 @@ func writeCmd(cn *pool.Conn, cmds ...Cmder) error {
}
}
_, err := cn.Write(cn.Wb.Bytes())
_, err := cn.NetConn.Write(cn.Wb.Bytes())
return err
}

View File

@ -18,9 +18,6 @@ type Conn struct {
Inited bool
UsedAt time.Time
ReadTimeout time.Duration
WriteTimeout time.Duration
}
func NewConn(netConn net.Conn) *Conn {
@ -30,7 +27,7 @@ func NewConn(netConn net.Conn) *Conn {
UsedAt: time.Now(),
}
cn.Rd = proto.NewReader(cn)
cn.Rd = proto.NewReader(cn.NetConn)
return cn
}
@ -38,28 +35,21 @@ func (cn *Conn) IsStale(timeout time.Duration) bool {
return timeout > 0 && time.Since(cn.UsedAt) > timeout
}
func (cn *Conn) Read(b []byte) (int, error) {
func (cn *Conn) SetReadTimeout(timeout time.Duration) error {
cn.UsedAt = time.Now()
if cn.ReadTimeout != 0 {
cn.NetConn.SetReadDeadline(cn.UsedAt.Add(cn.ReadTimeout))
} else {
cn.NetConn.SetReadDeadline(noDeadline)
if timeout > 0 {
return cn.NetConn.SetReadDeadline(cn.UsedAt.Add(timeout))
}
return cn.NetConn.Read(b)
return cn.NetConn.SetReadDeadline(noDeadline)
}
func (cn *Conn) Write(b []byte) (int, error) {
func (cn *Conn) SetWriteTimeout(timeout time.Duration) error {
cn.UsedAt = time.Now()
if cn.WriteTimeout != 0 {
cn.NetConn.SetWriteDeadline(cn.UsedAt.Add(cn.WriteTimeout))
} else {
cn.NetConn.SetWriteDeadline(noDeadline)
if timeout > 0 {
return cn.NetConn.SetWriteDeadline(cn.UsedAt.Add(timeout))
}
return cn.NetConn.Write(b)
}
func (cn *Conn) RemoteAddr() net.Addr {
return cn.NetConn.RemoteAddr()
return cn.NetConn.SetWriteDeadline(noDeadline)
}
func (cn *Conn) Close() error {

View File

@ -266,19 +266,19 @@ func (p *ConnPool) Closed() bool {
return atomic.LoadInt32(&p._closed) == 1
}
func (p *ConnPool) Close() (retErr error) {
func (p *ConnPool) Close() error {
if !atomic.CompareAndSwapInt32(&p._closed, 0, 1) {
return ErrClosed
}
p.connsMu.Lock()
// Close all connections.
var firstErr error
for _, cn := range p.conns {
if cn == nil {
continue
}
if err := p.closeConn(cn, ErrClosed); err != nil && retErr == nil {
retErr = err
if err := p.closeConn(cn, ErrClosed); err != nil && firstErr == nil {
firstErr = err
}
}
p.conns = nil
@ -288,7 +288,7 @@ func (p *ConnPool) Close() (retErr error) {
p.freeConns = nil
p.freeConnsMu.Unlock()
return retErr
return firstErr
}
func (p *ConnPool) closeConn(cn *Conn, reason error) error {

View File

@ -49,7 +49,7 @@ func (p *StickyConnPool) Get() (*Conn, bool, error) {
return cn, true, nil
}
func (p *StickyConnPool) put() (err error) {
func (p *StickyConnPool) putUpstream() (err error) {
err = p.pool.Put(p.cn)
p.cn = nil
return err
@ -67,7 +67,7 @@ func (p *StickyConnPool) Put(cn *Conn) error {
return nil
}
func (p *StickyConnPool) remove(reason error) error {
func (p *StickyConnPool) removeUpstream(reason error) error {
err := p.pool.Remove(p.cn, reason)
p.cn = nil
return err
@ -85,7 +85,7 @@ func (p *StickyConnPool) Remove(cn *Conn, reason error) error {
if cn != nil && p.cn != cn {
panic("p.cn != cn")
}
return p.remove(reason)
return p.removeUpstream(reason)
}
func (p *StickyConnPool) Len() int {
@ -120,10 +120,10 @@ func (p *StickyConnPool) Close() error {
var err error
if p.cn != nil {
if p.reusable {
err = p.put()
err = p.putUpstream()
} else {
reason := errors.New("redis: sticky not reusable connection")
err = p.remove(reason)
reason := errors.New("redis: unreusable sticky connection")
err = p.removeUpstream(reason)
}
}
return err

View File

@ -2,6 +2,7 @@ package redis_test
import (
"errors"
"fmt"
"net"
"os"
"os/exec"
@ -159,8 +160,7 @@ func perform(n int, cbs ...func(int)) {
func eventually(fn func() error, timeout time.Duration) error {
var exit int32
var retErr error
var mu sync.Mutex
errCh := make(chan error)
done := make(chan struct{})
go func() {
@ -172,9 +172,10 @@ func eventually(fn func() error, timeout time.Duration) error {
close(done)
return
}
mu.Lock()
retErr = err
mu.Unlock()
select {
case errCh <- err:
default:
}
time.Sleep(timeout / 100)
}
}()
@ -184,10 +185,12 @@ func eventually(fn func() error, timeout time.Duration) error {
return nil
case <-time.After(timeout):
atomic.StoreInt32(&exit, 1)
mu.Lock()
err := retErr
mu.Unlock()
select {
case err := <-errCh:
return err
default:
return fmt.Errorf("timeout after %s", timeout)
}
}
}

View File

@ -90,9 +90,13 @@ func (opt *Options) init() {
}
if opt.ReadTimeout == 0 {
opt.ReadTimeout = 3 * time.Second
} else if opt.ReadTimeout == -1 {
opt.ReadTimeout = 0
}
if opt.WriteTimeout == 0 {
opt.WriteTimeout = opt.ReadTimeout
} else if opt.WriteTimeout == -1 {
opt.WriteTimeout = 0
}
if opt.PoolTimeout == 0 {
opt.PoolTimeout = opt.ReadTimeout + time.Second

View File

@ -1,9 +1,9 @@
package redis
import (
"errors"
"sync"
"gopkg.in/redis.v5/internal"
"gopkg.in/redis.v5/internal/pool"
)
@ -67,7 +67,7 @@ func (c *Pipeline) Exec() ([]Cmder, error) {
}
if len(c.cmds) == 0 {
return c.cmds, nil
return nil, errors.New("redis: pipeline is empty")
}
cmds := c.cmds
@ -84,24 +84,3 @@ func (c *Pipeline) pipelined(fn func(*Pipeline) error) ([]Cmder, error) {
_ = c.Close()
return cmds, err
}
func execCmds(cn *pool.Conn, cmds []Cmder) (retry bool, firstErr error) {
if err := writeCmd(cn, cmds...); err != nil {
setCmdsErr(cmds, err)
return true, err
}
for i, cmd := range cmds {
err := cmd.readReply(cn)
if err == nil {
continue
}
if i == 0 && internal.IsNetworkError(err) {
return true, err
}
if firstErr == nil {
firstErr = err
}
}
return false, firstErr
}

View File

@ -4,13 +4,13 @@ import (
"strconv"
"sync"
"gopkg.in/redis.v5"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"gopkg.in/redis.v5"
)
var _ = Describe("Pipelining", func() {
var _ = Describe("Pipeline", func() {
var client *redis.Client
BeforeEach(func() {
@ -51,15 +51,12 @@ var _ = Describe("Pipelining", func() {
Expect(getNil.Val()).To(Equal(""))
})
It("should discard", func() {
It("discards queued commands", func() {
pipeline := client.Pipeline()
pipeline.Get("key")
pipeline.Discard()
cmds, err := pipeline.Exec()
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(0))
Expect(pipeline.Close()).NotTo(HaveOccurred())
_, err := pipeline.Exec()
Expect(err).To(MatchError("redis: pipeline is empty"))
})
It("should support block style", func() {
@ -84,12 +81,10 @@ var _ = Describe("Pipelining", func() {
Expect(pipeline.Close()).NotTo(HaveOccurred())
})
It("should pipeline with empty queue", func() {
It("returns an error when there are no commands", func() {
pipeline := client.Pipeline()
cmds, err := pipeline.Exec()
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(0))
Expect(pipeline.Close()).NotTo(HaveOccurred())
_, err := pipeline.Exec()
Expect(err).To(MatchError("redis: pipeline is empty"))
})
It("should increment correctly", func() {

View File

@ -35,12 +35,6 @@ func (c *PubSub) putConn(cn *pool.Conn, err error) {
}
func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
cn, _, err := c.conn()
if err != nil {
return err
}
c.putConn(cn, err)
args := make([]interface{}, 1+len(channels))
args[0] = redisCmd
for i, channel := range channels {
@ -48,7 +42,15 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
}
cmd := NewSliceCmd(args...)
return writeCmd(cn, cmd)
cn, _, err := c.conn()
if err != nil {
return err
}
cn.SetWriteTimeout(c.base.opt.WriteTimeout)
err = writeCmd(cn, cmd)
c.putConn(cn, err)
return err
}
// Subscribes the client to the specified channels.
@ -94,17 +96,21 @@ func (c *PubSub) Close() error {
}
func (c *PubSub) Ping(payload string) error {
cn, _, err := c.conn()
if err != nil {
return err
}
args := []interface{}{"PING"}
if payload != "" {
args = append(args, payload)
}
cmd := NewCmd(args...)
return writeCmd(cn, cmd)
cn, _, err := c.conn()
if err != nil {
return err
}
cn.SetWriteTimeout(c.base.opt.WriteTimeout)
err = writeCmd(cn, cmd)
c.putConn(cn, err)
return err
}
// Message received after a successful subscription to channel.
@ -176,13 +182,14 @@ func (c *PubSub) newMessage(reply []interface{}) (interface{}, error) {
// is not received in time. This is low-level API and most clients
// should use ReceiveMessage.
func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
cmd := NewSliceCmd()
cn, _, err := c.conn()
if err != nil {
return nil, err
}
cn.ReadTimeout = timeout
cmd := NewSliceCmd()
cn.SetReadTimeout(timeout)
err = cmd.readReply(cn)
c.putConn(cn, err)
if err != nil {

View File

@ -315,7 +315,7 @@ var _ = Describe("PubSub", func() {
Eventually(done).Should(Receive())
stats := client.PoolStats()
Expect(stats.Requests).To(Equal(uint32(4)))
Expect(stats.Requests).To(Equal(uint32(3)))
Expect(stats.Hits).To(Equal(uint32(1)))
}

View File

@ -3,6 +3,7 @@ package redis // import "gopkg.in/redis.v5"
import (
"fmt"
"log"
"time"
"gopkg.in/redis.v5/internal"
"gopkg.in/redis.v5/internal/pool"
@ -105,14 +106,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error {
return err
}
readTimeout := cmd.readTimeout()
if readTimeout != nil {
cn.ReadTimeout = *readTimeout
} else {
cn.ReadTimeout = c.opt.ReadTimeout
}
cn.WriteTimeout = c.opt.WriteTimeout
cn.SetWriteTimeout(c.opt.WriteTimeout)
if err := writeCmd(cn, cmd); err != nil {
c.putConn(cn, err, false)
cmd.setErr(err)
@ -122,8 +116,9 @@ func (c *baseClient) defaultProcess(cmd Cmder) error {
return err
}
cn.SetReadTimeout(c.cmdTimeout(cmd))
err = cmd.readReply(cn)
c.putConn(cn, err, readTimeout != nil)
c.putConn(cn, err, false)
if err != nil && internal.IsRetryableError(err) {
continue
}
@ -134,6 +129,14 @@ func (c *baseClient) defaultProcess(cmd Cmder) error {
return cmd.Err()
}
func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
if timeout := cmd.readTimeout(); timeout != nil {
return *timeout
} else {
return c.opt.ReadTimeout
}
}
func (c *baseClient) closed() bool {
return c.connPool.Closed()
}
@ -143,16 +146,16 @@ func (c *baseClient) closed() bool {
// It is rare to Close a Client, as the Client is meant to be
// long-lived and shared between many goroutines.
func (c *baseClient) Close() error {
var retErr error
var firstErr error
if c.onClose != nil {
if err := c.onClose(); err != nil && retErr == nil {
retErr = err
if err := c.onClose(); err != nil && firstErr == nil {
firstErr = err
}
}
if err := c.connPool.Close(); err != nil && retErr == nil {
retErr = err
if err := c.connPool.Close(); err != nil && firstErr == nil {
firstErr = err
}
return retErr
return firstErr
}
func (c *baseClient) getAddr() string {
@ -225,7 +228,7 @@ func (c *Client) pipelineExec(cmds []Cmder) error {
return err
}
retry, err := execCmds(cn, cmds)
retry, err := c.execCmds(cn, cmds)
c.putConn(cn, err, false)
if err == nil {
return nil
@ -240,6 +243,31 @@ func (c *Client) pipelineExec(cmds []Cmder) error {
return firstErr
}
func (c *Client) execCmds(cn *pool.Conn, cmds []Cmder) (retry bool, firstErr error) {
cn.SetWriteTimeout(c.opt.WriteTimeout)
if err := writeCmd(cn, cmds...); err != nil {
setCmdsErr(cmds, err)
return true, err
}
// Set read timeout for all commands.
cn.SetReadTimeout(c.opt.ReadTimeout)
for i, cmd := range cmds {
err := cmd.readReply(cn)
if err == nil {
continue
}
if i == 0 && internal.IsNetworkError(err) {
return true, err
}
if firstErr == nil {
firstErr = err
}
}
return false, firstErr
}
func (c *Client) pubSub() *PubSub {
return &PubSub{
base: baseClient{

View File

@ -3,11 +3,12 @@ package redis_test
import (
"bytes"
"net"
"time"
"gopkg.in/redis.v5"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"gopkg.in/redis.v5"
)
var _ = Describe("Client", func() {
@ -15,7 +16,7 @@ var _ = Describe("Client", func() {
BeforeEach(func() {
client = redis.NewClient(redisOptions())
Expect(client.FlushDb().Err()).To(BeNil())
Expect(client.FlushDb().Err()).NotTo(HaveOccurred())
})
AfterEach(func() {
@ -174,7 +175,7 @@ var _ = Describe("Client", func() {
Expect(cn.UsedAt.After(createdAt)).To(BeTrue())
})
It("should escape special chars", func() {
It("should process command with special chars", func() {
set := client.Set("key", "hello1\r\nhello2\r\n", 0)
Expect(set.Err()).NotTo(HaveOccurred())
Expect(set.Val()).To(Equal("OK"))
@ -191,12 +192,84 @@ var _ = Describe("Client", func() {
Expect(err).NotTo(HaveOccurred())
// Reconnect to get new connection.
Expect(client.Close()).To(BeNil())
Expect(client.Close()).NotTo(HaveOccurred())
client = redis.NewClient(redisOptions())
got, err := client.Get("key").Bytes()
Expect(err).NotTo(HaveOccurred())
Expect(got).To(Equal(bigVal))
})
})
var _ = Describe("Client timeout", func() {
var client *redis.Client
AfterEach(func() {
Expect(client.Close()).NotTo(HaveOccurred())
})
testTimeout := func() {
It("Ping timeouts", func() {
err := client.Ping().Err()
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Pipeline timeouts", func() {
_, err := client.Pipelined(func(pipe *redis.Pipeline) error {
pipe.Ping()
return nil
})
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Subscribe timeouts", func() {
_, err := client.Subscribe("_")
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Tx timeouts", func() {
err := client.Watch(func(tx *redis.Tx) error {
return tx.Ping().Err()
})
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Tx Pipeline timeouts", func() {
err := client.Watch(func(tx *redis.Tx) error {
_, err := tx.Pipelined(func(pipe *redis.Pipeline) error {
pipe.Ping()
return nil
})
return err
})
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
}
Context("read timeout", func() {
BeforeEach(func() {
opt := redisOptions()
opt.ReadTimeout = time.Nanosecond
opt.WriteTimeout = -1
client = redis.NewClient(opt)
})
testTimeout()
})
Context("write timeout", func() {
BeforeEach(func() {
opt := redisOptions()
opt.ReadTimeout = -1
opt.WriteTimeout = time.Nanosecond
client = redis.NewClient(opt)
})
testTimeout()
})
})

11
ring.go
View File

@ -332,7 +332,7 @@ func (c *Ring) heartbeat() {
//
// It is rare to Close a Ring, as the Ring is meant to be long-lived
// and shared between many goroutines.
func (c *Ring) Close() (retErr error) {
func (c *Ring) Close() error {
defer c.mu.Unlock()
c.mu.Lock()
@ -341,15 +341,16 @@ func (c *Ring) Close() (retErr error) {
}
c.closed = true
var firstErr error
for _, shard := range c.shards {
if err := shard.Client.Close(); err != nil {
retErr = err
if err := shard.Client.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
c.hash = nil
c.shards = nil
return retErr
return firstErr
}
func (c *Ring) Pipeline() *Pipeline {
@ -402,7 +403,7 @@ func (c *Ring) pipelineExec(cmds []Cmder) (firstErr error) {
continue
}
retry, err := execCmds(cn, cmds)
retry, err := shard.Client.execCmds(cn, cmds)
shard.Client.putConn(cn, err, false)
if err == nil {
continue

View File

@ -267,10 +267,10 @@ func (d *sentinelFailover) closeOldConns(newMaster string) {
if cn == nil {
break
}
if cn.RemoteAddr().String() != newMaster {
if cn.NetConn.RemoteAddr().String() != newMaster {
err := fmt.Errorf(
"sentinel: closing connection to the old master %s",
cn.RemoteAddr(),
cn.NetConn.RemoteAddr(),
)
internal.Logf(err.Error())
d.pool.Remove(cn, err)

12
tx.go
View File

@ -45,11 +45,11 @@ func (c *Client) Watch(fn func(*Tx) error, keys ...string) error {
return err
}
}
retErr := fn(tx)
if err := tx.close(); err != nil && retErr == nil {
retErr = err
firstErr := fn(tx)
if err := tx.close(); err != nil && firstErr == nil {
firstErr = err
}
return retErr
return firstErr
}
// close closes the transaction, releasing any open resources.
@ -133,12 +133,16 @@ func (c *Tx) exec(cmds []Cmder) error {
}
func (c *Tx) execCmds(cn *pool.Conn, cmds []Cmder) error {
cn.SetWriteTimeout(c.opt.WriteTimeout)
err := writeCmd(cn, cmds...)
if err != nil {
setCmdsErr(cmds[1:len(cmds)-1], err)
return err
}
// Set read timeout for all commands.
cn.SetReadTimeout(c.opt.ReadTimeout)
// Omit last command (EXEC).
cmdsLen := len(cmds) - 1

View File

@ -86,14 +86,12 @@ var _ = Describe("Tx", func() {
Expect(get.Val()).To(Equal("hello2"))
})
It("should exec empty", func() {
It("returns an error when there are no commands", func() {
err := client.Watch(func(tx *redis.Tx) error {
cmds, err := tx.Pipelined(func(*redis.Pipeline) error { return nil })
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(0))
_, err := tx.Pipelined(func(*redis.Pipeline) error { return nil })
return err
})
Expect(err).NotTo(HaveOccurred())
Expect(err).To(MatchError("redis: pipeline is empty"))
v, err := client.Ping().Result()
Expect(err).NotTo(HaveOccurred())
@ -150,30 +148,4 @@ var _ = Describe("Tx", func() {
err = do()
Expect(err).NotTo(HaveOccurred())
})
It("should recover from bad connection when there are no commands", func() {
// Put bad connection in the pool.
cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{}
err = client.Pool().Put(cn)
Expect(err).NotTo(HaveOccurred())
do := func() error {
err := client.Watch(func(tx *redis.Tx) error {
_, err := tx.Pipelined(func(pipe *redis.Pipeline) error {
return nil
})
return err
}, "key")
return err
}
err = do()
Expect(err).To(MatchError("bad connection"))
err = do()
Expect(err).NotTo(HaveOccurred())
})
})