diff --git a/cluster.go b/cluster.go index ec5a51f..f9a4fef 100644 --- a/cluster.go +++ b/cluster.go @@ -156,7 +156,7 @@ func (c *clusterNodes) All() ([]*clusterNode, error) { return nil, pool.ErrClosed } - var nodes []*clusterNode + nodes := make([]*clusterNode, 0, len(c.nodes)) for _, node := range c.nodes { nodes = append(nodes, node) } @@ -208,7 +208,7 @@ func (c *clusterNodes) Random() (*clusterNode, error) { } var nodeErr error - for i := 0; i < 10; i++ { + for i := 0; i <= c.opt.MaxRedirects; i++ { n := rand.Intn(len(addrs)) node, err := c.Get(addrs[n]) if err != nil { @@ -446,6 +446,10 @@ func (c *ClusterClient) Process(cmd Cmder) error { // On network errors try random node. if internal.IsRetryableError(err) { node, err = c.nodes.Random() + if err != nil { + cmd.setErr(err) + return err + } continue } @@ -475,6 +479,39 @@ func (c *ClusterClient) Process(cmd Cmder) error { return cmd.Err() } +// ForEachNode concurrently calls the fn on each ever known node in the cluster. +// It returns the first error if any. +func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error { + nodes, err := c.nodes.All() + if err != nil { + return err + } + + var wg sync.WaitGroup + errCh := make(chan error, 1) + for _, node := range nodes { + wg.Add(1) + go func(node *clusterNode) { + defer wg.Done() + err := fn(node.Client) + if err != nil { + select { + case errCh <- err: + default: + } + } + }(node) + } + wg.Wait() + + select { + case err := <-errCh: + return err + default: + return nil + } +} + // ForEachMaster concurrently calls the fn on each master node in the cluster. // It returns the first error if any. func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { @@ -649,10 +686,10 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error { } failedCmds, err = c.execClusterCmds(cn, cmds, failedCmds) + node.Client.putConn(cn, err, false) if err != nil { setFirstErr(err) } - node.Client.putConn(cn, err, false) } cmdsMap = failedCmds @@ -686,9 +723,15 @@ func (c *ClusterClient) execClusterCmds( continue } - if i == 0 && internal.IsNetworkError(err) { + if i == 0 && internal.IsRetryableError(err) { + node, err := c.nodes.Random() + if err != nil { + setFirstErr(err) + continue + } + cmd.reset() - failedCmds[nil] = append(failedCmds[nil], cmds...) + failedCmds[node] = append(failedCmds[node], cmds...) break } diff --git a/cluster_test.go b/cluster_test.go index 2c49f99..3406768 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -602,6 +602,33 @@ var _ = Describe("ClusterClient timeout", func() { testTimeout() }) + + Context("network timeout", func() { + const pause = time.Second + + BeforeEach(func() { + opt := redisClusterOptions() + opt.ReadTimeout = 100 * time.Millisecond + opt.WriteTimeout = 100 * time.Millisecond + opt.MaxRedirects = 1 + client = cluster.clusterClient(opt) + + err := client.ForEachNode(func(client *redis.Client) error { + return client.ClientPause(pause).Err() + }) + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + Eventually(func() error { + return client.ForEachNode(func(client *redis.Client) error { + return client.Ping().Err() + }) + }, pause).ShouldNot(HaveOccurred()) + }) + + testTimeout() + }) }) //------------------------------------------------------------------------------ diff --git a/tx.go b/tx.go index 772b3c9..30beda9 100644 --- a/tx.go +++ b/tx.go @@ -176,10 +176,8 @@ func (c *Tx) execCmds(cn *pool.Conn, cmds []Cmder) error { // Loop starts from 1 to omit MULTI cmd. for i := 1; i < cmdsLen; i++ { cmd := cmds[i] - if err := cmd.readReply(cn); err != nil { - if firstErr == nil { - firstErr = err - } + if err := cmd.readReply(cn); err != nil && firstErr == nil { + firstErr = err } }