diff --git a/cluster.go b/cluster.go index e294a1a..571d5c3 100644 --- a/cluster.go +++ b/cluster.go @@ -387,7 +387,13 @@ func (c *ClusterClient) cmdSlotAndNode(state *clusterState, cmd Cmder) (int, *cl } func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error { - node, err := c.state().slotMasterNode(hashtag.Slot(keys[0])) + var node *clusterNode + var err error + if len(keys) > 0 { + node, err = c.state().slotMasterNode(hashtag.Slot(keys[0])) + } else { + node, err = c.nodes.Random() + } if err != nil { return err } @@ -612,10 +618,10 @@ func (c *ClusterClient) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) { } func (c *ClusterClient) pipelineExec(cmds []Cmder) error { - var retErr error - setRetErr := func(err error) { - if retErr == nil { - retErr = err + var firstErr error + setFirstErr := func(err error) { + if firstErr == nil { + firstErr = err } } @@ -625,7 +631,7 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error { _, node, err := c.cmdSlotAndNode(state, cmd) if err != nil { cmd.setErr(err) - setRetErr(err) + setFirstErr(err) continue } cmdsMap[node] = append(cmdsMap[node], cmd) @@ -638,13 +644,13 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error { cn, _, err := node.Client.conn() if err != nil { setCmdsErr(cmds, err) - setRetErr(err) + setFirstErr(err) continue } failedCmds, err = c.execClusterCmds(cn, cmds, failedCmds) if err != nil { - setRetErr(err) + setFirstErr(err) } node.Client.putConn(cn, err, false) } @@ -652,24 +658,28 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error { cmdsMap = failedCmds } - return retErr + return firstErr } func (c *ClusterClient) execClusterCmds( cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, ) (map[*clusterNode][]Cmder, error) { + cn.SetWriteTimeout(c.opt.WriteTimeout) if err := writeCmd(cn, cmds...); err != nil { setCmdsErr(cmds, err) return failedCmds, err } - var retErr error - setRetErr := func(err error) { - if retErr == nil { - retErr = err + var firstErr error + setFirstErr := func(err error) { + if firstErr == nil { + firstErr = err } } + // Set read timeout for all commands. + cn.SetReadTimeout(c.opt.ReadTimeout) + for i, cmd := range cmds { err := cmd.readReply(cn) if err == nil { @@ -688,7 +698,7 @@ func (c *ClusterClient) execClusterCmds( node, err := c.nodes.Get(addr) if err != nil { - setRetErr(err) + setFirstErr(err) continue } @@ -697,16 +707,16 @@ func (c *ClusterClient) execClusterCmds( } else if ask { node, err := c.nodes.Get(addr) if err != nil { - setRetErr(err) + setFirstErr(err) continue } cmd.reset() failedCmds[node] = append(failedCmds[node], NewCmd("ASKING"), cmd) } else { - setRetErr(err) + setFirstErr(err) } } - return failedCmds, retErr + return failedCmds, firstErr } diff --git a/cluster_test.go b/cluster_test.go index 4a3dd7f..2c49f99 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -483,45 +483,124 @@ var _ = Describe("ClusterClient", func() { describeClusterClient() }) +}) - Describe("ClusterClient without nodes", func() { - BeforeEach(func() { - client = redis.NewClusterClient(&redis.ClusterOptions{}) +var _ = Describe("ClusterClient without nodes", func() { + var client *redis.ClusterClient + + BeforeEach(func() { + client = redis.NewClusterClient(&redis.ClusterOptions{}) + }) + + AfterEach(func() { + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + It("returns an error", func() { + err := client.Ping().Err() + Expect(err).To(MatchError("redis: cluster has no nodes")) + }) + + It("pipeline returns an error", func() { + _, err := client.Pipelined(func(pipe *redis.Pipeline) error { + pipe.Ping() + return nil }) + Expect(err).To(MatchError("redis: cluster has no nodes")) + }) +}) - It("returns an error", func() { - err := client.Ping().Err() - Expect(err).To(MatchError("redis: cluster has no nodes")) - }) +var _ = Describe("ClusterClient without valid nodes", func() { + var client *redis.ClusterClient - It("pipeline returns an error", func() { - _, err := client.Pipelined(func(pipe *redis.Pipeline) error { - pipe.Ping() - return nil - }) - Expect(err).To(MatchError("redis: cluster has no nodes")) + BeforeEach(func() { + client = redis.NewClusterClient(&redis.ClusterOptions{ + Addrs: []string{redisAddr}, }) }) - Describe("ClusterClient without valid nodes", func() { - BeforeEach(func() { - client = redis.NewClusterClient(&redis.ClusterOptions{ - Addrs: []string{redisAddr}, - }) - }) + AfterEach(func() { + Expect(client.Close()).NotTo(HaveOccurred()) + }) - It("returns an error", func() { + It("returns an error", func() { + err := client.Ping().Err() + Expect(err).To(MatchError("ERR This instance has cluster support disabled")) + }) + + It("pipeline returns an error", func() { + _, err := client.Pipelined(func(pipe *redis.Pipeline) error { + pipe.Ping() + return nil + }) + Expect(err).To(MatchError("ERR This instance has cluster support disabled")) + }) +}) + +var _ = Describe("ClusterClient timeout", func() { + var client *redis.ClusterClient + + AfterEach(func() { + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + testTimeout := func() { + It("Ping timeouts", func() { err := client.Ping().Err() - Expect(err).To(MatchError("ERR This instance has cluster support disabled")) + Expect(err).To(HaveOccurred()) + Expect(err.(net.Error).Timeout()).To(BeTrue()) }) - It("pipeline returns an error", func() { + It("Pipeline timeouts", func() { _, err := client.Pipelined(func(pipe *redis.Pipeline) error { pipe.Ping() return nil }) - Expect(err).To(MatchError("ERR This instance has cluster support disabled")) + Expect(err).To(HaveOccurred()) + Expect(err.(net.Error).Timeout()).To(BeTrue()) }) + + It("Tx timeouts", func() { + err := client.Watch(func(tx *redis.Tx) error { + return tx.Ping().Err() + }) + Expect(err).To(HaveOccurred()) + Expect(err.(net.Error).Timeout()).To(BeTrue()) + }) + + It("Tx Pipeline timeouts", func() { + err := client.Watch(func(tx *redis.Tx) error { + _, err := tx.Pipelined(func(pipe *redis.Pipeline) error { + pipe.Ping() + return nil + }) + return err + }) + Expect(err).To(HaveOccurred()) + Expect(err.(net.Error).Timeout()).To(BeTrue()) + }) + } + + Context("read timeout", func() { + BeforeEach(func() { + opt := redisClusterOptions() + opt.ReadTimeout = time.Nanosecond + opt.WriteTimeout = -1 + client = cluster.clusterClient(opt) + }) + + testTimeout() + }) + + Context("write timeout", func() { + BeforeEach(func() { + opt := redisClusterOptions() + opt.ReadTimeout = time.Nanosecond + opt.WriteTimeout = -1 + client = cluster.clusterClient(opt) + }) + + testTimeout() }) }) diff --git a/command.go b/command.go index 0108f8b..2bdf816 100644 --- a/command.go +++ b/command.go @@ -64,7 +64,7 @@ func writeCmd(cn *pool.Conn, cmds ...Cmder) error { } } - _, err := cn.Write(cn.Wb.Bytes()) + _, err := cn.NetConn.Write(cn.Wb.Bytes()) return err } diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 785fc21..b716cc2 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -18,9 +18,6 @@ type Conn struct { Inited bool UsedAt time.Time - - ReadTimeout time.Duration - WriteTimeout time.Duration } func NewConn(netConn net.Conn) *Conn { @@ -30,7 +27,7 @@ func NewConn(netConn net.Conn) *Conn { UsedAt: time.Now(), } - cn.Rd = proto.NewReader(cn) + cn.Rd = proto.NewReader(cn.NetConn) return cn } @@ -38,28 +35,21 @@ func (cn *Conn) IsStale(timeout time.Duration) bool { return timeout > 0 && time.Since(cn.UsedAt) > timeout } -func (cn *Conn) Read(b []byte) (int, error) { +func (cn *Conn) SetReadTimeout(timeout time.Duration) error { cn.UsedAt = time.Now() - if cn.ReadTimeout != 0 { - cn.NetConn.SetReadDeadline(cn.UsedAt.Add(cn.ReadTimeout)) - } else { - cn.NetConn.SetReadDeadline(noDeadline) + if timeout > 0 { + return cn.NetConn.SetReadDeadline(cn.UsedAt.Add(timeout)) } - return cn.NetConn.Read(b) + return cn.NetConn.SetReadDeadline(noDeadline) + } -func (cn *Conn) Write(b []byte) (int, error) { +func (cn *Conn) SetWriteTimeout(timeout time.Duration) error { cn.UsedAt = time.Now() - if cn.WriteTimeout != 0 { - cn.NetConn.SetWriteDeadline(cn.UsedAt.Add(cn.WriteTimeout)) - } else { - cn.NetConn.SetWriteDeadline(noDeadline) + if timeout > 0 { + return cn.NetConn.SetWriteDeadline(cn.UsedAt.Add(timeout)) } - return cn.NetConn.Write(b) -} - -func (cn *Conn) RemoteAddr() net.Addr { - return cn.NetConn.RemoteAddr() + return cn.NetConn.SetWriteDeadline(noDeadline) } func (cn *Conn) Close() error { diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 389a3d2..b7e8977 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -266,19 +266,19 @@ func (p *ConnPool) Closed() bool { return atomic.LoadInt32(&p._closed) == 1 } -func (p *ConnPool) Close() (retErr error) { +func (p *ConnPool) Close() error { if !atomic.CompareAndSwapInt32(&p._closed, 0, 1) { return ErrClosed } p.connsMu.Lock() - // Close all connections. + var firstErr error for _, cn := range p.conns { if cn == nil { continue } - if err := p.closeConn(cn, ErrClosed); err != nil && retErr == nil { - retErr = err + if err := p.closeConn(cn, ErrClosed); err != nil && firstErr == nil { + firstErr = err } } p.conns = nil @@ -288,7 +288,7 @@ func (p *ConnPool) Close() (retErr error) { p.freeConns = nil p.freeConnsMu.Unlock() - return retErr + return firstErr } func (p *ConnPool) closeConn(cn *Conn, reason error) error { diff --git a/internal/pool/pool_sticky.go b/internal/pool/pool_sticky.go index ce45f4b..d25bf99 100644 --- a/internal/pool/pool_sticky.go +++ b/internal/pool/pool_sticky.go @@ -49,7 +49,7 @@ func (p *StickyConnPool) Get() (*Conn, bool, error) { return cn, true, nil } -func (p *StickyConnPool) put() (err error) { +func (p *StickyConnPool) putUpstream() (err error) { err = p.pool.Put(p.cn) p.cn = nil return err @@ -67,7 +67,7 @@ func (p *StickyConnPool) Put(cn *Conn) error { return nil } -func (p *StickyConnPool) remove(reason error) error { +func (p *StickyConnPool) removeUpstream(reason error) error { err := p.pool.Remove(p.cn, reason) p.cn = nil return err @@ -85,7 +85,7 @@ func (p *StickyConnPool) Remove(cn *Conn, reason error) error { if cn != nil && p.cn != cn { panic("p.cn != cn") } - return p.remove(reason) + return p.removeUpstream(reason) } func (p *StickyConnPool) Len() int { @@ -120,10 +120,10 @@ func (p *StickyConnPool) Close() error { var err error if p.cn != nil { if p.reusable { - err = p.put() + err = p.putUpstream() } else { - reason := errors.New("redis: sticky not reusable connection") - err = p.remove(reason) + reason := errors.New("redis: unreusable sticky connection") + err = p.removeUpstream(reason) } } return err diff --git a/main_test.go b/main_test.go index 67886e4..e48dfc9 100644 --- a/main_test.go +++ b/main_test.go @@ -2,6 +2,7 @@ package redis_test import ( "errors" + "fmt" "net" "os" "os/exec" @@ -159,8 +160,7 @@ func perform(n int, cbs ...func(int)) { func eventually(fn func() error, timeout time.Duration) error { var exit int32 - var retErr error - var mu sync.Mutex + errCh := make(chan error) done := make(chan struct{}) go func() { @@ -172,9 +172,10 @@ func eventually(fn func() error, timeout time.Duration) error { close(done) return } - mu.Lock() - retErr = err - mu.Unlock() + select { + case errCh <- err: + default: + } time.Sleep(timeout / 100) } }() @@ -184,10 +185,12 @@ func eventually(fn func() error, timeout time.Duration) error { return nil case <-time.After(timeout): atomic.StoreInt32(&exit, 1) - mu.Lock() - err := retErr - mu.Unlock() - return err + select { + case err := <-errCh: + return err + default: + return fmt.Errorf("timeout after %s", timeout) + } } } diff --git a/options.go b/options.go index 77ffbce..a7f7fd7 100644 --- a/options.go +++ b/options.go @@ -90,9 +90,13 @@ func (opt *Options) init() { } if opt.ReadTimeout == 0 { opt.ReadTimeout = 3 * time.Second + } else if opt.ReadTimeout == -1 { + opt.ReadTimeout = 0 } if opt.WriteTimeout == 0 { opt.WriteTimeout = opt.ReadTimeout + } else if opt.WriteTimeout == -1 { + opt.WriteTimeout = 0 } if opt.PoolTimeout == 0 { opt.PoolTimeout = opt.ReadTimeout + time.Second diff --git a/pipeline.go b/pipeline.go index fcd1f99..ef5510b 100644 --- a/pipeline.go +++ b/pipeline.go @@ -1,9 +1,9 @@ package redis import ( + "errors" "sync" - "gopkg.in/redis.v5/internal" "gopkg.in/redis.v5/internal/pool" ) @@ -67,7 +67,7 @@ func (c *Pipeline) Exec() ([]Cmder, error) { } if len(c.cmds) == 0 { - return c.cmds, nil + return nil, errors.New("redis: pipeline is empty") } cmds := c.cmds @@ -84,24 +84,3 @@ func (c *Pipeline) pipelined(fn func(*Pipeline) error) ([]Cmder, error) { _ = c.Close() return cmds, err } - -func execCmds(cn *pool.Conn, cmds []Cmder) (retry bool, firstErr error) { - if err := writeCmd(cn, cmds...); err != nil { - setCmdsErr(cmds, err) - return true, err - } - - for i, cmd := range cmds { - err := cmd.readReply(cn) - if err == nil { - continue - } - if i == 0 && internal.IsNetworkError(err) { - return true, err - } - if firstErr == nil { - firstErr = err - } - } - return false, firstErr -} diff --git a/pipeline_test.go b/pipeline_test.go index 01fcc09..6c6fb96 100644 --- a/pipeline_test.go +++ b/pipeline_test.go @@ -4,13 +4,13 @@ import ( "strconv" "sync" + "gopkg.in/redis.v5" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - - "gopkg.in/redis.v5" ) -var _ = Describe("Pipelining", func() { +var _ = Describe("Pipeline", func() { var client *redis.Client BeforeEach(func() { @@ -51,15 +51,12 @@ var _ = Describe("Pipelining", func() { Expect(getNil.Val()).To(Equal("")) }) - It("should discard", func() { + It("discards queued commands", func() { pipeline := client.Pipeline() - pipeline.Get("key") pipeline.Discard() - cmds, err := pipeline.Exec() - Expect(err).NotTo(HaveOccurred()) - Expect(cmds).To(HaveLen(0)) - Expect(pipeline.Close()).NotTo(HaveOccurred()) + _, err := pipeline.Exec() + Expect(err).To(MatchError("redis: pipeline is empty")) }) It("should support block style", func() { @@ -84,12 +81,10 @@ var _ = Describe("Pipelining", func() { Expect(pipeline.Close()).NotTo(HaveOccurred()) }) - It("should pipeline with empty queue", func() { + It("returns an error when there are no commands", func() { pipeline := client.Pipeline() - cmds, err := pipeline.Exec() - Expect(err).NotTo(HaveOccurred()) - Expect(cmds).To(HaveLen(0)) - Expect(pipeline.Close()).NotTo(HaveOccurred()) + _, err := pipeline.Exec() + Expect(err).To(MatchError("redis: pipeline is empty")) }) It("should increment correctly", func() { diff --git a/pubsub.go b/pubsub.go index 223518a..c9205a0 100644 --- a/pubsub.go +++ b/pubsub.go @@ -35,12 +35,6 @@ func (c *PubSub) putConn(cn *pool.Conn, err error) { } func (c *PubSub) subscribe(redisCmd string, channels ...string) error { - cn, _, err := c.conn() - if err != nil { - return err - } - c.putConn(cn, err) - args := make([]interface{}, 1+len(channels)) args[0] = redisCmd for i, channel := range channels { @@ -48,7 +42,15 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error { } cmd := NewSliceCmd(args...) - return writeCmd(cn, cmd) + cn, _, err := c.conn() + if err != nil { + return err + } + + cn.SetWriteTimeout(c.base.opt.WriteTimeout) + err = writeCmd(cn, cmd) + c.putConn(cn, err) + return err } // Subscribes the client to the specified channels. @@ -94,17 +96,21 @@ func (c *PubSub) Close() error { } func (c *PubSub) Ping(payload string) error { - cn, _, err := c.conn() - if err != nil { - return err - } - args := []interface{}{"PING"} if payload != "" { args = append(args, payload) } cmd := NewCmd(args...) - return writeCmd(cn, cmd) + + cn, _, err := c.conn() + if err != nil { + return err + } + + cn.SetWriteTimeout(c.base.opt.WriteTimeout) + err = writeCmd(cn, cmd) + c.putConn(cn, err) + return err } // Message received after a successful subscription to channel. @@ -176,13 +182,14 @@ func (c *PubSub) newMessage(reply []interface{}) (interface{}, error) { // is not received in time. This is low-level API and most clients // should use ReceiveMessage. func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { + cmd := NewSliceCmd() + cn, _, err := c.conn() if err != nil { return nil, err } - cn.ReadTimeout = timeout - cmd := NewSliceCmd() + cn.SetReadTimeout(timeout) err = cmd.readReply(cn) c.putConn(cn, err) if err != nil { diff --git a/pubsub_test.go b/pubsub_test.go index ecbea76..f366213 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -315,7 +315,7 @@ var _ = Describe("PubSub", func() { Eventually(done).Should(Receive()) stats := client.PoolStats() - Expect(stats.Requests).To(Equal(uint32(4))) + Expect(stats.Requests).To(Equal(uint32(3))) Expect(stats.Hits).To(Equal(uint32(1))) } diff --git a/redis.go b/redis.go index 1c5fdd1..6fcd5c4 100644 --- a/redis.go +++ b/redis.go @@ -3,6 +3,7 @@ package redis // import "gopkg.in/redis.v5" import ( "fmt" "log" + "time" "gopkg.in/redis.v5/internal" "gopkg.in/redis.v5/internal/pool" @@ -105,14 +106,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { return err } - readTimeout := cmd.readTimeout() - if readTimeout != nil { - cn.ReadTimeout = *readTimeout - } else { - cn.ReadTimeout = c.opt.ReadTimeout - } - cn.WriteTimeout = c.opt.WriteTimeout - + cn.SetWriteTimeout(c.opt.WriteTimeout) if err := writeCmd(cn, cmd); err != nil { c.putConn(cn, err, false) cmd.setErr(err) @@ -122,8 +116,9 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { return err } + cn.SetReadTimeout(c.cmdTimeout(cmd)) err = cmd.readReply(cn) - c.putConn(cn, err, readTimeout != nil) + c.putConn(cn, err, false) if err != nil && internal.IsRetryableError(err) { continue } @@ -134,6 +129,14 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { return cmd.Err() } +func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration { + if timeout := cmd.readTimeout(); timeout != nil { + return *timeout + } else { + return c.opt.ReadTimeout + } +} + func (c *baseClient) closed() bool { return c.connPool.Closed() } @@ -143,16 +146,16 @@ func (c *baseClient) closed() bool { // It is rare to Close a Client, as the Client is meant to be // long-lived and shared between many goroutines. func (c *baseClient) Close() error { - var retErr error + var firstErr error if c.onClose != nil { - if err := c.onClose(); err != nil && retErr == nil { - retErr = err + if err := c.onClose(); err != nil && firstErr == nil { + firstErr = err } } - if err := c.connPool.Close(); err != nil && retErr == nil { - retErr = err + if err := c.connPool.Close(); err != nil && firstErr == nil { + firstErr = err } - return retErr + return firstErr } func (c *baseClient) getAddr() string { @@ -225,7 +228,7 @@ func (c *Client) pipelineExec(cmds []Cmder) error { return err } - retry, err := execCmds(cn, cmds) + retry, err := c.execCmds(cn, cmds) c.putConn(cn, err, false) if err == nil { return nil @@ -240,6 +243,31 @@ func (c *Client) pipelineExec(cmds []Cmder) error { return firstErr } +func (c *Client) execCmds(cn *pool.Conn, cmds []Cmder) (retry bool, firstErr error) { + cn.SetWriteTimeout(c.opt.WriteTimeout) + if err := writeCmd(cn, cmds...); err != nil { + setCmdsErr(cmds, err) + return true, err + } + + // Set read timeout for all commands. + cn.SetReadTimeout(c.opt.ReadTimeout) + + for i, cmd := range cmds { + err := cmd.readReply(cn) + if err == nil { + continue + } + if i == 0 && internal.IsNetworkError(err) { + return true, err + } + if firstErr == nil { + firstErr = err + } + } + return false, firstErr +} + func (c *Client) pubSub() *PubSub { return &PubSub{ base: baseClient{ diff --git a/redis_test.go b/redis_test.go index 9bb4b68..e15c871 100644 --- a/redis_test.go +++ b/redis_test.go @@ -3,11 +3,12 @@ package redis_test import ( "bytes" "net" + "time" + + "gopkg.in/redis.v5" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - - "gopkg.in/redis.v5" ) var _ = Describe("Client", func() { @@ -15,7 +16,7 @@ var _ = Describe("Client", func() { BeforeEach(func() { client = redis.NewClient(redisOptions()) - Expect(client.FlushDb().Err()).To(BeNil()) + Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) }) AfterEach(func() { @@ -174,7 +175,7 @@ var _ = Describe("Client", func() { Expect(cn.UsedAt.After(createdAt)).To(BeTrue()) }) - It("should escape special chars", func() { + It("should process command with special chars", func() { set := client.Set("key", "hello1\r\nhello2\r\n", 0) Expect(set.Err()).NotTo(HaveOccurred()) Expect(set.Val()).To(Equal("OK")) @@ -191,12 +192,84 @@ var _ = Describe("Client", func() { Expect(err).NotTo(HaveOccurred()) // Reconnect to get new connection. - Expect(client.Close()).To(BeNil()) + Expect(client.Close()).NotTo(HaveOccurred()) client = redis.NewClient(redisOptions()) got, err := client.Get("key").Bytes() Expect(err).NotTo(HaveOccurred()) Expect(got).To(Equal(bigVal)) }) - +}) + +var _ = Describe("Client timeout", func() { + var client *redis.Client + + AfterEach(func() { + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + testTimeout := func() { + It("Ping timeouts", func() { + err := client.Ping().Err() + Expect(err).To(HaveOccurred()) + Expect(err.(net.Error).Timeout()).To(BeTrue()) + }) + + It("Pipeline timeouts", func() { + _, err := client.Pipelined(func(pipe *redis.Pipeline) error { + pipe.Ping() + return nil + }) + Expect(err).To(HaveOccurred()) + Expect(err.(net.Error).Timeout()).To(BeTrue()) + }) + + It("Subscribe timeouts", func() { + _, err := client.Subscribe("_") + Expect(err).To(HaveOccurred()) + Expect(err.(net.Error).Timeout()).To(BeTrue()) + }) + + It("Tx timeouts", func() { + err := client.Watch(func(tx *redis.Tx) error { + return tx.Ping().Err() + }) + Expect(err).To(HaveOccurred()) + Expect(err.(net.Error).Timeout()).To(BeTrue()) + }) + + It("Tx Pipeline timeouts", func() { + err := client.Watch(func(tx *redis.Tx) error { + _, err := tx.Pipelined(func(pipe *redis.Pipeline) error { + pipe.Ping() + return nil + }) + return err + }) + Expect(err).To(HaveOccurred()) + Expect(err.(net.Error).Timeout()).To(BeTrue()) + }) + } + + Context("read timeout", func() { + BeforeEach(func() { + opt := redisOptions() + opt.ReadTimeout = time.Nanosecond + opt.WriteTimeout = -1 + client = redis.NewClient(opt) + }) + + testTimeout() + }) + + Context("write timeout", func() { + BeforeEach(func() { + opt := redisOptions() + opt.ReadTimeout = -1 + opt.WriteTimeout = time.Nanosecond + client = redis.NewClient(opt) + }) + + testTimeout() + }) }) diff --git a/ring.go b/ring.go index 53c3f11..11945b4 100644 --- a/ring.go +++ b/ring.go @@ -332,7 +332,7 @@ func (c *Ring) heartbeat() { // // It is rare to Close a Ring, as the Ring is meant to be long-lived // and shared between many goroutines. -func (c *Ring) Close() (retErr error) { +func (c *Ring) Close() error { defer c.mu.Unlock() c.mu.Lock() @@ -341,15 +341,16 @@ func (c *Ring) Close() (retErr error) { } c.closed = true + var firstErr error for _, shard := range c.shards { - if err := shard.Client.Close(); err != nil { - retErr = err + if err := shard.Client.Close(); err != nil && firstErr == nil { + firstErr = err } } c.hash = nil c.shards = nil - return retErr + return firstErr } func (c *Ring) Pipeline() *Pipeline { @@ -402,7 +403,7 @@ func (c *Ring) pipelineExec(cmds []Cmder) (firstErr error) { continue } - retry, err := execCmds(cn, cmds) + retry, err := shard.Client.execCmds(cn, cmds) shard.Client.putConn(cn, err, false) if err == nil { continue diff --git a/sentinel.go b/sentinel.go index 0849dda..2a32647 100644 --- a/sentinel.go +++ b/sentinel.go @@ -267,10 +267,10 @@ func (d *sentinelFailover) closeOldConns(newMaster string) { if cn == nil { break } - if cn.RemoteAddr().String() != newMaster { + if cn.NetConn.RemoteAddr().String() != newMaster { err := fmt.Errorf( "sentinel: closing connection to the old master %s", - cn.RemoteAddr(), + cn.NetConn.RemoteAddr(), ) internal.Logf(err.Error()) d.pool.Remove(cn, err) diff --git a/tx.go b/tx.go index 2a45990..772b3c9 100644 --- a/tx.go +++ b/tx.go @@ -45,11 +45,11 @@ func (c *Client) Watch(fn func(*Tx) error, keys ...string) error { return err } } - retErr := fn(tx) - if err := tx.close(); err != nil && retErr == nil { - retErr = err + firstErr := fn(tx) + if err := tx.close(); err != nil && firstErr == nil { + firstErr = err } - return retErr + return firstErr } // close closes the transaction, releasing any open resources. @@ -133,12 +133,16 @@ func (c *Tx) exec(cmds []Cmder) error { } func (c *Tx) execCmds(cn *pool.Conn, cmds []Cmder) error { + cn.SetWriteTimeout(c.opt.WriteTimeout) err := writeCmd(cn, cmds...) if err != nil { setCmdsErr(cmds[1:len(cmds)-1], err) return err } + // Set read timeout for all commands. + cn.SetReadTimeout(c.opt.ReadTimeout) + // Omit last command (EXEC). cmdsLen := len(cmds) - 1 diff --git a/tx_test.go b/tx_test.go index 9631e4c..156a890 100644 --- a/tx_test.go +++ b/tx_test.go @@ -86,14 +86,12 @@ var _ = Describe("Tx", func() { Expect(get.Val()).To(Equal("hello2")) }) - It("should exec empty", func() { + It("returns an error when there are no commands", func() { err := client.Watch(func(tx *redis.Tx) error { - cmds, err := tx.Pipelined(func(*redis.Pipeline) error { return nil }) - Expect(err).NotTo(HaveOccurred()) - Expect(cmds).To(HaveLen(0)) + _, err := tx.Pipelined(func(*redis.Pipeline) error { return nil }) return err }) - Expect(err).NotTo(HaveOccurred()) + Expect(err).To(MatchError("redis: pipeline is empty")) v, err := client.Ping().Result() Expect(err).NotTo(HaveOccurred()) @@ -150,30 +148,4 @@ var _ = Describe("Tx", func() { err = do() Expect(err).NotTo(HaveOccurred()) }) - - It("should recover from bad connection when there are no commands", func() { - // Put bad connection in the pool. - cn, _, err := client.Pool().Get() - Expect(err).NotTo(HaveOccurred()) - - cn.NetConn = &badConn{} - err = client.Pool().Put(cn) - Expect(err).NotTo(HaveOccurred()) - - do := func() error { - err := client.Watch(func(tx *redis.Tx) error { - _, err := tx.Pipelined(func(pipe *redis.Pipeline) error { - return nil - }) - return err - }, "key") - return err - } - - err = do() - Expect(err).To(MatchError("bad connection")) - - err = do() - Expect(err).NotTo(HaveOccurred()) - }) })