Pass context to Dialer

This commit is contained in:
Vladimir Mihailenco 2019-06-04 14:05:29 +03:00
parent 9dba04507e
commit 53c8a4a6b7
15 changed files with 71 additions and 60 deletions

View File

@ -53,7 +53,7 @@ type ClusterOptions struct {
// Following options are copied from Options struct. // Following options are copied from Options struct.
Dialer func(network, addr string) (net.Conn, error) Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
OnConnect func(*Conn) error OnConnect func(*Conn) error
@ -1055,7 +1055,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro
go func(node *clusterNode, cmds []Cmder) { go func(node *clusterNode, cmds []Cmder) {
defer wg.Done() defer wg.Done()
cn, err := node.Client.getConn() cn, err := node.Client.getConn(ctx)
if err != nil { if err != nil {
if err == pool.ErrClosed { if err == pool.ErrClosed {
c.mapCmdsByNode(cmds, failedCmds) c.mapCmdsByNode(cmds, failedCmds)
@ -1256,7 +1256,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er
go func(node *clusterNode, cmds []Cmder) { go func(node *clusterNode, cmds []Cmder) {
defer wg.Done() defer wg.Done()
cn, err := node.Client.getConn() cn, err := node.Client.getConn(ctx)
if err != nil { if err != nil {
if err == pool.ErrClosed { if err == pool.ErrClosed {
c.mapCmdsByNode(cmds, failedCmds) c.mapCmdsByNode(cmds, failedCmds)

View File

@ -39,7 +39,7 @@ func BenchmarkPoolGetPut(b *testing.B) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
cn, err := connPool.Get() cn, err := connPool.Get(nil)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -81,7 +81,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
cn, err := connPool.Get() cn, err := connPool.Get(nil)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

View File

@ -1,6 +1,7 @@
package pool_test package pool_test
import ( import (
"context"
"net" "net"
"sync" "sync"
"testing" "testing"
@ -30,6 +31,6 @@ func perform(n int, cbs ...func(int)) {
wg.Wait() wg.Wait()
} }
func dummyDialer() (net.Conn, error) { func dummyDialer(context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil return &net.TCPConn{}, nil
} }

View File

@ -1,6 +1,7 @@
package pool package pool
import ( import (
"context"
"errors" "errors"
"net" "net"
"sync" "sync"
@ -36,7 +37,7 @@ type Pooler interface {
NewConn() (*Conn, error) NewConn() (*Conn, error)
CloseConn(*Conn) error CloseConn(*Conn) error
Get() (*Conn, error) Get(context.Context) (*Conn, error)
Put(*Conn) Put(*Conn)
Remove(*Conn) Remove(*Conn)
@ -48,7 +49,7 @@ type Pooler interface {
} }
type Options struct { type Options struct {
Dialer func() (net.Conn, error) Dialer func(c context.Context) (net.Conn, error)
OnClose func(*Conn) error OnClose func(*Conn) error
PoolSize int PoolSize int
@ -114,7 +115,7 @@ func (p *ConnPool) checkMinIdleConns() {
} }
func (p *ConnPool) addIdleConn() { func (p *ConnPool) addIdleConn() {
cn, err := p.newConn(true) cn, err := p.newConn(nil, true)
if err != nil { if err != nil {
return return
} }
@ -126,11 +127,11 @@ func (p *ConnPool) addIdleConn() {
} }
func (p *ConnPool) NewConn() (*Conn, error) { func (p *ConnPool) NewConn() (*Conn, error) {
return p._NewConn(false) return p._NewConn(nil, false)
} }
func (p *ConnPool) _NewConn(pooled bool) (*Conn, error) { func (p *ConnPool) _NewConn(c context.Context, pooled bool) (*Conn, error) {
cn, err := p.newConn(pooled) cn, err := p.newConn(c, pooled)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -148,7 +149,7 @@ func (p *ConnPool) _NewConn(pooled bool) (*Conn, error) {
return cn, nil return cn, nil
} }
func (p *ConnPool) newConn(pooled bool) (*Conn, error) { func (p *ConnPool) newConn(c context.Context, pooled bool) (*Conn, error) {
if p.closed() { if p.closed() {
return nil, ErrClosed return nil, ErrClosed
} }
@ -157,7 +158,7 @@ func (p *ConnPool) newConn(pooled bool) (*Conn, error) {
return nil, p.getLastDialError() return nil, p.getLastDialError()
} }
netConn, err := p.opt.Dialer() netConn, err := p.opt.Dialer(c)
if err != nil { if err != nil {
p.setLastDialError(err) p.setLastDialError(err)
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) { if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) {
@ -177,7 +178,7 @@ func (p *ConnPool) tryDial() {
return return
} }
conn, err := p.opt.Dialer() conn, err := p.opt.Dialer(nil)
if err != nil { if err != nil {
p.setLastDialError(err) p.setLastDialError(err)
time.Sleep(time.Second) time.Sleep(time.Second)
@ -204,7 +205,7 @@ func (p *ConnPool) getLastDialError() error {
} }
// Get returns existed connection from the pool or creates a new one. // Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get() (*Conn, error) { func (p *ConnPool) Get(c context.Context) (*Conn, error) {
if p.closed() { if p.closed() {
return nil, ErrClosed return nil, ErrClosed
} }
@ -234,7 +235,7 @@ func (p *ConnPool) Get() (*Conn, error) {
atomic.AddUint32(&p.stats.Misses, 1) atomic.AddUint32(&p.stats.Misses, 1)
newcn, err := p._NewConn(true) newcn, err := p._NewConn(c, true)
if err != nil { if err != nil {
p.freeTurn() p.freeTurn()
return nil, err return nil, err

View File

@ -1,5 +1,7 @@
package pool package pool
import "context"
type SingleConnPool struct { type SingleConnPool struct {
cn *Conn cn *Conn
} }
@ -20,7 +22,7 @@ func (p *SingleConnPool) CloseConn(*Conn) error {
panic("not implemented") panic("not implemented")
} }
func (p *SingleConnPool) Get() (*Conn, error) { func (p *SingleConnPool) Get(c context.Context) (*Conn, error) {
return p.cn, nil return p.cn, nil
} }

View File

@ -1,6 +1,9 @@
package pool package pool
import "sync" import (
"context"
"sync"
)
type StickyConnPool struct { type StickyConnPool struct {
pool *ConnPool pool *ConnPool
@ -28,7 +31,7 @@ func (p *StickyConnPool) CloseConn(*Conn) error {
panic("not implemented") panic("not implemented")
} }
func (p *StickyConnPool) Get() (*Conn, error) { func (p *StickyConnPool) Get(c context.Context) (*Conn, error) {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
@ -39,7 +42,7 @@ func (p *StickyConnPool) Get() (*Conn, error) {
return p.cn, nil return p.cn, nil
} }
cn, err := p.pool.Get() cn, err := p.pool.Get(c)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -30,13 +30,13 @@ var _ = Describe("ConnPool", func() {
It("should unblock client when conn is removed", func() { It("should unblock client when conn is removed", func() {
// Reserve one connection. // Reserve one connection.
cn, err := connPool.Get() cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
// Reserve all other connections. // Reserve all other connections.
var cns []*pool.Conn var cns []*pool.Conn
for i := 0; i < 9; i++ { for i := 0; i < 9; i++ {
cn, err := connPool.Get() cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cns = append(cns, cn) cns = append(cns, cn)
} }
@ -47,7 +47,7 @@ var _ = Describe("ConnPool", func() {
defer GinkgoRecover() defer GinkgoRecover()
started <- true started <- true
_, err := connPool.Get() _, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
done <- true done <- true
@ -110,7 +110,7 @@ var _ = Describe("MinIdleConns", func() {
BeforeEach(func() { BeforeEach(func() {
var err error var err error
cn, err = connPool.Get() cn, err = connPool.Get(nil)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Eventually(func() int { Eventually(func() int {
@ -145,7 +145,7 @@ var _ = Describe("MinIdleConns", func() {
perform(poolSize, func(_ int) { perform(poolSize, func(_ int) {
defer GinkgoRecover() defer GinkgoRecover()
cn, err := connPool.Get() cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
mu.Lock() mu.Lock()
cns = append(cns, cn) cns = append(cns, cn)
@ -160,7 +160,7 @@ var _ = Describe("MinIdleConns", func() {
It("Get is blocked", func() { It("Get is blocked", func() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
connPool.Get() connPool.Get(nil)
close(done) close(done)
}() }()
@ -274,7 +274,7 @@ var _ = Describe("conns reaper", func() {
// add stale connections // add stale connections
staleConns = nil staleConns = nil
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
cn, err := connPool.Get() cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
switch typ { switch typ {
case "idle": case "idle":
@ -288,7 +288,7 @@ var _ = Describe("conns reaper", func() {
// add fresh connections // add fresh connections
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
cn, err := connPool.Get() cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
conns = append(conns, cn) conns = append(conns, cn)
} }
@ -333,7 +333,7 @@ var _ = Describe("conns reaper", func() {
for j := 0; j < 3; j++ { for j := 0; j < 3; j++ {
var freeCns []*pool.Conn var freeCns []*pool.Conn
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
cn, err := connPool.Get() cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil()) Expect(cn).NotTo(BeNil())
freeCns = append(freeCns, cn) freeCns = append(freeCns, cn)
@ -342,7 +342,7 @@ var _ = Describe("conns reaper", func() {
Expect(connPool.Len()).To(Equal(3)) Expect(connPool.Len()).To(Equal(3))
Expect(connPool.IdleLen()).To(Equal(0)) Expect(connPool.IdleLen()).To(Equal(0))
cn, err := connPool.Get() cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil()) Expect(cn).NotTo(BeNil())
conns = append(conns, cn) conns = append(conns, cn)
@ -396,7 +396,7 @@ var _ = Describe("race", func() {
perform(C, func(id int) { perform(C, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
cn, err := connPool.Get() cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
if err == nil { if err == nil {
connPool.Put(cn) connPool.Put(cn)
@ -404,7 +404,7 @@ var _ = Describe("race", func() {
} }
}, func(id int) { }, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
cn, err := connPool.Get() cn, err := connPool.Get(nil)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
if err == nil { if err == nil {
connPool.Remove(cn) connPool.Remove(cn)

View File

@ -1,6 +1,7 @@
package redis package redis
import ( import (
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@ -34,7 +35,7 @@ type Options struct {
// Dialer creates new network connection and has priority over // Dialer creates new network connection and has priority over
// Network and Addr options. // Network and Addr options.
Dialer func(network, addr string) (net.Conn, error) Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
// Hook that is called when new connection is established. // Hook that is called when new connection is established.
OnConnect func(*Conn) error OnConnect func(*Conn) error
@ -105,7 +106,7 @@ func (opt *Options) init() {
opt.Addr = "localhost:6379" opt.Addr = "localhost:6379"
} }
if opt.Dialer == nil { if opt.Dialer == nil {
opt.Dialer = func(network, addr string) (net.Conn, error) { opt.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
netDialer := &net.Dialer{ netDialer := &net.Dialer{
Timeout: opt.DialTimeout, Timeout: opt.DialTimeout,
KeepAlive: 5 * time.Minute, KeepAlive: 5 * time.Minute,
@ -215,8 +216,8 @@ func ParseURL(redisURL string) (*Options, error) {
func newConnPool(opt *Options) *pool.ConnPool { func newConnPool(opt *Options) *pool.ConnPool {
return pool.NewConnPool(&pool.Options{ return pool.NewConnPool(&pool.Options{
Dialer: func() (net.Conn, error) { Dialer: func(c context.Context) (net.Conn, error) {
return opt.Dialer(opt.Network, opt.Addr) return opt.Dialer(c, opt.Network, opt.Addr)
}, },
PoolSize: opt.PoolSize, PoolSize: opt.PoolSize,
MinIdleConns: opt.MinIdleConns, MinIdleConns: opt.MinIdleConns,

View File

@ -81,7 +81,7 @@ var _ = Describe("pool", func() {
}) })
It("removes broken connections", func() { It("removes broken connections", func() {
cn, err := client.Pool().Get() cn, err := client.Pool().Get(nil)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{}) cn.SetNetConn(&badConn{})
client.Pool().Put(cn) client.Pool().Put(cn)

View File

@ -154,7 +154,7 @@ func (c *baseClient) newConn() (*pool.Conn, error) {
return cn, nil return cn, nil
} }
func (c *baseClient) getConn() (*pool.Conn, error) { func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
if c.limiter != nil { if c.limiter != nil {
err := c.limiter.Allow() err := c.limiter.Allow()
if err != nil { if err != nil {
@ -162,7 +162,7 @@ func (c *baseClient) getConn() (*pool.Conn, error) {
} }
} }
cn, err := c._getConn() cn, err := c._getConn(ctx)
if err != nil { if err != nil {
if c.limiter != nil { if c.limiter != nil {
c.limiter.ReportResult(err) c.limiter.ReportResult(err)
@ -172,8 +172,8 @@ func (c *baseClient) getConn() (*pool.Conn, error) {
return cn, nil return cn, nil
} }
func (c *baseClient) _getConn() (*pool.Conn, error) { func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
cn, err := c.connPool.Get() cn, err := c.connPool.Get(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -256,7 +256,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
time.Sleep(c.retryBackoff(attempt)) time.Sleep(c.retryBackoff(attempt))
} }
cn, err := c.getConn() cn, err := c.getConn(ctx)
if err != nil { if err != nil {
cmd.setErr(err) cmd.setErr(err)
if internal.IsRetryableError(err, true) { if internal.IsRetryableError(err, true) {
@ -326,22 +326,24 @@ func (c *baseClient) getAddr() string {
} }
func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error { func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error {
return c.generalProcessPipeline(cmds, c.pipelineProcessCmds) return c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds)
} }
func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error { func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
return c.generalProcessPipeline(cmds, c.txPipelineProcessCmds) return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds)
} }
type pipelineProcessor func(*pool.Conn, []Cmder) (bool, error) type pipelineProcessor func(*pool.Conn, []Cmder) (bool, error)
func (c *baseClient) generalProcessPipeline(cmds []Cmder, p pipelineProcessor) error { func (c *baseClient) generalProcessPipeline(
ctx context.Context, cmds []Cmder, p pipelineProcessor,
) error {
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
if attempt > 0 { if attempt > 0 {
time.Sleep(c.retryBackoff(attempt)) time.Sleep(c.retryBackoff(attempt))
} }
cn, err := c.getConn() cn, err := c.getConn(ctx)
if err != nil { if err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)
return err return err

View File

@ -2,6 +2,7 @@ package redis_test
import ( import (
"bytes" "bytes"
"context"
"net" "net"
"time" "time"
@ -41,7 +42,7 @@ var _ = Describe("Client", func() {
custom := redis.NewClient(&redis.Options{ custom := redis.NewClient(&redis.Options{
Network: "tcp", Network: "tcp",
Addr: redisAddr, Addr: redisAddr,
Dialer: func(network, addr string) (net.Conn, error) { Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial(network, addr) return net.Dial(network, addr)
}, },
}) })
@ -146,7 +147,7 @@ var _ = Describe("Client", func() {
}) })
// Put bad connection in the pool. // Put bad connection in the pool.
cn, err := client.Pool().Get() cn, err := client.Pool().Get(nil)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{}) cn.SetNetConn(&badConn{})
@ -184,7 +185,7 @@ var _ = Describe("Client", func() {
}) })
It("should update conn.UsedAt on read/write", func() { It("should update conn.UsedAt on read/write", func() {
cn, err := client.Pool().Get() cn, err := client.Pool().Get(nil)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn.UsedAt).NotTo(BeZero()) Expect(cn.UsedAt).NotTo(BeZero())
createdAt := cn.UsedAt() createdAt := cn.UsedAt()
@ -197,7 +198,7 @@ var _ = Describe("Client", func() {
err = client.Ping().Err() err = client.Ping().Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn, err = client.Pool().Get() cn, err = client.Pool().Get(nil)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil()) Expect(cn).NotTo(BeNil())
Expect(cn.UsedAt().After(createdAt)).To(BeTrue()) Expect(cn.UsedAt().After(createdAt)).To(BeTrue())

View File

@ -610,7 +610,7 @@ func (c *Ring) _processPipeline(ctx context.Context, cmds []Cmder) error {
return return
} }
cn, err := shard.Client.getConn() cn, err := shard.Client.getConn(ctx)
if err != nil { if err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)
return return

View File

@ -22,14 +22,14 @@ type FailoverOptions struct {
MasterName string MasterName string
// A seed list of host:port addresses of sentinel nodes. // A seed list of host:port addresses of sentinel nodes.
SentinelAddrs []string SentinelAddrs []string
SentinelPassword string
// Following options are copied from Options struct. // Following options are copied from Options struct.
Dialer func(network, addr string) (net.Conn, error) Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
OnConnect func(*Conn) error OnConnect func(*Conn) error
Password string Password string
SentinelPassword string
DB int DB int
MaxRetries int MaxRetries int
@ -312,7 +312,7 @@ func (c *sentinelFailover) Pool() *pool.ConnPool {
return c.pool return c.pool
} }
func (c *sentinelFailover) dial(network, addr string) (net.Conn, error) { func (c *sentinelFailover) dial(ctx context.Context, network, addr string) (net.Conn, error) {
addr, err := c.MasterAddr() addr, err := c.MasterAddr()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -124,7 +124,7 @@ var _ = Describe("Tx", func() {
It("should recover from bad connection", func() { It("should recover from bad connection", func() {
// Put bad connection in the pool. // Put bad connection in the pool.
cn, err := client.Pool().Get() cn, err := client.Pool().Get(nil)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{}) cn.SetNetConn(&badConn{})

View File

@ -20,7 +20,7 @@ type UniversalOptions struct {
// Common options. // Common options.
Dialer func(network, addr string) (net.Conn, error) Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
OnConnect func(*Conn) error OnConnect func(*Conn) error
Password string Password string
MaxRetries int MaxRetries int