diff --git a/cluster.go b/cluster.go index 29ae471..bc453e3 100644 --- a/cluster.go +++ b/cluster.go @@ -1065,7 +1065,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro return } - err = c.pipelineProcessCmds(node, cn, cmds, failedCmds) + err = c.pipelineProcessCmds(ctx, node, cn, cmds, failedCmds) node.Client.releaseConnStrict(cn, err) }(node, cmds) } @@ -1129,9 +1129,9 @@ func (c *ClusterClient) cmdsAreReadOnly(cmds []Cmder) bool { } func (c *ClusterClient) pipelineProcessCmds( - node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap, + ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap, ) error { - err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { + err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { return writeCmd(wr, cmds...) }) if err != nil { @@ -1142,7 +1142,7 @@ func (c *ClusterClient) pipelineProcessCmds( return err } - return cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error { + return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { return c.pipelineReadCmds(node, rd, cmds, failedCmds) }) } @@ -1266,7 +1266,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er return } - err = c.txPipelineProcessCmds(node, cn, cmds, failedCmds) + err = c.txPipelineProcessCmds(ctx, node, cn, cmds, failedCmds) node.Client.releaseConnStrict(cn, err) }(node, cmds) } @@ -1292,9 +1292,9 @@ func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder { } func (c *ClusterClient) txPipelineProcessCmds( - node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap, + ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap, ) error { - err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { + err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { return txPipelineWriteMulti(wr, cmds) }) if err != nil { @@ -1305,7 +1305,7 @@ func (c *ClusterClient) txPipelineProcessCmds( return err } - err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error { + err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { err := c.txPipelineReadQueued(rd, cmds, failedCmds) if err != nil { setCmdsErr(cmds, err) diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 78b29f5..687c1e8 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -1,6 +1,7 @@ package pool import ( + "context" "net" "sync/atomic" "time" @@ -48,24 +49,6 @@ func (cn *Conn) SetNetConn(netConn net.Conn) { cn.wr.Reset(netConn) } -func (cn *Conn) setReadTimeout(timeout time.Duration) error { - now := time.Now() - cn.SetUsedAt(now) - if timeout > 0 { - return cn.netConn.SetReadDeadline(now.Add(timeout)) - } - return cn.netConn.SetReadDeadline(noDeadline) -} - -func (cn *Conn) setWriteTimeout(timeout time.Duration) error { - now := time.Now() - cn.SetUsedAt(now) - if timeout > 0 { - return cn.netConn.SetWriteDeadline(now.Add(timeout)) - } - return cn.netConn.SetWriteDeadline(noDeadline) -} - func (cn *Conn) Write(b []byte) (int, error) { return cn.netConn.Write(b) } @@ -74,13 +57,17 @@ func (cn *Conn) RemoteAddr() net.Addr { return cn.netConn.RemoteAddr() } -func (cn *Conn) WithReader(timeout time.Duration, fn func(rd *proto.Reader) error) error { - _ = cn.setReadTimeout(timeout) +func (cn *Conn) WithReader(ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error) error { + tm := cn.deadline(ctx, timeout) + _ = cn.netConn.SetReadDeadline(tm) return fn(cn.rd) } -func (cn *Conn) WithWriter(timeout time.Duration, fn func(wr *proto.Writer) error) error { - _ = cn.setWriteTimeout(timeout) +func (cn *Conn) WithWriter( + ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error, +) error { + tm := cn.deadline(ctx, timeout) + _ = cn.netConn.SetWriteDeadline(tm) firstErr := fn(cn.wr) err := cn.wr.Flush() @@ -93,3 +80,22 @@ func (cn *Conn) WithWriter(timeout time.Duration, fn func(wr *proto.Writer) erro func (cn *Conn) Close() error { return cn.netConn.Close() } + +func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time { + if ctx != nil { + tm, ok := ctx.Deadline() + if ok { + cn.SetUsedAt(tm) + return tm + } + } + + now := time.Now() + if timeout > 0 { + cn.SetUsedAt(now) + return now.Add(timeout) + } + + cn.SetUsedAt(now) + return noDeadline +} diff --git a/pubsub.go b/pubsub.go index 03b0156..e3df4f4 100644 --- a/pubsub.go +++ b/pubsub.go @@ -1,6 +1,7 @@ package redis import ( + "context" "errors" "fmt" "strings" @@ -83,8 +84,8 @@ func (c *PubSub) _conn(newChannels []string) (*pool.Conn, error) { return cn, nil } -func (c *PubSub) writeCmd(cn *pool.Conn, cmd Cmder) error { - return cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { +func (c *PubSub) writeCmd(ctx context.Context, cn *pool.Conn, cmd Cmder) error { + return cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { return writeCmd(wr, cmd) }) } @@ -128,7 +129,7 @@ func (c *PubSub) _subscribe( args = append(args, channel) } cmd := NewSliceCmd(args...) - return c.writeCmd(cn, cmd) + return c.writeCmd(context.TODO(), cn, cmd) } func (c *PubSub) releaseConn(cn *pool.Conn, err error, allowTimeout bool) { @@ -258,7 +259,7 @@ func (c *PubSub) Ping(payload ...string) error { return err } - err = c.writeCmd(cn, cmd) + err = c.writeCmd(context.TODO(), cn, cmd) c.releaseConn(cn, err, false) return err } @@ -350,7 +351,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { return nil, err } - err = cn.WithReader(timeout, func(rd *proto.Reader) error { + err = cn.WithReader(context.TODO(), timeout, func(rd *proto.Reader) error { return c.cmd.readReply(rd) }) diff --git a/redis.go b/redis.go index 6870e2b..d47c2b3 100644 --- a/redis.go +++ b/redis.go @@ -265,7 +265,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error { return err } - err = cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { + err = cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { return writeCmd(wr, cmd) }) if err != nil { @@ -277,7 +277,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error { return err } - err = cn.WithReader(c.cmdTimeout(cmd), cmd.readReply) + err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply) c.releaseConn(cn, err) if err != nil && internal.IsRetryableError(err, cmd.readTimeout() == nil) { continue @@ -333,7 +333,7 @@ func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds) } -type pipelineProcessor func(*pool.Conn, []Cmder) (bool, error) +type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error) func (c *baseClient) generalProcessPipeline( ctx context.Context, cmds []Cmder, p pipelineProcessor, @@ -349,7 +349,7 @@ func (c *baseClient) generalProcessPipeline( return err } - canRetry, err := p(cn, cmds) + canRetry, err := p(ctx, cn, cmds) c.releaseConnStrict(cn, err) if !canRetry || !internal.IsRetryableError(err, true) { @@ -359,8 +359,10 @@ func (c *baseClient) generalProcessPipeline( return cmdsFirstErr(cmds) } -func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) { - err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { +func (c *baseClient) pipelineProcessCmds( + ctx context.Context, cn *pool.Conn, cmds []Cmder, +) (bool, error) { + err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { return writeCmd(wr, cmds...) }) if err != nil { @@ -368,7 +370,7 @@ func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, err return true, err } - err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error { + err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { return pipelineReadCmds(rd, cmds) }) return true, err @@ -384,8 +386,10 @@ func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error { return nil } -func (c *baseClient) txPipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) { - err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { +func (c *baseClient) txPipelineProcessCmds( + ctx context.Context, cn *pool.Conn, cmds []Cmder, +) (bool, error) { + err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { return txPipelineWriteMulti(wr, cmds) }) if err != nil { @@ -393,7 +397,7 @@ func (c *baseClient) txPipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, e return true, err } - err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error { + err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { err := txPipelineReadQueued(rd, cmds) if err != nil { setCmdsErr(cmds, err) diff --git a/ring.go b/ring.go index fa77a97..3e9e0c4 100644 --- a/ring.go +++ b/ring.go @@ -616,7 +616,7 @@ func (c *Ring) _processPipeline(ctx context.Context, cmds []Cmder) error { return } - canRetry, err := shard.Client.pipelineProcessCmds(cn, cmds) + canRetry, err := shard.Client.pipelineProcessCmds(ctx, cn, cmds) shard.Client.releaseConnStrict(cn, err) if canRetry && internal.IsRetryableError(err, true) {