Pass context to Dialer

This commit is contained in:
Vladimir Mihailenco 2019-06-04 14:05:29 +03:00
parent 9dba04507e
commit 53c8a4a6b7
15 changed files with 71 additions and 60 deletions

View File

@ -53,7 +53,7 @@ type ClusterOptions struct {
// Following options are copied from Options struct.
Dialer func(network, addr string) (net.Conn, error)
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
OnConnect func(*Conn) error
@ -1055,7 +1055,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro
go func(node *clusterNode, cmds []Cmder) {
defer wg.Done()
cn, err := node.Client.getConn()
cn, err := node.Client.getConn(ctx)
if err != nil {
if err == pool.ErrClosed {
c.mapCmdsByNode(cmds, failedCmds)
@ -1256,7 +1256,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er
go func(node *clusterNode, cmds []Cmder) {
defer wg.Done()
cn, err := node.Client.getConn()
cn, err := node.Client.getConn(ctx)
if err != nil {
if err == pool.ErrClosed {
c.mapCmdsByNode(cmds, failedCmds)

View File

@ -39,7 +39,7 @@ func BenchmarkPoolGetPut(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get()
cn, err := connPool.Get(nil)
if err != nil {
b.Fatal(err)
}
@ -81,7 +81,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get()
cn, err := connPool.Get(nil)
if err != nil {
b.Fatal(err)
}

View File

@ -1,6 +1,7 @@
package pool_test
import (
"context"
"net"
"sync"
"testing"
@ -30,6 +31,6 @@ func perform(n int, cbs ...func(int)) {
wg.Wait()
}
func dummyDialer() (net.Conn, error) {
func dummyDialer(context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil
}

View File

@ -1,6 +1,7 @@
package pool
import (
"context"
"errors"
"net"
"sync"
@ -36,7 +37,7 @@ type Pooler interface {
NewConn() (*Conn, error)
CloseConn(*Conn) error
Get() (*Conn, error)
Get(context.Context) (*Conn, error)
Put(*Conn)
Remove(*Conn)
@ -48,7 +49,7 @@ type Pooler interface {
}
type Options struct {
Dialer func() (net.Conn, error)
Dialer func(c context.Context) (net.Conn, error)
OnClose func(*Conn) error
PoolSize int
@ -114,7 +115,7 @@ func (p *ConnPool) checkMinIdleConns() {
}
func (p *ConnPool) addIdleConn() {
cn, err := p.newConn(true)
cn, err := p.newConn(nil, true)
if err != nil {
return
}
@ -126,11 +127,11 @@ func (p *ConnPool) addIdleConn() {
}
func (p *ConnPool) NewConn() (*Conn, error) {
return p._NewConn(false)
return p._NewConn(nil, false)
}
func (p *ConnPool) _NewConn(pooled bool) (*Conn, error) {
cn, err := p.newConn(pooled)
func (p *ConnPool) _NewConn(c context.Context, pooled bool) (*Conn, error) {
cn, err := p.newConn(c, pooled)
if err != nil {
return nil, err
}
@ -148,7 +149,7 @@ func (p *ConnPool) _NewConn(pooled bool) (*Conn, error) {
return cn, nil
}
func (p *ConnPool) newConn(pooled bool) (*Conn, error) {
func (p *ConnPool) newConn(c context.Context, pooled bool) (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
@ -157,7 +158,7 @@ func (p *ConnPool) newConn(pooled bool) (*Conn, error) {
return nil, p.getLastDialError()
}
netConn, err := p.opt.Dialer()
netConn, err := p.opt.Dialer(c)
if err != nil {
p.setLastDialError(err)
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) {
@ -177,7 +178,7 @@ func (p *ConnPool) tryDial() {
return
}
conn, err := p.opt.Dialer()
conn, err := p.opt.Dialer(nil)
if err != nil {
p.setLastDialError(err)
time.Sleep(time.Second)
@ -204,7 +205,7 @@ func (p *ConnPool) getLastDialError() error {
}
// Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get() (*Conn, error) {
func (p *ConnPool) Get(c context.Context) (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
@ -234,7 +235,7 @@ func (p *ConnPool) Get() (*Conn, error) {
atomic.AddUint32(&p.stats.Misses, 1)
newcn, err := p._NewConn(true)
newcn, err := p._NewConn(c, true)
if err != nil {
p.freeTurn()
return nil, err

View File

@ -1,5 +1,7 @@
package pool
import "context"
type SingleConnPool struct {
cn *Conn
}
@ -20,7 +22,7 @@ func (p *SingleConnPool) CloseConn(*Conn) error {
panic("not implemented")
}
func (p *SingleConnPool) Get() (*Conn, error) {
func (p *SingleConnPool) Get(c context.Context) (*Conn, error) {
return p.cn, nil
}

View File

@ -1,6 +1,9 @@
package pool
import "sync"
import (
"context"
"sync"
)
type StickyConnPool struct {
pool *ConnPool
@ -28,7 +31,7 @@ func (p *StickyConnPool) CloseConn(*Conn) error {
panic("not implemented")
}
func (p *StickyConnPool) Get() (*Conn, error) {
func (p *StickyConnPool) Get(c context.Context) (*Conn, error) {
p.mu.Lock()
defer p.mu.Unlock()
@ -39,7 +42,7 @@ func (p *StickyConnPool) Get() (*Conn, error) {
return p.cn, nil
}
cn, err := p.pool.Get()
cn, err := p.pool.Get(c)
if err != nil {
return nil, err
}

View File

@ -30,13 +30,13 @@ var _ = Describe("ConnPool", func() {
It("should unblock client when conn is removed", func() {
// Reserve one connection.
cn, err := connPool.Get()
cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())
// Reserve all other connections.
var cns []*pool.Conn
for i := 0; i < 9; i++ {
cn, err := connPool.Get()
cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())
cns = append(cns, cn)
}
@ -47,7 +47,7 @@ var _ = Describe("ConnPool", func() {
defer GinkgoRecover()
started <- true
_, err := connPool.Get()
_, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())
done <- true
@ -110,7 +110,7 @@ var _ = Describe("MinIdleConns", func() {
BeforeEach(func() {
var err error
cn, err = connPool.Get()
cn, err = connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())
Eventually(func() int {
@ -145,7 +145,7 @@ var _ = Describe("MinIdleConns", func() {
perform(poolSize, func(_ int) {
defer GinkgoRecover()
cn, err := connPool.Get()
cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())
mu.Lock()
cns = append(cns, cn)
@ -160,7 +160,7 @@ var _ = Describe("MinIdleConns", func() {
It("Get is blocked", func() {
done := make(chan struct{})
go func() {
connPool.Get()
connPool.Get(nil)
close(done)
}()
@ -274,7 +274,7 @@ var _ = Describe("conns reaper", func() {
// add stale connections
staleConns = nil
for i := 0; i < 3; i++ {
cn, err := connPool.Get()
cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())
switch typ {
case "idle":
@ -288,7 +288,7 @@ var _ = Describe("conns reaper", func() {
// add fresh connections
for i := 0; i < 3; i++ {
cn, err := connPool.Get()
cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())
conns = append(conns, cn)
}
@ -333,7 +333,7 @@ var _ = Describe("conns reaper", func() {
for j := 0; j < 3; j++ {
var freeCns []*pool.Conn
for i := 0; i < 3; i++ {
cn, err := connPool.Get()
cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil())
freeCns = append(freeCns, cn)
@ -342,7 +342,7 @@ var _ = Describe("conns reaper", func() {
Expect(connPool.Len()).To(Equal(3))
Expect(connPool.IdleLen()).To(Equal(0))
cn, err := connPool.Get()
cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil())
conns = append(conns, cn)
@ -396,7 +396,7 @@ var _ = Describe("race", func() {
perform(C, func(id int) {
for i := 0; i < N; i++ {
cn, err := connPool.Get()
cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())
if err == nil {
connPool.Put(cn)
@ -404,7 +404,7 @@ var _ = Describe("race", func() {
}
}, func(id int) {
for i := 0; i < N; i++ {
cn, err := connPool.Get()
cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred())
if err == nil {
connPool.Remove(cn)

View File

@ -1,6 +1,7 @@
package redis
import (
"context"
"crypto/tls"
"errors"
"fmt"
@ -34,7 +35,7 @@ type Options struct {
// Dialer creates new network connection and has priority over
// Network and Addr options.
Dialer func(network, addr string) (net.Conn, error)
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
// Hook that is called when new connection is established.
OnConnect func(*Conn) error
@ -105,7 +106,7 @@ func (opt *Options) init() {
opt.Addr = "localhost:6379"
}
if opt.Dialer == nil {
opt.Dialer = func(network, addr string) (net.Conn, error) {
opt.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
netDialer := &net.Dialer{
Timeout: opt.DialTimeout,
KeepAlive: 5 * time.Minute,
@ -215,8 +216,8 @@ func ParseURL(redisURL string) (*Options, error) {
func newConnPool(opt *Options) *pool.ConnPool {
return pool.NewConnPool(&pool.Options{
Dialer: func() (net.Conn, error) {
return opt.Dialer(opt.Network, opt.Addr)
Dialer: func(c context.Context) (net.Conn, error) {
return opt.Dialer(c, opt.Network, opt.Addr)
},
PoolSize: opt.PoolSize,
MinIdleConns: opt.MinIdleConns,

View File

@ -81,7 +81,7 @@ var _ = Describe("pool", func() {
})
It("removes broken connections", func() {
cn, err := client.Pool().Get()
cn, err := client.Pool().Get(nil)
Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{})
client.Pool().Put(cn)

View File

@ -154,7 +154,7 @@ func (c *baseClient) newConn() (*pool.Conn, error) {
return cn, nil
}
func (c *baseClient) getConn() (*pool.Conn, error) {
func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
if c.limiter != nil {
err := c.limiter.Allow()
if err != nil {
@ -162,7 +162,7 @@ func (c *baseClient) getConn() (*pool.Conn, error) {
}
}
cn, err := c._getConn()
cn, err := c._getConn(ctx)
if err != nil {
if c.limiter != nil {
c.limiter.ReportResult(err)
@ -172,8 +172,8 @@ func (c *baseClient) getConn() (*pool.Conn, error) {
return cn, nil
}
func (c *baseClient) _getConn() (*pool.Conn, error) {
cn, err := c.connPool.Get()
func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
cn, err := c.connPool.Get(ctx)
if err != nil {
return nil, err
}
@ -256,7 +256,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
time.Sleep(c.retryBackoff(attempt))
}
cn, err := c.getConn()
cn, err := c.getConn(ctx)
if err != nil {
cmd.setErr(err)
if internal.IsRetryableError(err, true) {
@ -326,22 +326,24 @@ func (c *baseClient) getAddr() string {
}
func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error {
return c.generalProcessPipeline(cmds, c.pipelineProcessCmds)
return c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds)
}
func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
return c.generalProcessPipeline(cmds, c.txPipelineProcessCmds)
return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds)
}
type pipelineProcessor func(*pool.Conn, []Cmder) (bool, error)
func (c *baseClient) generalProcessPipeline(cmds []Cmder, p pipelineProcessor) error {
func (c *baseClient) generalProcessPipeline(
ctx context.Context, cmds []Cmder, p pipelineProcessor,
) error {
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
if attempt > 0 {
time.Sleep(c.retryBackoff(attempt))
}
cn, err := c.getConn()
cn, err := c.getConn(ctx)
if err != nil {
setCmdsErr(cmds, err)
return err

View File

@ -2,6 +2,7 @@ package redis_test
import (
"bytes"
"context"
"net"
"time"
@ -41,7 +42,7 @@ var _ = Describe("Client", func() {
custom := redis.NewClient(&redis.Options{
Network: "tcp",
Addr: redisAddr,
Dialer: func(network, addr string) (net.Conn, error) {
Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
},
})
@ -146,7 +147,7 @@ var _ = Describe("Client", func() {
})
// Put bad connection in the pool.
cn, err := client.Pool().Get()
cn, err := client.Pool().Get(nil)
Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{})
@ -184,7 +185,7 @@ var _ = Describe("Client", func() {
})
It("should update conn.UsedAt on read/write", func() {
cn, err := client.Pool().Get()
cn, err := client.Pool().Get(nil)
Expect(err).NotTo(HaveOccurred())
Expect(cn.UsedAt).NotTo(BeZero())
createdAt := cn.UsedAt()
@ -197,7 +198,7 @@ var _ = Describe("Client", func() {
err = client.Ping().Err()
Expect(err).NotTo(HaveOccurred())
cn, err = client.Pool().Get()
cn, err = client.Pool().Get(nil)
Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil())
Expect(cn.UsedAt().After(createdAt)).To(BeTrue())

View File

@ -610,7 +610,7 @@ func (c *Ring) _processPipeline(ctx context.Context, cmds []Cmder) error {
return
}
cn, err := shard.Client.getConn()
cn, err := shard.Client.getConn(ctx)
if err != nil {
setCmdsErr(cmds, err)
return

View File

@ -21,16 +21,16 @@ type FailoverOptions struct {
// The master name.
MasterName string
// A seed list of host:port addresses of sentinel nodes.
SentinelAddrs []string
SentinelAddrs []string
SentinelPassword string
// Following options are copied from Options struct.
Dialer func(network, addr string) (net.Conn, error)
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
OnConnect func(*Conn) error
Password string
SentinelPassword string
DB int
Password string
DB int
MaxRetries int
MinRetryBackoff time.Duration
@ -312,7 +312,7 @@ func (c *sentinelFailover) Pool() *pool.ConnPool {
return c.pool
}
func (c *sentinelFailover) dial(network, addr string) (net.Conn, error) {
func (c *sentinelFailover) dial(ctx context.Context, network, addr string) (net.Conn, error) {
addr, err := c.MasterAddr()
if err != nil {
return nil, err

View File

@ -124,7 +124,7 @@ var _ = Describe("Tx", func() {
It("should recover from bad connection", func() {
// Put bad connection in the pool.
cn, err := client.Pool().Get()
cn, err := client.Pool().Get(nil)
Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{})

View File

@ -20,7 +20,7 @@ type UniversalOptions struct {
// Common options.
Dialer func(network, addr string) (net.Conn, error)
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
OnConnect func(*Conn) error
Password string
MaxRetries int