From 0884e48a21a780f22b43fb8fa5ba65a8f7bc9567 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Mon, 21 Nov 2022 11:55:19 +0200 Subject: [PATCH] chore: improve cluster pipeline retries --- cluster.go | 122 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 77 insertions(+), 45 deletions(-) diff --git a/cluster.go b/cluster.go index cc077b3..a70ff12 100644 --- a/cluster.go +++ b/cluster.go @@ -846,8 +846,8 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { c.cmdable = c.Process c.hooks.setProcess(c.process) - c.hooks.setProcessPipeline(c._processPipeline) - c.hooks.setProcessTxPipeline(c._processTxPipeline) + c.hooks.setProcessPipeline(c.processPipeline) + c.hooks.setProcessTxPipeline(c.processTxPipeline) return c } @@ -1187,7 +1187,7 @@ func (c *ClusterClient) Pipelined(ctx context.Context, fn func(Pipeliner) error) return c.Pipeline().Pipelined(ctx, fn) } -func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) error { +func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error { cmdsMap := newCmdsMap() if err := c.mapCmdsByNode(ctx, cmdsMap, cmds); err != nil { @@ -1210,7 +1210,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro wg.Add(1) go func(node *clusterNode, cmds []Cmder) { defer wg.Done() - c._processPipelineNode(ctx, node, cmds, failedCmds) + c.processPipelineNode(ctx, node, cmds, failedCmds) }(node, cmds) } @@ -1263,22 +1263,38 @@ func (c *ClusterClient) cmdsAreReadOnly(ctx context.Context, cmds []Cmder) bool return true } -func (c *ClusterClient) _processPipelineNode( +func (c *ClusterClient) processPipelineNode( ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, ) { _ = node.Client.hooks.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { - return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { - if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { - return writeCmds(wr, cmds) - }); err != nil { - setCmdsErr(cmds, err) - return err - } + cn, err := node.Client.getConn(ctx) + if err != nil { + _ = c.mapCmdsByNode(ctx, failedCmds, cmds) + setCmdsErr(cmds, err) + return err + } - return cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error { - return c.pipelineReadCmds(ctx, node, rd, cmds, failedCmds) - }) - }) + err = c.processPipelineNodeConn(ctx, node, cn, cmds, failedCmds) + node.Client.releaseConn(ctx, cn, err) + return err + }) +} + +func (c *ClusterClient) processPipelineNodeConn( + ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap, +) error { + if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmds(wr, cmds) + }); err != nil { + if shouldRetry(err, true) { + _ = c.mapCmdsByNode(ctx, failedCmds, cmds) + } + setCmdsErr(cmds, err) + return err + } + + return cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error { + return c.pipelineReadCmds(ctx, node, rd, cmds, failedCmds) }) } @@ -1365,7 +1381,7 @@ func (c *ClusterClient) TxPipelined(ctx context.Context, fn func(Pipeliner) erro return c.TxPipeline().Pipelined(ctx, fn) } -func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) error { +func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) error { // Trim multi .. exec. cmds = cmds[1 : len(cmds)-1] @@ -1399,7 +1415,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er wg.Add(1) go func(node *clusterNode, cmds []Cmder) { defer wg.Done() - c._processTxPipelineNode(ctx, node, cmds, failedCmds) + c.processTxPipelineNode(ctx, node, cmds, failedCmds) }(node, cmds) } @@ -1423,40 +1439,56 @@ func (c *ClusterClient) mapCmdsBySlot(ctx context.Context, cmds []Cmder) map[int return cmdsMap } -func (c *ClusterClient) _processTxPipelineNode( +func (c *ClusterClient) processTxPipelineNode( ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, ) { cmds = wrapMultiExec(ctx, cmds) _ = node.Client.hooks.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { - return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { - if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { - return writeCmds(wr, cmds) - }); err != nil { - setCmdsErr(cmds, err) - return err + cn, err := node.Client.getConn(ctx) + if err != nil { + _ = c.mapCmdsByNode(ctx, failedCmds, cmds) + setCmdsErr(cmds, err) + return err + } + + err = c.processTxPipelineNodeConn(ctx, node, cn, cmds, failedCmds) + node.Client.releaseConn(ctx, cn, err) + return err + }) +} + +func (c *ClusterClient) processTxPipelineNodeConn( + ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap, +) error { + if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmds(wr, cmds) + }); err != nil { + if shouldRetry(err, true) { + _ = c.mapCmdsByNode(ctx, failedCmds, cmds) + } + setCmdsErr(cmds, err) + return err + } + + return cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error { + statusCmd := cmds[0].(*StatusCmd) + // Trim multi and exec. + trimmedCmds := cmds[1 : len(cmds)-1] + + if err := c.txPipelineReadQueued( + ctx, rd, statusCmd, trimmedCmds, failedCmds, + ); err != nil { + setCmdsErr(cmds, err) + + moved, ask, addr := isMovedError(err) + if moved || ask { + return c.cmdsMoved(ctx, trimmedCmds, moved, ask, addr, failedCmds) } - return cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error { - statusCmd := cmds[0].(*StatusCmd) - // Trim multi and exec. - trimmedCmds := cmds[1 : len(cmds)-1] + return err + } - if err := c.txPipelineReadQueued( - ctx, rd, statusCmd, trimmedCmds, failedCmds, - ); err != nil { - setCmdsErr(cmds, err) - - moved, ask, addr := isMovedError(err) - if moved || ask { - return c.cmdsMoved(ctx, trimmedCmds, moved, ask, addr, failedCmds) - } - - return err - } - - return pipelineReadCmds(rd, trimmedCmds) - }) - }) + return pipelineReadCmds(rd, trimmedCmds) }) }