Merge pull request #1047 from go-redis/fix/context-deadline

Use Context.Deadline to set net.Conn deadline
This commit is contained in:
Vladimir Mihailenco 2019-06-08 15:34:34 +03:00 committed by GitHub
commit 530e66a66e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 57 additions and 46 deletions

View File

@ -1065,7 +1065,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro
return return
} }
err = c.pipelineProcessCmds(node, cn, cmds, failedCmds) err = c.pipelineProcessCmds(ctx, node, cn, cmds, failedCmds)
node.Client.releaseConnStrict(cn, err) node.Client.releaseConnStrict(cn, err)
}(node, cmds) }(node, cmds)
} }
@ -1129,9 +1129,9 @@ func (c *ClusterClient) cmdsAreReadOnly(cmds []Cmder) bool {
} }
func (c *ClusterClient) pipelineProcessCmds( func (c *ClusterClient) pipelineProcessCmds(
node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap, ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
) error { ) error {
err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmds...) return writeCmd(wr, cmds...)
}) })
if err != nil { if err != nil {
@ -1142,7 +1142,7 @@ func (c *ClusterClient) pipelineProcessCmds(
return err return err
} }
return cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error { return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
return c.pipelineReadCmds(node, rd, cmds, failedCmds) return c.pipelineReadCmds(node, rd, cmds, failedCmds)
}) })
} }
@ -1266,7 +1266,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er
return return
} }
err = c.txPipelineProcessCmds(node, cn, cmds, failedCmds) err = c.txPipelineProcessCmds(ctx, node, cn, cmds, failedCmds)
node.Client.releaseConnStrict(cn, err) node.Client.releaseConnStrict(cn, err)
}(node, cmds) }(node, cmds)
} }
@ -1292,9 +1292,9 @@ func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder {
} }
func (c *ClusterClient) txPipelineProcessCmds( func (c *ClusterClient) txPipelineProcessCmds(
node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap, ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
) error { ) error {
err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return txPipelineWriteMulti(wr, cmds) return txPipelineWriteMulti(wr, cmds)
}) })
if err != nil { if err != nil {
@ -1305,7 +1305,7 @@ func (c *ClusterClient) txPipelineProcessCmds(
return err return err
} }
err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error { err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
err := c.txPipelineReadQueued(rd, cmds, failedCmds) err := c.txPipelineReadQueued(rd, cmds, failedCmds)
if err != nil { if err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)

View File

@ -1,6 +1,7 @@
package pool package pool
import ( import (
"context"
"net" "net"
"sync/atomic" "sync/atomic"
"time" "time"
@ -48,24 +49,6 @@ func (cn *Conn) SetNetConn(netConn net.Conn) {
cn.wr.Reset(netConn) cn.wr.Reset(netConn)
} }
func (cn *Conn) setReadTimeout(timeout time.Duration) error {
now := time.Now()
cn.SetUsedAt(now)
if timeout > 0 {
return cn.netConn.SetReadDeadline(now.Add(timeout))
}
return cn.netConn.SetReadDeadline(noDeadline)
}
func (cn *Conn) setWriteTimeout(timeout time.Duration) error {
now := time.Now()
cn.SetUsedAt(now)
if timeout > 0 {
return cn.netConn.SetWriteDeadline(now.Add(timeout))
}
return cn.netConn.SetWriteDeadline(noDeadline)
}
func (cn *Conn) Write(b []byte) (int, error) { func (cn *Conn) Write(b []byte) (int, error) {
return cn.netConn.Write(b) return cn.netConn.Write(b)
} }
@ -74,13 +57,17 @@ func (cn *Conn) RemoteAddr() net.Addr {
return cn.netConn.RemoteAddr() return cn.netConn.RemoteAddr()
} }
func (cn *Conn) WithReader(timeout time.Duration, fn func(rd *proto.Reader) error) error { func (cn *Conn) WithReader(ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error) error {
_ = cn.setReadTimeout(timeout) tm := cn.deadline(ctx, timeout)
_ = cn.netConn.SetReadDeadline(tm)
return fn(cn.rd) return fn(cn.rd)
} }
func (cn *Conn) WithWriter(timeout time.Duration, fn func(wr *proto.Writer) error) error { func (cn *Conn) WithWriter(
_ = cn.setWriteTimeout(timeout) ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error,
) error {
tm := cn.deadline(ctx, timeout)
_ = cn.netConn.SetWriteDeadline(tm)
firstErr := fn(cn.wr) firstErr := fn(cn.wr)
err := cn.wr.Flush() err := cn.wr.Flush()
@ -93,3 +80,22 @@ func (cn *Conn) WithWriter(timeout time.Duration, fn func(wr *proto.Writer) erro
func (cn *Conn) Close() error { func (cn *Conn) Close() error {
return cn.netConn.Close() return cn.netConn.Close()
} }
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {
if ctx != nil {
tm, ok := ctx.Deadline()
if ok {
cn.SetUsedAt(tm)
return tm
}
}
now := time.Now()
if timeout > 0 {
cn.SetUsedAt(now)
return now.Add(timeout)
}
cn.SetUsedAt(now)
return noDeadline
}

View File

@ -1,6 +1,7 @@
package redis package redis
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
@ -83,8 +84,8 @@ func (c *PubSub) _conn(newChannels []string) (*pool.Conn, error) {
return cn, nil return cn, nil
} }
func (c *PubSub) writeCmd(cn *pool.Conn, cmd Cmder) error { func (c *PubSub) writeCmd(ctx context.Context, cn *pool.Conn, cmd Cmder) error {
return cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { return cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmd) return writeCmd(wr, cmd)
}) })
} }
@ -128,7 +129,7 @@ func (c *PubSub) _subscribe(
args = append(args, channel) args = append(args, channel)
} }
cmd := NewSliceCmd(args...) cmd := NewSliceCmd(args...)
return c.writeCmd(cn, cmd) return c.writeCmd(context.TODO(), cn, cmd)
} }
func (c *PubSub) releaseConn(cn *pool.Conn, err error, allowTimeout bool) { func (c *PubSub) releaseConn(cn *pool.Conn, err error, allowTimeout bool) {
@ -258,7 +259,7 @@ func (c *PubSub) Ping(payload ...string) error {
return err return err
} }
err = c.writeCmd(cn, cmd) err = c.writeCmd(context.TODO(), cn, cmd)
c.releaseConn(cn, err, false) c.releaseConn(cn, err, false)
return err return err
} }
@ -350,7 +351,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
return nil, err return nil, err
} }
err = cn.WithReader(timeout, func(rd *proto.Reader) error { err = cn.WithReader(context.TODO(), timeout, func(rd *proto.Reader) error {
return c.cmd.readReply(rd) return c.cmd.readReply(rd)
}) })

