diff --git a/cluster.go b/cluster.go index 2759d8f..f962684 100644 --- a/cluster.go +++ b/cluster.go @@ -1180,8 +1180,8 @@ func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) error { cmdsMap := newCmdsMap() - err := c.mapCmdsByNode(ctx, cmdsMap, cmds) - if err != nil { + + if err := c.mapCmdsByNode(ctx, cmdsMap, cmds); err != nil { setCmdsErr(cmds, err) return err } @@ -1201,18 +1201,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro wg.Add(1) go func(node *clusterNode, cmds []Cmder) { defer wg.Done() - - err := c._processPipelineNode(ctx, node, cmds, failedCmds) - if err == nil { - return - } - if attempt < c.opt.MaxRedirects { - if err := c.mapCmdsByNode(ctx, failedCmds, cmds); err != nil { - setCmdsErr(cmds, err) - } - } else { - setCmdsErr(cmds, err) - } + c._processPipelineNode(ctx, node, cmds, failedCmds) }(node, cmds) } @@ -1267,13 +1256,13 @@ func (c *ClusterClient) cmdsAreReadOnly(ctx context.Context, cmds []Cmder) bool func (c *ClusterClient) _processPipelineNode( ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, -) error { - return node.Client.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { +) { + _ = node.Client.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { - err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { + if err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { return writeCmds(wr, cmds) - }) - if err != nil { + }); err != nil { + setCmdsErr(cmds, err) return err } @@ -1291,7 +1280,7 @@ func (c *ClusterClient) pipelineReadCmds( cmds []Cmder, failedCmds *cmdsMap, ) error { - for _, cmd := range cmds { + for i, cmd := range cmds { err := cmd.readReply(rd) cmd.SetErr(err) @@ -1303,15 +1292,24 @@ func (c *ClusterClient) pipelineReadCmds( continue } - if c.opt.ReadOnly && (isLoadingError(err) || !isRedisError(err)) { + if c.opt.ReadOnly { node.MarkAsFailing() + } + + if !isRedisError(err) { + if shouldRetry(err, true) { + _ = c.mapCmdsByNode(ctx, failedCmds, cmds) + } + setCmdsErr(cmds[i+1:], err) return err } - if isRedisError(err) { - continue - } + } + + if err := cmds[0].Err(); err != nil && shouldRetry(err, true) { + _ = c.mapCmdsByNode(ctx, failedCmds, cmds) return err } + return nil } @@ -1393,19 +1391,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er wg.Add(1) go func(node *clusterNode, cmds []Cmder) { defer wg.Done() - - err := c._processTxPipelineNode(ctx, node, cmds, failedCmds) - if err == nil { - return - } - - if attempt < c.opt.MaxRedirects { - if err := c.mapCmdsByNode(ctx, failedCmds, cmds); err != nil { - setCmdsErr(cmds, err) - } - } else { - setCmdsErr(cmds, err) - } + c._processTxPipelineNode(ctx, node, cmds, failedCmds) }(node, cmds) } @@ -1431,34 +1417,39 @@ func (c *ClusterClient) mapCmdsBySlot(ctx context.Context, cmds []Cmder) map[int func (c *ClusterClient) _processTxPipelineNode( ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, -) error { - return node.Client.hooks.processTxPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { - return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { - err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return writeCmds(wr, cmds) - }) - if err != nil { - return err - } - - return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { - statusCmd := cmds[0].(*StatusCmd) - // Trim multi and exec. - cmds = cmds[1 : len(cmds)-1] - - err := c.txPipelineReadQueued(ctx, rd, statusCmd, cmds, failedCmds) - if err != nil { - moved, ask, addr := isMovedError(err) - if moved || ask { - return c.cmdsMoved(ctx, cmds, moved, ask, addr, failedCmds) - } +) { + _ = node.Client.hooks.processTxPipeline( + ctx, cmds, func(ctx context.Context, cmds []Cmder) error { + return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + if err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmds(wr, cmds) + }); err != nil { + setCmdsErr(cmds, err) return err } - return pipelineReadCmds(rd, cmds) + return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { + statusCmd := cmds[0].(*StatusCmd) + // Trim multi and exec. + trimmedCmds := cmds[1 : len(cmds)-1] + + if err := c.txPipelineReadQueued( + ctx, rd, statusCmd, trimmedCmds, failedCmds, + ); err != nil { + setCmdsErr(cmds, err) + + moved, ask, addr := isMovedError(err) + if moved || ask { + return c.cmdsMoved(ctx, trimmedCmds, moved, ask, addr, failedCmds) + } + + return err + } + + return pipelineReadCmds(rd, trimmedCmds) + }) }) }) - }) } func (c *ClusterClient) txPipelineReadQueued( diff --git a/cluster_test.go b/cluster_test.go index 72938a2..92844eb 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -1276,20 +1276,33 @@ var _ = Describe("ClusterClient timeout", func() { Context("read/write timeout", func() { BeforeEach(func() { opt := redisClusterOptions() - opt.ReadTimeout = 250 * time.Millisecond - opt.WriteTimeout = 250 * time.Millisecond - opt.MaxRedirects = 1 client = cluster.newClusterClient(ctx, opt) err := client.ForEachShard(ctx, func(ctx context.Context, client *redis.Client) error { - return client.ClientPause(ctx, pause).Err() + err := client.ClientPause(ctx, pause).Err() + + opt := client.Options() + opt.ReadTimeout = time.Nanosecond + opt.WriteTimeout = time.Nanosecond + + return err }) Expect(err).NotTo(HaveOccurred()) + + // Overwrite timeouts after the client is initialized. + opt.ReadTimeout = time.Nanosecond + opt.WriteTimeout = time.Nanosecond + opt.MaxRedirects = 0 }) AfterEach(func() { _ = client.ForEachShard(ctx, func(ctx context.Context, client *redis.Client) error { defer GinkgoRecover() + + opt := client.Options() + opt.ReadTimeout = time.Second + opt.WriteTimeout = time.Second + Eventually(func() error { return client.Ping(ctx).Err() }, 2*pause).ShouldNot(HaveOccurred()) diff --git a/race_test.go b/race_test.go index 52181e0..2265366 100644 --- a/race_test.go +++ b/race_test.go @@ -2,7 +2,6 @@ package redis_test import ( "bytes" - "context" "fmt" "net" "strconv" @@ -289,26 +288,6 @@ var _ = Describe("races", func() { wg.Wait() Expect(atomic.LoadUint32(&received)).To(Equal(uint32(C * N))) }) - - It("should abort on context timeout", func() { - opt := redisClusterOptions() - client := cluster.newClusterClient(ctx, opt) - - ctx, cancel := context.WithCancel(context.Background()) - - wg := performAsync(C, func(_ int) { - _, err := client.XRead(ctx, &redis.XReadArgs{ - Streams: []string{"test", "$"}, - Block: 1 * time.Second, - }).Result() - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(Or(Equal(context.Canceled.Error()), ContainSubstring("operation was canceled"))) - }) - - time.Sleep(10 * time.Millisecond) - cancel() - wg.Wait() - }) }) var _ = Describe("cluster races", func() { diff --git a/redis.go b/redis.go index 704d9e5..0428149 100644 --- a/redis.go +++ b/redis.go @@ -290,27 +290,7 @@ func (c *baseClient) withConn( c.releaseConn(ctx, cn, err) }() - done := ctx.Done() //nolint:ifshort - - if done == nil { - err = fn(ctx, cn) - return err - } - - errc := make(chan error, 1) - go func() { errc <- fn(ctx, cn) }() - - select { - case <-done: - _ = cn.Close() - // Wait for the goroutine to finish and send something. - <-errc - - err = ctx.Err() - return err - case err = <-errc: - return err - } + return fn(ctx, cn) } func (c *baseClient) process(ctx context.Context, cmd Cmder) error { @@ -416,7 +396,6 @@ func (c *baseClient) generalProcessPipeline( ) error { err := c._generalProcessPipeline(ctx, cmds, p) if err != nil { - setCmdsErr(cmds, err) return err } return cmdsFirstErr(cmds) @@ -429,6 +408,7 @@ func (c *baseClient) _generalProcessPipeline( for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { if attempt > 0 { if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { + setCmdsErr(cmds, err) return err } } @@ -449,53 +429,61 @@ func (c *baseClient) _generalProcessPipeline( func (c *baseClient) pipelineProcessCmds( ctx context.Context, cn *pool.Conn, cmds []Cmder, ) (bool, error) { - err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { + if err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { return writeCmds(wr, cmds) - }) - if err != nil { + }); err != nil { + setCmdsErr(cmds, err) return true, err } - err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { + if err := cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { return pipelineReadCmds(rd, cmds) - }) - return true, err + }); err != nil { + return true, err + } + + return false, nil } func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error { - for _, cmd := range cmds { + for i, cmd := range cmds { err := cmd.readReply(rd) cmd.SetErr(err) if err != nil && !isRedisError(err) { + setCmdsErr(cmds[i+1:], err) return err } } - return nil + // Retry errors like "LOADING redis is loading the dataset in memory". + return cmds[0].Err() } func (c *baseClient) txPipelineProcessCmds( ctx context.Context, cn *pool.Conn, cmds []Cmder, ) (bool, error) { - err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { + if err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { return writeCmds(wr, cmds) - }) - if err != nil { + }); err != nil { + setCmdsErr(cmds, err) return true, err } - err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { + if err := cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { statusCmd := cmds[0].(*StatusCmd) // Trim multi and exec. - cmds = cmds[1 : len(cmds)-1] + trimmedCmds := cmds[1 : len(cmds)-1] - err := txPipelineReadQueued(rd, statusCmd, cmds) - if err != nil { + if err := txPipelineReadQueued(rd, statusCmd, trimmedCmds); err != nil { + setCmdsErr(cmds, err) return err } - return pipelineReadCmds(rd, cmds) - }) - return false, err + return pipelineReadCmds(rd, trimmedCmds) + }); err != nil { + return false, err + } + + return false, nil } func wrapMultiExec(ctx context.Context, cmds []Cmder) []Cmder {