Fix nil ptr in case when all nodes are unavailable.

This commit is contained in:
Vladimir Mihailenco 2016-12-12 17:30:08 +02:00
parent 7f2a0bff84
commit c7dfbb54af
3 changed files with 77 additions and 9 deletions

View File

@ -156,7 +156,7 @@ func (c *clusterNodes) All() ([]*clusterNode, error) {
return nil, pool.ErrClosed return nil, pool.ErrClosed
} }
var nodes []*clusterNode nodes := make([]*clusterNode, 0, len(c.nodes))
for _, node := range c.nodes { for _, node := range c.nodes {
nodes = append(nodes, node) nodes = append(nodes, node)
} }
@ -208,7 +208,7 @@ func (c *clusterNodes) Random() (*clusterNode, error) {
} }
var nodeErr error var nodeErr error
for i := 0; i < 10; i++ { for i := 0; i <= c.opt.MaxRedirects; i++ {
n := rand.Intn(len(addrs)) n := rand.Intn(len(addrs))
node, err := c.Get(addrs[n]) node, err := c.Get(addrs[n])
if err != nil { if err != nil {
@ -446,6 +446,10 @@ func (c *ClusterClient) Process(cmd Cmder) error {
// On network errors try random node. // On network errors try random node.
if internal.IsRetryableError(err) { if internal.IsRetryableError(err) {
node, err = c.nodes.Random() node, err = c.nodes.Random()
if err != nil {
cmd.setErr(err)
return err
}
continue continue
} }
@ -475,6 +479,39 @@ func (c *ClusterClient) Process(cmd Cmder) error {
return cmd.Err() return cmd.Err()
} }
// ForEachNode concurrently calls the fn on each ever known node in the cluster.
// It returns the first error if any.
func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error {
nodes, err := c.nodes.All()
if err != nil {
return err
}
var wg sync.WaitGroup
errCh := make(chan error, 1)
for _, node := range nodes {
wg.Add(1)
go func(node *clusterNode) {
defer wg.Done()
err := fn(node.Client)
if err != nil {
select {
case errCh <- err:
default:
}
}
}(node)
}
wg.Wait()
select {
case err := <-errCh:
return err
default:
return nil
}
}
// ForEachMaster concurrently calls the fn on each master node in the cluster. // ForEachMaster concurrently calls the fn on each master node in the cluster.
// It returns the first error if any. // It returns the first error if any.
func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error {
@ -649,10 +686,10 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error {
} }
failedCmds, err = c.execClusterCmds(cn, cmds, failedCmds) failedCmds, err = c.execClusterCmds(cn, cmds, failedCmds)
node.Client.putConn(cn, err, false)
if err != nil { if err != nil {
setFirstErr(err) setFirstErr(err)
} }
node.Client.putConn(cn, err, false)
} }
cmdsMap = failedCmds cmdsMap = failedCmds
@ -686,9 +723,15 @@ func (c *ClusterClient) execClusterCmds(
continue continue
} }
if i == 0 && internal.IsNetworkError(err) { if i == 0 && internal.IsRetryableError(err) {
node, err := c.nodes.Random()
if err != nil {
setFirstErr(err)
continue
}
cmd.reset() cmd.reset()
failedCmds[nil] = append(failedCmds[nil], cmds...) failedCmds[node] = append(failedCmds[node], cmds...)
break break
} }

View File

@ -602,6 +602,33 @@ var _ = Describe("ClusterClient timeout", func() {
testTimeout() testTimeout()
}) })
Context("network timeout", func() {
const pause = time.Second
BeforeEach(func() {
opt := redisClusterOptions()
opt.ReadTimeout = 100 * time.Millisecond
opt.WriteTimeout = 100 * time.Millisecond
opt.MaxRedirects = 1
client = cluster.clusterClient(opt)
err := client.ForEachNode(func(client *redis.Client) error {
return client.ClientPause(pause).Err()
})
Expect(err).NotTo(HaveOccurred())
})
AfterEach(func() {
Eventually(func() error {
return client.ForEachNode(func(client *redis.Client) error {
return client.Ping().Err()
})
}, pause).ShouldNot(HaveOccurred())
})
testTimeout()
})
}) })
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------

6
tx.go
View File

@ -176,10 +176,8 @@ func (c *Tx) execCmds(cn *pool.Conn, cmds []Cmder) error {
// Loop starts from 1 to omit MULTI cmd. // Loop starts from 1 to omit MULTI cmd.
for i := 1; i < cmdsLen; i++ { for i := 1; i < cmdsLen; i++ {
cmd := cmds[i] cmd := cmds[i]
if err := cmd.readReply(cn); err != nil { if err := cmd.readReply(cn); err != nil && firstErr == nil {
if firstErr == nil { firstErr = err
firstErr = err
}
} }
} }