forked from mirror/redis
Merge pull request #1046 from go-redis/fix/dialer-context
Pass context to Dialer
This commit is contained in:
commit
4fe609d47c
|
@ -53,7 +53,7 @@ type ClusterOptions struct {
|
|||
|
||||
// Following options are copied from Options struct.
|
||||
|
||||
Dialer func(network, addr string) (net.Conn, error)
|
||||
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
OnConnect func(*Conn) error
|
||||
|
||||
|
@ -1055,7 +1055,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro
|
|||
go func(node *clusterNode, cmds []Cmder) {
|
||||
defer wg.Done()
|
||||
|
||||
cn, err := node.Client.getConn()
|
||||
cn, err := node.Client.getConn(ctx)
|
||||
if err != nil {
|
||||
if err == pool.ErrClosed {
|
||||
c.mapCmdsByNode(cmds, failedCmds)
|
||||
|
@ -1256,7 +1256,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er
|
|||
go func(node *clusterNode, cmds []Cmder) {
|
||||
defer wg.Done()
|
||||
|
||||
cn, err := node.Client.getConn()
|
||||
cn, err := node.Client.getConn(ctx)
|
||||
if err != nil {
|
||||
if err == pool.ErrClosed {
|
||||
c.mapCmdsByNode(cmds, failedCmds)
|
||||
|
|
|
@ -39,7 +39,7 @@ func BenchmarkPoolGetPut(b *testing.B) {
|
|||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cn, err := connPool.Get()
|
||||
cn, err := connPool.Get(nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
@ -81,7 +81,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
|
|||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cn, err := connPool.Get()
|
||||
cn, err := connPool.Get(nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package pool_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
|
@ -30,6 +31,6 @@ func perform(n int, cbs ...func(int)) {
|
|||
wg.Wait()
|
||||
}
|
||||
|
||||
func dummyDialer() (net.Conn, error) {
|
||||
func dummyDialer(context.Context) (net.Conn, error) {
|
||||
return &net.TCPConn{}, nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
|
@ -36,7 +37,7 @@ type Pooler interface {
|
|||
NewConn() (*Conn, error)
|
||||
CloseConn(*Conn) error
|
||||
|
||||
Get() (*Conn, error)
|
||||
Get(context.Context) (*Conn, error)
|
||||
Put(*Conn)
|
||||
Remove(*Conn)
|
||||
|
||||
|
@ -48,7 +49,7 @@ type Pooler interface {
|
|||
}
|
||||
|
||||
type Options struct {
|
||||
Dialer func() (net.Conn, error)
|
||||
Dialer func(c context.Context) (net.Conn, error)
|
||||
OnClose func(*Conn) error
|
||||
|
||||
PoolSize int
|
||||
|
@ -114,7 +115,7 @@ func (p *ConnPool) checkMinIdleConns() {
|
|||
}
|
||||
|
||||
func (p *ConnPool) addIdleConn() {
|
||||
cn, err := p.newConn(true)
|
||||
cn, err := p.newConn(nil, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -126,11 +127,11 @@ func (p *ConnPool) addIdleConn() {
|
|||
}
|
||||
|
||||
func (p *ConnPool) NewConn() (*Conn, error) {
|
||||
return p._NewConn(false)
|
||||
return p._NewConn(nil, false)
|
||||
}
|
||||
|
||||
func (p *ConnPool) _NewConn(pooled bool) (*Conn, error) {
|
||||
cn, err := p.newConn(pooled)
|
||||
func (p *ConnPool) _NewConn(c context.Context, pooled bool) (*Conn, error) {
|
||||
cn, err := p.newConn(c, pooled)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -148,7 +149,7 @@ func (p *ConnPool) _NewConn(pooled bool) (*Conn, error) {
|
|||
return cn, nil
|
||||
}
|
||||
|
||||
func (p *ConnPool) newConn(pooled bool) (*Conn, error) {
|
||||
func (p *ConnPool) newConn(c context.Context, pooled bool) (*Conn, error) {
|
||||
if p.closed() {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
@ -157,7 +158,7 @@ func (p *ConnPool) newConn(pooled bool) (*Conn, error) {
|
|||
return nil, p.getLastDialError()
|
||||
}
|
||||
|
||||
netConn, err := p.opt.Dialer()
|
||||
netConn, err := p.opt.Dialer(c)
|
||||
if err != nil {
|
||||
p.setLastDialError(err)
|
||||
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) {
|
||||
|
@ -177,7 +178,7 @@ func (p *ConnPool) tryDial() {
|
|||
return
|
||||
}
|
||||
|
||||
conn, err := p.opt.Dialer()
|
||||
conn, err := p.opt.Dialer(nil)
|
||||
if err != nil {
|
||||
p.setLastDialError(err)
|
||||
time.Sleep(time.Second)
|
||||
|
@ -204,7 +205,7 @@ func (p *ConnPool) getLastDialError() error {
|
|||
}
|
||||
|
||||
// Get returns existed connection from the pool or creates a new one.
|
||||
func (p *ConnPool) Get() (*Conn, error) {
|
||||
func (p *ConnPool) Get(c context.Context) (*Conn, error) {
|
||||
if p.closed() {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
@ -234,7 +235,7 @@ func (p *ConnPool) Get() (*Conn, error) {
|
|||
|
||||
atomic.AddUint32(&p.stats.Misses, 1)
|
||||
|
||||
newcn, err := p._NewConn(true)
|
||||
newcn, err := p._NewConn(c, true)
|
||||
if err != nil {
|
||||
p.freeTurn()
|
||||
return nil, err
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
package pool
|
||||
|
||||
import "context"
|
||||
|
||||
type SingleConnPool struct {
|
||||
cn *Conn
|
||||
}
|
||||
|
@ -20,7 +22,7 @@ func (p *SingleConnPool) CloseConn(*Conn) error {
|
|||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Get() (*Conn, error) {
|
||||
func (p *SingleConnPool) Get(c context.Context) (*Conn, error) {
|
||||
return p.cn, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package pool
|
||||
|
||||
import "sync"
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type StickyConnPool struct {
|
||||
pool *ConnPool
|
||||
|
@ -28,7 +31,7 @@ func (p *StickyConnPool) CloseConn(*Conn) error {
|
|||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (p *StickyConnPool) Get() (*Conn, error) {
|
||||
func (p *StickyConnPool) Get(c context.Context) (*Conn, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
|
@ -39,7 +42,7 @@ func (p *StickyConnPool) Get() (*Conn, error) {
|
|||
return p.cn, nil
|
||||
}
|
||||
|
||||
cn, err := p.pool.Get()
|
||||
cn, err := p.pool.Get(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -30,13 +30,13 @@ var _ = Describe("ConnPool", func() {
|
|||
|
||||
It("should unblock client when conn is removed", func() {
|
||||
// Reserve one connection.
|
||||
cn, err := connPool.Get()
|
||||
cn, err := connPool.Get(nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Reserve all other connections.
|
||||
var cns []*pool.Conn
|
||||
for i := 0; i < 9; i++ {
|
||||
cn, err := connPool.Get()
|
||||
cn, err := connPool.Get(nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
cns = append(cns, cn)
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ var _ = Describe("ConnPool", func() {
|
|||
defer GinkgoRecover()
|
||||
|
||||
started <- true
|
||||
_, err := connPool.Get()
|
||||
_, err := connPool.Get(nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
done <- true
|
||||
|
||||
|
@ -110,7 +110,7 @@ var _ = Describe("MinIdleConns", func() {
|
|||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
cn, err = connPool.Get()
|
||||
cn, err = connPool.Get(nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Eventually(func() int {
|
||||
|
@ -145,7 +145,7 @@ var _ = Describe("MinIdleConns", func() {
|
|||
perform(poolSize, func(_ int) {
|
||||
defer GinkgoRecover()
|
||||
|
||||
cn, err := connPool.Get()
|
||||
cn, err := connPool.Get(nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
mu.Lock()
|
||||
cns = append(cns, cn)
|
||||
|
@ -160,7 +160,7 @@ var _ = Describe("MinIdleConns", func() {
|
|||
It("Get is blocked", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
connPool.Get()
|
||||
connPool.Get(nil)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
@ -274,7 +274,7 @@ var _ = Describe("conns reaper", func() {
|
|||
// add stale connections
|
||||
staleConns = nil
|
||||
for i := 0; i < 3; i++ {
|
||||
cn, err := connPool.Get()
|
||||
cn, err := connPool.Get(nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
switch typ {
|
||||
case "idle":
|
||||
|
@ -288,7 +288,7 @@ var _ = Describe("conns reaper", func() {
|
|||
|
||||
// add fresh connections
|
||||
for i := 0; i < 3; i++ {
|
||||
cn, err := connPool.Get()
|
||||
cn, err := connPool.Get(nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
conns = append(conns, cn)
|
||||
}
|
||||
|
@ -333,7 +333,7 @@ var _ = Describe("conns reaper", func() {
|
|||
for j := 0; j < 3; j++ {
|
||||
var freeCns []*pool.Conn
|
||||
for i := 0; i < 3; i++ {
|
||||
cn, err := connPool.Get()
|
||||
cn, err := connPool.Get(nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(cn).NotTo(BeNil())
|
||||
freeCns = append(freeCns, cn)
|
||||
|
@ -342,7 +342,7 @@ var _ = Describe("conns reaper", func() {
|
|||
Expect(connPool.Len()).To(Equal(3))
|
||||
Expect(connPool.IdleLen()).To(Equal(0))
|
||||
|
||||
cn, err := connPool.Get()
|
||||
cn, err := connPool.Get(nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(cn).NotTo(BeNil())
|
||||
conns = append(conns, cn)
|
||||
|
@ -396,7 +396,7 @@ var _ = Describe("race", func() {
|
|||
|
||||
perform(C, func(id int) {
|
||||
for i := 0; i < N; i++ {
|
||||
cn, err := connPool.Get()
|
||||
cn, err := connPool.Get(nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
if err == nil {
|
||||
connPool.Put(cn)
|
||||
|
@ -404,7 +404,7 @@ var _ = Describe("race", func() {
|
|||
}
|
||||
}, func(id int) {
|
||||
for i := 0; i < N; i++ {
|
||||
cn, err := connPool.Get()
|
||||
cn, err := connPool.Get(nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
if err == nil {
|
||||
connPool.Remove(cn)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -34,7 +35,7 @@ type Options struct {
|
|||
|
||||
// Dialer creates new network connection and has priority over
|
||||
// Network and Addr options.
|
||||
Dialer func(network, addr string) (net.Conn, error)
|
||||
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// Hook that is called when new connection is established.
|
||||
OnConnect func(*Conn) error
|
||||
|
@ -105,7 +106,7 @@ func (opt *Options) init() {
|
|||
opt.Addr = "localhost:6379"
|
||||
}
|
||||
if opt.Dialer == nil {
|
||||
opt.Dialer = func(network, addr string) (net.Conn, error) {
|
||||
opt.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
netDialer := &net.Dialer{
|
||||
Timeout: opt.DialTimeout,
|
||||
KeepAlive: 5 * time.Minute,
|
||||
|
@ -215,8 +216,8 @@ func ParseURL(redisURL string) (*Options, error) {
|
|||
|
||||
func newConnPool(opt *Options) *pool.ConnPool {
|
||||
return pool.NewConnPool(&pool.Options{
|
||||
Dialer: func() (net.Conn, error) {
|
||||
return opt.Dialer(opt.Network, opt.Addr)
|
||||
Dialer: func(c context.Context) (net.Conn, error) {
|
||||
return opt.Dialer(c, opt.Network, opt.Addr)
|
||||
},
|
||||
PoolSize: opt.PoolSize,
|
||||
MinIdleConns: opt.MinIdleConns,
|
||||
|
|
|
@ -81,7 +81,7 @@ var _ = Describe("pool", func() {
|
|||
})
|
||||
|
||||
It("removes broken connections", func() {
|
||||
cn, err := client.Pool().Get()
|
||||
cn, err := client.Pool().Get(nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
cn.SetNetConn(&badConn{})
|
||||
client.Pool().Put(cn)
|
||||
|
|
20
redis.go
20
redis.go
|
@ -154,7 +154,7 @@ func (c *baseClient) newConn() (*pool.Conn, error) {
|
|||
return cn, nil
|
||||
}
|
||||
|
||||
func (c *baseClient) getConn() (*pool.Conn, error) {
|
||||
func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
|
||||
if c.limiter != nil {
|
||||
err := c.limiter.Allow()
|
||||
if err != nil {
|
||||
|
@ -162,7 +162,7 @@ func (c *baseClient) getConn() (*pool.Conn, error) {
|
|||
}
|
||||
}
|
||||
|
||||
cn, err := c._getConn()
|
||||
cn, err := c._getConn(ctx)
|
||||
if err != nil {
|
||||
if c.limiter != nil {
|
||||
c.limiter.ReportResult(err)
|
||||
|
@ -172,8 +172,8 @@ func (c *baseClient) getConn() (*pool.Conn, error) {
|
|||
return cn, nil
|
||||
}
|
||||
|
||||
func (c *baseClient) _getConn() (*pool.Conn, error) {
|
||||
cn, err := c.connPool.Get()
|
||||
func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
|
||||
cn, err := c.connPool.Get(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -256,7 +256,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
|
|||
time.Sleep(c.retryBackoff(attempt))
|
||||
}
|
||||
|
||||
cn, err := c.getConn()
|
||||
cn, err := c.getConn(ctx)
|
||||
if err != nil {
|
||||
cmd.setErr(err)
|
||||
if internal.IsRetryableError(err, true) {
|
||||
|
@ -326,22 +326,24 @@ func (c *baseClient) getAddr() string {
|
|||
}
|
||||
|
||||
func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error {
|
||||
return c.generalProcessPipeline(cmds, c.pipelineProcessCmds)
|
||||
return c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds)
|
||||
}
|
||||
|
||||
func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
|
||||
return c.generalProcessPipeline(cmds, c.txPipelineProcessCmds)
|
||||
return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds)
|
||||
}
|
||||
|
||||
type pipelineProcessor func(*pool.Conn, []Cmder) (bool, error)
|
||||
|
||||
func (c *baseClient) generalProcessPipeline(cmds []Cmder, p pipelineProcessor) error {
|
||||
func (c *baseClient) generalProcessPipeline(
|
||||
ctx context.Context, cmds []Cmder, p pipelineProcessor,
|
||||
) error {
|
||||
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
time.Sleep(c.retryBackoff(attempt))
|
||||
}
|
||||
|
||||
cn, err := c.getConn()
|
||||
cn, err := c.getConn(ctx)
|
||||
if err != nil {
|
||||
setCmdsErr(cmds, err)
|
||||
return err
|
||||
|
|
|
@ -2,6 +2,7 @@ package redis_test
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
|
@ -41,7 +42,7 @@ var _ = Describe("Client", func() {
|
|||
custom := redis.NewClient(&redis.Options{
|
||||
Network: "tcp",
|
||||
Addr: redisAddr,
|
||||
Dialer: func(network, addr string) (net.Conn, error) {
|
||||
Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial(network, addr)
|
||||
},
|
||||
})
|
||||
|
@ -146,7 +147,7 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
|
||||
// Put bad connection in the pool.
|
||||
cn, err := client.Pool().Get()
|
||||
cn, err := client.Pool().Get(nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
cn.SetNetConn(&badConn{})
|
||||
|
@ -184,7 +185,7 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
|
||||
It("should update conn.UsedAt on read/write", func() {
|
||||
cn, err := client.Pool().Get()
|
||||
cn, err := client.Pool().Get(nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(cn.UsedAt).NotTo(BeZero())
|
||||
createdAt := cn.UsedAt()
|
||||
|
@ -197,7 +198,7 @@ var _ = Describe("Client", func() {
|
|||
err = client.Ping().Err()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
cn, err = client.Pool().Get()
|
||||
cn, err = client.Pool().Get(nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(cn).NotTo(BeNil())
|
||||
Expect(cn.UsedAt().After(createdAt)).To(BeTrue())
|
||||
|
|
2
ring.go
2
ring.go
|
@ -610,7 +610,7 @@ func (c *Ring) _processPipeline(ctx context.Context, cmds []Cmder) error {
|
|||
return
|
||||
}
|
||||
|
||||
cn, err := shard.Client.getConn()
|
||||
cn, err := shard.Client.getConn(ctx)
|
||||
if err != nil {
|
||||
setCmdsErr(cmds, err)
|
||||
return
|
||||
|
|
12
sentinel.go
12
sentinel.go
|
@ -21,16 +21,16 @@ type FailoverOptions struct {
|
|||
// The master name.
|
||||
MasterName string
|
||||
// A seed list of host:port addresses of sentinel nodes.
|
||||
SentinelAddrs []string
|
||||
SentinelAddrs []string
|
||||
SentinelPassword string
|
||||
|
||||
// Following options are copied from Options struct.
|
||||
|
||||
Dialer func(network, addr string) (net.Conn, error)
|
||||
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
OnConnect func(*Conn) error
|
||||
|
||||
Password string
|
||||
SentinelPassword string
|
||||
DB int
|
||||
Password string
|
||||
DB int
|
||||
|
||||
MaxRetries int
|
||||
MinRetryBackoff time.Duration
|
||||
|
@ -312,7 +312,7 @@ func (c *sentinelFailover) Pool() *pool.ConnPool {
|
|||
return c.pool
|
||||
}
|
||||
|
||||
func (c *sentinelFailover) dial(network, addr string) (net.Conn, error) {
|
||||
func (c *sentinelFailover) dial(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
addr, err := c.MasterAddr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -124,7 +124,7 @@ var _ = Describe("Tx", func() {
|
|||
|
||||
It("should recover from bad connection", func() {
|
||||
// Put bad connection in the pool.
|
||||
cn, err := client.Pool().Get()
|
||||
cn, err := client.Pool().Get(nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
cn.SetNetConn(&badConn{})
|
||||
|
|
|
@ -20,7 +20,7 @@ type UniversalOptions struct {
|
|||
|
||||
// Common options.
|
||||
|
||||
Dialer func(network, addr string) (net.Conn, error)
|
||||
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
OnConnect func(*Conn) error
|
||||
Password string
|
||||
MaxRetries int
|
||||
|
|
Loading…
Reference in New Issue