Merge pull request #1047 from go-redis/fix/context-deadline

Use Context.Deadline to set net.Conn deadline
This commit is contained in:
Vladimir Mihailenco 2019-06-08 15:34:34 +03:00 committed by GitHub
commit 530e66a66e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 57 additions and 46 deletions

View File

@ -1065,7 +1065,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro
return
}
err = c.pipelineProcessCmds(node, cn, cmds, failedCmds)
err = c.pipelineProcessCmds(ctx, node, cn, cmds, failedCmds)
node.Client.releaseConnStrict(cn, err)
}(node, cmds)
}
@ -1129,9 +1129,9 @@ func (c *ClusterClient) cmdsAreReadOnly(cmds []Cmder) bool {
}
func (c *ClusterClient) pipelineProcessCmds(
node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
) error {
err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmds...)
})
if err != nil {
@ -1142,7 +1142,7 @@ func (c *ClusterClient) pipelineProcessCmds(
return err
}
return cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error {
return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
return c.pipelineReadCmds(node, rd, cmds, failedCmds)
})
}
@ -1266,7 +1266,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er
return
}
err = c.txPipelineProcessCmds(node, cn, cmds, failedCmds)
err = c.txPipelineProcessCmds(ctx, node, cn, cmds, failedCmds)
node.Client.releaseConnStrict(cn, err)
}(node, cmds)
}
@ -1292,9 +1292,9 @@ func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder {
}
func (c *ClusterClient) txPipelineProcessCmds(
node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
) error {
err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return txPipelineWriteMulti(wr, cmds)
})
if err != nil {
@ -1305,7 +1305,7 @@ func (c *ClusterClient) txPipelineProcessCmds(
return err
}
err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error {
err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
err := c.txPipelineReadQueued(rd, cmds, failedCmds)
if err != nil {
setCmdsErr(cmds, err)

View File

@ -1,6 +1,7 @@
package pool
import (
"context"
"net"
"sync/atomic"
"time"
@ -48,24 +49,6 @@ func (cn *Conn) SetNetConn(netConn net.Conn) {
cn.wr.Reset(netConn)
}
func (cn *Conn) setReadTimeout(timeout time.Duration) error {
now := time.Now()
cn.SetUsedAt(now)
if timeout > 0 {
return cn.netConn.SetReadDeadline(now.Add(timeout))
}
return cn.netConn.SetReadDeadline(noDeadline)
}
func (cn *Conn) setWriteTimeout(timeout time.Duration) error {
now := time.Now()
cn.SetUsedAt(now)
if timeout > 0 {
return cn.netConn.SetWriteDeadline(now.Add(timeout))
}
return cn.netConn.SetWriteDeadline(noDeadline)
}
func (cn *Conn) Write(b []byte) (int, error) {
return cn.netConn.Write(b)
}
@ -74,13 +57,17 @@ func (cn *Conn) RemoteAddr() net.Addr {
return cn.netConn.RemoteAddr()
}
func (cn *Conn) WithReader(timeout time.Duration, fn func(rd *proto.Reader) error) error {
_ = cn.setReadTimeout(timeout)
func (cn *Conn) WithReader(ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error) error {
tm := cn.deadline(ctx, timeout)
_ = cn.netConn.SetReadDeadline(tm)
return fn(cn.rd)
}
func (cn *Conn) WithWriter(timeout time.Duration, fn func(wr *proto.Writer) error) error {
_ = cn.setWriteTimeout(timeout)
func (cn *Conn) WithWriter(
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error,
) error {
tm := cn.deadline(ctx, timeout)
_ = cn.netConn.SetWriteDeadline(tm)
firstErr := fn(cn.wr)
err := cn.wr.Flush()
@ -93,3 +80,22 @@ func (cn *Conn) WithWriter(timeout time.Duration, fn func(wr *proto.Writer) erro
func (cn *Conn) Close() error {
return cn.netConn.Close()
}
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {
if ctx != nil {
tm, ok := ctx.Deadline()
if ok {
cn.SetUsedAt(tm)
return tm
}
}
now := time.Now()
if timeout > 0 {
cn.SetUsedAt(now)
return now.Add(timeout)
}
cn.SetUsedAt(now)
return noDeadline
}

View File

@ -1,6 +1,7 @@
package redis
import (
"context"
"errors"
"fmt"
"strings"
@ -83,8 +84,8 @@ func (c *PubSub) _conn(newChannels []string) (*pool.Conn, error) {
return cn, nil
}
func (c *PubSub) writeCmd(cn *pool.Conn, cmd Cmder) error {
return cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
func (c *PubSub) writeCmd(ctx context.Context, cn *pool.Conn, cmd Cmder) error {
return cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmd)
})
}
@ -128,7 +129,7 @@ func (c *PubSub) _subscribe(
args = append(args, channel)
}
cmd := NewSliceCmd(args...)
return c.writeCmd(cn, cmd)
return c.writeCmd(context.TODO(), cn, cmd)
}
func (c *PubSub) releaseConn(cn *pool.Conn, err error, allowTimeout bool) {
@ -258,7 +259,7 @@ func (c *PubSub) Ping(payload ...string) error {
return err
}
err = c.writeCmd(cn, cmd)
err = c.writeCmd(context.TODO(), cn, cmd)
c.releaseConn(cn, err, false)
return err
}
@ -350,7 +351,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
return nil, err
}
err = cn.WithReader(timeout, func(rd *proto.Reader) error {
err = cn.WithReader(context.TODO(), timeout, func(rd *proto.Reader) error {
return c.cmd.readReply(rd)
})

View File

@ -265,7 +265,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
return err
}
err = cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
err = cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmd)
})
if err != nil {
@ -277,7 +277,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
return err
}
err = cn.WithReader(c.cmdTimeout(cmd), cmd.readReply)
err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply)
c.releaseConn(cn, err)
if err != nil && internal.IsRetryableError(err, cmd.readTimeout() == nil) {
continue
@ -333,7 +333,7 @@ func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error
return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds)
}
type pipelineProcessor func(*pool.Conn, []Cmder) (bool, error)
type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error)
func (c *baseClient) generalProcessPipeline(
ctx context.Context, cmds []Cmder, p pipelineProcessor,
@ -349,7 +349,7 @@ func (c *baseClient) generalProcessPipeline(
return err
}
canRetry, err := p(cn, cmds)
canRetry, err := p(ctx, cn, cmds)
c.releaseConnStrict(cn, err)
if !canRetry || !internal.IsRetryableError(err, true) {
@ -359,8 +359,10 @@ func (c *baseClient) generalProcessPipeline(
return cmdsFirstErr(cmds)
}
func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) {
err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
func (c *baseClient) pipelineProcessCmds(
ctx context.Context, cn *pool.Conn, cmds []Cmder,
) (bool, error) {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmds...)
})
if err != nil {
@ -368,7 +370,7 @@ func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, err
return true, err
}
err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error {
err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
return pipelineReadCmds(rd, cmds)
})
return true, err
@ -384,8 +386,10 @@ func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
return nil
}
func (c *baseClient) txPipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) {
err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
func (c *baseClient) txPipelineProcessCmds(
ctx context.Context, cn *pool.Conn, cmds []Cmder,
) (bool, error) {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return txPipelineWriteMulti(wr, cmds)
})
if err != nil {
@ -393,7 +397,7 @@ func (c *baseClient) txPipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, e
return true, err
}
err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error {
err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
err := txPipelineReadQueued(rd, cmds)
if err != nil {
setCmdsErr(cmds, err)

View File

@ -616,7 +616,7 @@ func (c *Ring) _processPipeline(ctx context.Context, cmds []Cmder) error {
return
}
canRetry, err := shard.Client.pipelineProcessCmds(cn, cmds)
canRetry, err := shard.Client.pipelineProcessCmds(ctx, cn, cmds)
shard.Client.releaseConnStrict(cn, err)
if canRetry && internal.IsRetryableError(err, true) {