View File

@ -265,7 +265,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
return err return err
} }
err = cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { err = cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmd) return writeCmd(wr, cmd)
}) })
if err != nil { if err != nil {
@ -277,7 +277,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
return err return err
} }
err = cn.WithReader(c.cmdTimeout(cmd), cmd.readReply) err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply)
c.releaseConn(cn, err) c.releaseConn(cn, err)
if err != nil && internal.IsRetryableError(err, cmd.readTimeout() == nil) { if err != nil && internal.IsRetryableError(err, cmd.readTimeout() == nil) {
continue continue
@ -333,7 +333,7 @@ func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error
return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds) return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds)
} }
type pipelineProcessor func(*pool.Conn, []Cmder) (bool, error) type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error)
func (c *baseClient) generalProcessPipeline( func (c *baseClient) generalProcessPipeline(
ctx context.Context, cmds []Cmder, p pipelineProcessor, ctx context.Context, cmds []Cmder, p pipelineProcessor,
@ -349,7 +349,7 @@ func (c *baseClient) generalProcessPipeline(
return err return err
} }
canRetry, err := p(cn, cmds) canRetry, err := p(ctx, cn, cmds)
c.releaseConnStrict(cn, err) c.releaseConnStrict(cn, err)
if !canRetry || !internal.IsRetryableError(err, true) { if !canRetry || !internal.IsRetryableError(err, true) {
@ -359,8 +359,10 @@ func (c *baseClient) generalProcessPipeline(
return cmdsFirstErr(cmds) return cmdsFirstErr(cmds)
} }
func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) { func (c *baseClient) pipelineProcessCmds(
err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { ctx context.Context, cn *pool.Conn, cmds []Cmder,
) (bool, error) {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmds...) return writeCmd(wr, cmds...)
}) })
if err != nil { if err != nil {
@ -368,7 +370,7 @@ func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, err
return true, err return true, err
} }
err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error { err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
return pipelineReadCmds(rd, cmds) return pipelineReadCmds(rd, cmds)
}) })
return true, err return true, err
@ -384,8 +386,10 @@ func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
return nil return nil
} }
func (c *baseClient) txPipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) { func (c *baseClient) txPipelineProcessCmds(
err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { ctx context.Context, cn *pool.Conn, cmds []Cmder,
) (bool, error) {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return txPipelineWriteMulti(wr, cmds) return txPipelineWriteMulti(wr, cmds)
}) })
if err != nil { if err != nil {
@ -393,7 +397,7 @@ func (c *baseClient) txPipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, e
return true, err return true, err
} }
err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error { err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
err := txPipelineReadQueued(rd, cmds) err := txPipelineReadQueued(rd, cmds)
if err != nil { if err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)

View File

@ -616,7 +616,7 @@ func (c *Ring) _processPipeline(ctx context.Context, cmds []Cmder) error {
return return
} }
canRetry, err := shard.Client.pipelineProcessCmds(cn, cmds) canRetry, err := shard.Client.pipelineProcessCmds(ctx, cn, cmds)
shard.Client.releaseConnStrict(cn, err) shard.Client.releaseConnStrict(cn, err)
if canRetry && internal.IsRetryableError(err, true) { if canRetry && internal.IsRetryableError(err, true) {