diff --git a/cluster.go b/cluster.go index 3559b12..b84c579 100644 --- a/cluster.go +++ b/cluster.go @@ -745,14 +745,22 @@ func (c *ClusterClient) ProcessContext(ctx context.Context, cmd Cmder) error { } func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { + err := c._process(ctx, cmd) + if err != nil { + cmd.setErr(err) + return err + } + return nil +} + +func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error { cmdInfo := c.cmdInfo(cmd.Name()) slot := c.cmdSlot(cmd) var node *clusterNode var ask bool + var lastErr error for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { - var err error - if attempt > 0 { if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { return err @@ -760,10 +768,10 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { } if node == nil { + var err error node, err = c.cmdNode(cmdInfo, slot) if err != nil { - cmd.setErr(err) - break + return err } } @@ -771,23 +779,27 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { pipe := node.Client.Pipeline() _ = pipe.Process(NewCmd("ASKING")) _ = pipe.Process(cmd) - _, err = pipe.ExecContext(ctx) + _, lastErr = pipe.ExecContext(ctx) _ = pipe.Close() ask = false } else { - err = node.Client.ProcessContext(ctx, cmd) + lastErr = node.Client._process(ctx, cmd) } // If there is no error - we are done. - if err == nil { - break + if lastErr == nil { + return nil } - if err != Nil { + if lastErr != Nil { c.state.LazyReload() } + if lastErr == pool.ErrClosed || isReadOnlyError(lastErr) { + node = nil + continue + } // If slave is loading - pick another node. - if c.opt.ReadOnly && isLoadingError(err) { + if c.opt.ReadOnly && isLoadingError(lastErr) { node.MarkAsFailing() node = nil continue @@ -795,21 +807,17 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { var moved bool var addr string - moved, ask, addr = isMovedError(err) + moved, ask, addr = isMovedError(lastErr) if moved || ask { + var err error node, err = c.nodes.Get(addr) if err != nil { - break + return err } continue } - if err == pool.ErrClosed || isReadOnlyError(err) { - node = nil - continue - } - - if isRetryableError(err, cmd.readTimeout() == nil) { + if isRetryableError(lastErr, cmd.readTimeout() == nil) { // First retry the same node. if attempt == 0 { continue @@ -821,10 +829,9 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { continue } - break + return lastErr } - - return cmd.Err() + return lastErr } // ForEachMaster concurrently calls the fn on each master node in the cluster. @@ -1052,6 +1059,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { if attempt > 0 { if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { + setCmdsErr(cmds, err) return err } } @@ -1064,18 +1072,24 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro go func(node *clusterNode, cmds []Cmder) { defer wg.Done() - cn, err := node.Client.getConn(ctx) + err := node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmd(wr, cmds...) + }) + if err != nil { + return err + } + + return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { + return c.pipelineReadCmds(node, rd, cmds, failedCmds) + }) + }) if err != nil { - if err == pool.ErrClosed { - _ = c.mapCmdsByNode(cmds, failedCmds) - } else { + err = c.mapCmdsByNode(cmds, failedCmds) + if err != nil { setCmdsErr(cmds, err) } - return } - - err = c.pipelineProcessCmds(ctx, node, cn, cmds, failedCmds) - node.Client.releaseConn(cn, err) }(node, cmds) } @@ -1100,10 +1114,15 @@ func newCmdsMap() *cmdsMap { } } +func (m *cmdsMap) Add(node *clusterNode, cmds ...Cmder) { + m.mu.Lock() + m.m[node] = append(m.m[node], cmds...) + m.mu.Unlock() +} + func (c *ClusterClient) mapCmdsByNode(cmds []Cmder, cmdsMap *cmdsMap) error { state, err := c.state.Get() if err != nil { - setCmdsErr(cmds, err) return err } @@ -1122,10 +1141,7 @@ func (c *ClusterClient) mapCmdsByNode(cmds []Cmder, cmdsMap *cmdsMap) error { if err != nil { return err } - - cmdsMap.mu.Lock() - cmdsMap.m[node] = append(cmdsMap.m[node], cmd) - cmdsMap.mu.Unlock() + cmdsMap.Add(node, cmd) } return nil } @@ -1140,87 +1156,55 @@ func (c *ClusterClient) cmdsAreReadOnly(cmds []Cmder) bool { return true } -func (c *ClusterClient) pipelineProcessCmds( - ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap, -) error { - err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return writeCmd(wr, cmds...) - }) - if err != nil { - setCmdsErr(cmds, err) - failedCmds.mu.Lock() - failedCmds.m[node] = cmds - failedCmds.mu.Unlock() - return err - } - - return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { - return c.pipelineReadCmds(node, rd, cmds, failedCmds) - }) -} - func (c *ClusterClient) pipelineReadCmds( node *clusterNode, rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap, ) error { - var firstErr error for _, cmd := range cmds { err := cmd.readReply(rd) if err == nil { continue } - if c.checkMovedErr(cmd, err, failedCmds) { continue } if c.opt.ReadOnly && isLoadingError(err) { node.MarkAsFailing() - } else if isRedisError(err) { + return err + } + if isRedisError(err) { continue } - - failedCmds.mu.Lock() - failedCmds.m[node] = append(failedCmds.m[node], cmd) - failedCmds.mu.Unlock() - if firstErr == nil { - firstErr = err - } + return err } - return firstErr + return nil } func (c *ClusterClient) checkMovedErr( cmd Cmder, err error, failedCmds *cmdsMap, ) bool { moved, ask, addr := isMovedError(err) + if !moved && !ask { + return false + } + + node, err := c.nodes.Get(addr) + if err != nil { + return false + } if moved { c.state.LazyReload() - - node, err := c.nodes.Get(addr) - if err != nil { - return false - } - - failedCmds.mu.Lock() - failedCmds.m[node] = append(failedCmds.m[node], cmd) - failedCmds.mu.Unlock() + failedCmds.Add(node, cmd) return true } if ask { - node, err := c.nodes.Get(addr) - if err != nil { - return false - } - - failedCmds.mu.Lock() - failedCmds.m[node] = append(failedCmds.m[node], NewCmd("ASKING"), cmd) - failedCmds.mu.Unlock() + failedCmds.Add(node, NewCmd("ASKING"), cmd) return true } - return false + panic("not reached") } // TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. @@ -1244,6 +1228,7 @@ func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) err func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) error { state, err := c.state.Get() if err != nil { + setCmdsErr(cmds, err) return err } @@ -1254,11 +1239,12 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er setCmdsErr(cmds, err) continue } - cmdsMap := map[*clusterNode][]Cmder{node: cmds} + cmdsMap := map[*clusterNode][]Cmder{node: cmds} for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { if attempt > 0 { if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { + setCmdsErr(cmds, err) return err } } @@ -1271,18 +1257,33 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er go func(node *clusterNode, cmds []Cmder) { defer wg.Done() - cn, err := node.Client.getConn(ctx) + err := node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { + return txPipelineWriteMulti(wr, cmds) + }) + if err != nil { + return err + } + + err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { + err := c.txPipelineReadQueued(rd, cmds, failedCmds) + if err != nil { + moved, ask, addr := isMovedError(err) + if moved || ask { + return c.cmdsMoved(cmds, moved, ask, addr, failedCmds) + } + return err + } + return pipelineReadCmds(rd, cmds) + }) + return err + }) if err != nil { - if err == pool.ErrClosed { - _ = c.mapCmdsByNode(cmds, failedCmds) - } else { + err = c.mapCmdsByNode(cmds, failedCmds) + if err != nil { setCmdsErr(cmds, err) } - return } - - err = c.txPipelineProcessCmds(ctx, node, cn, cmds, failedCmds) - node.Client.releaseConn(cn, err) }(node, cmds) } @@ -1306,31 +1307,6 @@ func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder { return cmdsMap } -func (c *ClusterClient) txPipelineProcessCmds( - ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap, -) error { - err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return txPipelineWriteMulti(wr, cmds) - }) - if err != nil { - setCmdsErr(cmds, err) - failedCmds.mu.Lock() - failedCmds.m[node] = cmds - failedCmds.mu.Unlock() - return err - } - - err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { - err := c.txPipelineReadQueued(rd, cmds, failedCmds) - if err != nil { - setCmdsErr(cmds, err) - return err - } - return pipelineReadCmds(rd, cmds) - }) - return err -} - func (c *ClusterClient) txPipelineReadQueued( rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap, ) error { @@ -1342,14 +1318,9 @@ func (c *ClusterClient) txPipelineReadQueued( for _, cmd := range cmds { err := statusCmd.readReply(rd) - if err == nil { + if err == nil || c.checkMovedErr(cmd, err, failedCmds) || isRedisError(err) { continue } - - if c.checkMovedErr(cmd, err, failedCmds) || isRedisError(err) { - continue - } - return err } @@ -1364,20 +1335,39 @@ func (c *ClusterClient) txPipelineReadQueued( switch line[0] { case proto.ErrorReply: - err := proto.ParseErrorReply(line) - for _, cmd := range cmds { - if !c.checkMovedErr(cmd, err, failedCmds) { - break - } - } - return err + return proto.ParseErrorReply(line) case proto.ArrayReply: // ok default: - err := fmt.Errorf("redis: expected '*', but got line %q", line) + return fmt.Errorf("redis: expected '*', but got line %q", line) + } + + return nil +} + +func (c *ClusterClient) cmdsMoved( + cmds []Cmder, moved, ask bool, addr string, failedCmds *cmdsMap, +) error { + node, err := c.nodes.Get(addr) + if err != nil { return err } + if moved { + c.state.LazyReload() + for _, cmd := range cmds { + failedCmds.Add(node, cmd) + } + return nil + } + + if ask { + for _, cmd := range cmds { + failedCmds.Add(node, NewCmd("ASKING"), cmd) + } + return nil + } + return nil } diff --git a/command.go b/command.go index 02a3e1e..b8f6977 100644 --- a/command.go +++ b/command.go @@ -123,6 +123,10 @@ func (cmd *baseCmd) stringArg(pos int) string { return s } +func (cmd *baseCmd) setErr(e error) { + cmd.err = e +} + func (cmd *baseCmd) Err() error { return cmd.err } @@ -135,10 +139,6 @@ func (cmd *baseCmd) setReadTimeout(d time.Duration) { cmd._readTimeout = &d } -func (cmd *baseCmd) setErr(e error) { - cmd.err = e -} - //------------------------------------------------------------------------------ type Cmd struct { diff --git a/redis.go b/redis.go index 239d36d..6634b80 100644 --- a/redis.go +++ b/redis.go @@ -235,7 +235,32 @@ func (c *baseClient) releaseConn(cn *pool.Conn, err error) { } } +func (c *baseClient) withConn( + ctx context.Context, fn func(context.Context, *pool.Conn) error, +) error { + cn, err := c.getConn(ctx) + if err != nil { + return err + } + defer func() { + c.releaseConn(cn, err) + }() + + err = fn(ctx, cn) + return err +} + func (c *baseClient) process(ctx context.Context, cmd Cmder) error { + err := c._process(ctx, cmd) + if err != nil { + cmd.setErr(err) + return err + } + return nil +} + +func (c *baseClient) _process(ctx context.Context, cmd Cmder) error { + var lastErr error for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { if attempt > 0 { if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { @@ -243,37 +268,29 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error { } } - cn, err := c.getConn(ctx) - if err != nil { - cmd.setErr(err) - if isRetryableError(err, true) { - continue + var retryTimeout bool + lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmd(wr, cmd) + }) + if err != nil { + retryTimeout = true + return err } - return err - } - err = cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return writeCmd(wr, cmd) + err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply) + if err != nil { + retryTimeout = cmd.readTimeout() == nil + return err + } + + return nil }) - if err != nil { - c.releaseConn(cn, err) - cmd.setErr(err) - if isRetryableError(err, true) { - continue - } - return err + if lastErr == nil || !isRetryableError(lastErr, retryTimeout) { + return lastErr } - - err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply) - c.releaseConn(cn, err) - if err != nil && isRetryableError(err, cmd.readTimeout() == nil) { - continue - } - - return err } - - return cmd.Err() + return lastErr } func (c *baseClient) retryBackoff(attempt int) time.Duration { @@ -325,6 +342,18 @@ type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error) func (c *baseClient) generalProcessPipeline( ctx context.Context, cmds []Cmder, p pipelineProcessor, ) error { + err := c._generalProcessPipeline(ctx, cmds, p) + if err != nil { + setCmdsErr(cmds, err) + return err + } + return cmdsFirstErr(cmds) +} + +func (c *baseClient) _generalProcessPipeline( + ctx context.Context, cmds []Cmder, p pipelineProcessor, +) error { + var lastErr error for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { if attempt > 0 { if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { @@ -332,20 +361,17 @@ func (c *baseClient) generalProcessPipeline( } } - cn, err := c.getConn(ctx) - if err != nil { - setCmdsErr(cmds, err) + var canRetry bool + lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + var err error + canRetry, err = p(ctx, cn, cmds) return err - } - - canRetry, err := p(ctx, cn, cmds) - c.releaseConn(cn, err) - - if !canRetry || !isRetryableError(err, true) { - break + }) + if lastErr == nil || !canRetry || !isRetryableError(lastErr, true) { + return lastErr } } - return cmdsFirstErr(cmds) + return lastErr } func (c *baseClient) pipelineProcessCmds( @@ -355,7 +381,6 @@ func (c *baseClient) pipelineProcessCmds( return writeCmd(wr, cmds...) }) if err != nil { - setCmdsErr(cmds, err) return true, err } @@ -382,14 +407,12 @@ func (c *baseClient) txPipelineProcessCmds( return txPipelineWriteMulti(wr, cmds) }) if err != nil { - setCmdsErr(cmds, err) return true, err } err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { err := txPipelineReadQueued(rd, cmds) if err != nil { - setCmdsErr(cmds, err) return err } return pipelineReadCmds(rd, cmds) diff --git a/ring.go b/ring.go index c6f592d..1aeb35a 100644 --- a/ring.go +++ b/ring.go @@ -551,6 +551,16 @@ func (c *Ring) cmdShard(cmd Cmder) (*ringShard, error) { } func (c *Ring) process(ctx context.Context, cmd Cmder) error { + err := c._process(ctx, cmd) + if err != nil { + cmd.setErr(err) + return err + } + return nil +} + +func (c *Ring) _process(ctx context.Context, cmd Cmder) error { + var lastErr error for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { if attempt > 0 { if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { @@ -560,19 +570,15 @@ func (c *Ring) process(ctx context.Context, cmd Cmder) error { shard, err := c.cmdShard(cmd) if err != nil { - cmd.setErr(err) return err } - err = shard.Client.ProcessContext(ctx, cmd) - if err == nil { - return nil - } - if !isRetryableError(err, cmd.readTimeout() == nil) { - return err + lastErr = shard.Client._process(ctx, cmd) + if lastErr == nil || !isRetryableError(lastErr, cmd.readTimeout() == nil) { + return lastErr } } - return cmd.Err() + return lastErr } func (c *Ring) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { @@ -626,63 +632,42 @@ func (c *Ring) generalProcessPipeline( cmdsMap[hash] = append(cmdsMap[hash], cmd) } - for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { - if attempt > 0 { - if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { - return err + var wg sync.WaitGroup + for hash, cmds := range cmdsMap { + wg.Add(1) + go func(hash string, cmds []Cmder) { + defer wg.Done() + + err := c.processShardPipeline(ctx, hash, cmds, tx) + if err != nil { + setCmdsErr(cmds, err) } - } - - var mu sync.Mutex - var failedCmdsMap map[string][]Cmder - var wg sync.WaitGroup - - for hash, cmds := range cmdsMap { - wg.Add(1) - go func(hash string, cmds []Cmder) { - defer wg.Done() - - shard, err := c.shards.GetByHash(hash) - if err != nil { - setCmdsErr(cmds, err) - return - } - - cn, err := shard.Client.getConn(ctx) - if err != nil { - setCmdsErr(cmds, err) - return - } - - var canRetry bool - if tx { - canRetry, err = shard.Client.txPipelineProcessCmds(ctx, cn, cmds) - } else { - canRetry, err = shard.Client.pipelineProcessCmds(ctx, cn, cmds) - } - shard.Client.releaseConn(cn, err) - - if canRetry && isRetryableError(err, true) { - mu.Lock() - if failedCmdsMap == nil { - failedCmdsMap = make(map[string][]Cmder) - } - failedCmdsMap[hash] = cmds - mu.Unlock() - } - }(hash, cmds) - } - - wg.Wait() - if len(failedCmdsMap) == 0 { - break - } - cmdsMap = failedCmdsMap + }(hash, cmds) } + wg.Wait() return cmdsFirstErr(cmds) } +func (c *Ring) processShardPipeline( + ctx context.Context, hash string, cmds []Cmder, tx bool, +) error { + //TODO: retry? + shard, err := c.shards.GetByHash(hash) + if err != nil { + return err + } + + if tx { + err = shard.Client._generalProcessPipeline( + ctx, cmds, shard.Client.txPipelineProcessCmds) + } else { + err = shard.Client._generalProcessPipeline( + ctx, cmds, shard.Client.pipelineProcessCmds) + } + return err +} + // Close closes the ring client, releasing any open resources. // // It is rare to Close a Ring, as the Ring is meant to be long-lived