diff --git a/cluster.go b/cluster.go index 60373872..29ae471a 100644 --- a/cluster.go +++ b/cluster.go @@ -53,7 +53,7 @@ type ClusterOptions struct { // Following options are copied from Options struct. - Dialer func(network, addr string) (net.Conn, error) + Dialer func(ctx context.Context, network, addr string) (net.Conn, error) OnConnect func(*Conn) error @@ -1055,7 +1055,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro go func(node *clusterNode, cmds []Cmder) { defer wg.Done() - cn, err := node.Client.getConn() + cn, err := node.Client.getConn(ctx) if err != nil { if err == pool.ErrClosed { c.mapCmdsByNode(cmds, failedCmds) @@ -1256,7 +1256,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er go func(node *clusterNode, cmds []Cmder) { defer wg.Done() - cn, err := node.Client.getConn() + cn, err := node.Client.getConn(ctx) if err != nil { if err == pool.ErrClosed { c.mapCmdsByNode(cmds, failedCmds) diff --git a/internal/pool/bench_test.go b/internal/pool/bench_test.go index 54fbf913..82e2d5d9 100644 --- a/internal/pool/bench_test.go +++ b/internal/pool/bench_test.go @@ -39,7 +39,7 @@ func BenchmarkPoolGetPut(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - cn, err := connPool.Get() + cn, err := connPool.Get(nil) if err != nil { b.Fatal(err) } @@ -81,7 +81,7 @@ func BenchmarkPoolGetRemove(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - cn, err := connPool.Get() + cn, err := connPool.Get(nil) if err != nil { b.Fatal(err) } diff --git a/internal/pool/main_test.go b/internal/pool/main_test.go index 43afe3fa..2365dbc6 100644 --- a/internal/pool/main_test.go +++ b/internal/pool/main_test.go @@ -1,6 +1,7 @@ package pool_test import ( + "context" "net" "sync" "testing" @@ -30,6 +31,6 @@ func perform(n int, cbs ...func(int)) { wg.Wait() } -func dummyDialer() (net.Conn, error) { +func dummyDialer(context.Context) (net.Conn, error) { return &net.TCPConn{}, nil } diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 88b059ca..fa6855c7 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -1,6 +1,7 @@ package pool import ( + "context" "errors" "net" "sync" @@ -36,7 +37,7 @@ type Pooler interface { NewConn() (*Conn, error) CloseConn(*Conn) error - Get() (*Conn, error) + Get(context.Context) (*Conn, error) Put(*Conn) Remove(*Conn) @@ -48,7 +49,7 @@ type Pooler interface { } type Options struct { - Dialer func() (net.Conn, error) + Dialer func(c context.Context) (net.Conn, error) OnClose func(*Conn) error PoolSize int @@ -114,7 +115,7 @@ func (p *ConnPool) checkMinIdleConns() { } func (p *ConnPool) addIdleConn() { - cn, err := p.newConn(true) + cn, err := p.newConn(nil, true) if err != nil { return } @@ -126,11 +127,11 @@ func (p *ConnPool) addIdleConn() { } func (p *ConnPool) NewConn() (*Conn, error) { - return p._NewConn(false) + return p._NewConn(nil, false) } -func (p *ConnPool) _NewConn(pooled bool) (*Conn, error) { - cn, err := p.newConn(pooled) +func (p *ConnPool) _NewConn(c context.Context, pooled bool) (*Conn, error) { + cn, err := p.newConn(c, pooled) if err != nil { return nil, err } @@ -148,7 +149,7 @@ func (p *ConnPool) _NewConn(pooled bool) (*Conn, error) { return cn, nil } -func (p *ConnPool) newConn(pooled bool) (*Conn, error) { +func (p *ConnPool) newConn(c context.Context, pooled bool) (*Conn, error) { if p.closed() { return nil, ErrClosed } @@ -157,7 +158,7 @@ func (p *ConnPool) newConn(pooled bool) (*Conn, error) { return nil, p.getLastDialError() } - netConn, err := p.opt.Dialer() + netConn, err := p.opt.Dialer(c) if err != nil { p.setLastDialError(err) if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) { @@ -177,7 +178,7 @@ func (p *ConnPool) tryDial() { return } - conn, err := p.opt.Dialer() + conn, err := p.opt.Dialer(nil) if err != nil { p.setLastDialError(err) time.Sleep(time.Second) @@ -204,7 +205,7 @@ func (p *ConnPool) getLastDialError() error { } // Get returns existed connection from the pool or creates a new one. -func (p *ConnPool) Get() (*Conn, error) { +func (p *ConnPool) Get(c context.Context) (*Conn, error) { if p.closed() { return nil, ErrClosed } @@ -234,7 +235,7 @@ func (p *ConnPool) Get() (*Conn, error) { atomic.AddUint32(&p.stats.Misses, 1) - newcn, err := p._NewConn(true) + newcn, err := p._NewConn(c, true) if err != nil { p.freeTurn() return nil, err diff --git a/internal/pool/pool_single.go b/internal/pool/pool_single.go index b35b78af..6112d8f3 100644 --- a/internal/pool/pool_single.go +++ b/internal/pool/pool_single.go @@ -1,5 +1,7 @@ package pool +import "context" + type SingleConnPool struct { cn *Conn } @@ -20,7 +22,7 @@ func (p *SingleConnPool) CloseConn(*Conn) error { panic("not implemented") } -func (p *SingleConnPool) Get() (*Conn, error) { +func (p *SingleConnPool) Get(c context.Context) (*Conn, error) { return p.cn, nil } diff --git a/internal/pool/pool_sticky.go b/internal/pool/pool_sticky.go index 91bd9133..1e632ec1 100644 --- a/internal/pool/pool_sticky.go +++ b/internal/pool/pool_sticky.go @@ -1,6 +1,9 @@ package pool -import "sync" +import ( + "context" + "sync" +) type StickyConnPool struct { pool *ConnPool @@ -28,7 +31,7 @@ func (p *StickyConnPool) CloseConn(*Conn) error { panic("not implemented") } -func (p *StickyConnPool) Get() (*Conn, error) { +func (p *StickyConnPool) Get(c context.Context) (*Conn, error) { p.mu.Lock() defer p.mu.Unlock() @@ -39,7 +42,7 @@ func (p *StickyConnPool) Get() (*Conn, error) { return p.cn, nil } - cn, err := p.pool.Get() + cn, err := p.pool.Get(c) if err != nil { return nil, err } diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index 07fb48ac..18d02780 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -30,13 +30,13 @@ var _ = Describe("ConnPool", func() { It("should unblock client when conn is removed", func() { // Reserve one connection. - cn, err := connPool.Get() + cn, err := connPool.Get(nil) Expect(err).NotTo(HaveOccurred()) // Reserve all other connections. var cns []*pool.Conn for i := 0; i < 9; i++ { - cn, err := connPool.Get() + cn, err := connPool.Get(nil) Expect(err).NotTo(HaveOccurred()) cns = append(cns, cn) } @@ -47,7 +47,7 @@ var _ = Describe("ConnPool", func() { defer GinkgoRecover() started <- true - _, err := connPool.Get() + _, err := connPool.Get(nil) Expect(err).NotTo(HaveOccurred()) done <- true @@ -110,7 +110,7 @@ var _ = Describe("MinIdleConns", func() { BeforeEach(func() { var err error - cn, err = connPool.Get() + cn, err = connPool.Get(nil) Expect(err).NotTo(HaveOccurred()) Eventually(func() int { @@ -145,7 +145,7 @@ var _ = Describe("MinIdleConns", func() { perform(poolSize, func(_ int) { defer GinkgoRecover() - cn, err := connPool.Get() + cn, err := connPool.Get(nil) Expect(err).NotTo(HaveOccurred()) mu.Lock() cns = append(cns, cn) @@ -160,7 +160,7 @@ var _ = Describe("MinIdleConns", func() { It("Get is blocked", func() { done := make(chan struct{}) go func() { - connPool.Get() + connPool.Get(nil) close(done) }() @@ -274,7 +274,7 @@ var _ = Describe("conns reaper", func() { // add stale connections staleConns = nil for i := 0; i < 3; i++ { - cn, err := connPool.Get() + cn, err := connPool.Get(nil) Expect(err).NotTo(HaveOccurred()) switch typ { case "idle": @@ -288,7 +288,7 @@ var _ = Describe("conns reaper", func() { // add fresh connections for i := 0; i < 3; i++ { - cn, err := connPool.Get() + cn, err := connPool.Get(nil) Expect(err).NotTo(HaveOccurred()) conns = append(conns, cn) } @@ -333,7 +333,7 @@ var _ = Describe("conns reaper", func() { for j := 0; j < 3; j++ { var freeCns []*pool.Conn for i := 0; i < 3; i++ { - cn, err := connPool.Get() + cn, err := connPool.Get(nil) Expect(err).NotTo(HaveOccurred()) Expect(cn).NotTo(BeNil()) freeCns = append(freeCns, cn) @@ -342,7 +342,7 @@ var _ = Describe("conns reaper", func() { Expect(connPool.Len()).To(Equal(3)) Expect(connPool.IdleLen()).To(Equal(0)) - cn, err := connPool.Get() + cn, err := connPool.Get(nil) Expect(err).NotTo(HaveOccurred()) Expect(cn).NotTo(BeNil()) conns = append(conns, cn) @@ -396,7 +396,7 @@ var _ = Describe("race", func() { perform(C, func(id int) { for i := 0; i < N; i++ { - cn, err := connPool.Get() + cn, err := connPool.Get(nil) Expect(err).NotTo(HaveOccurred()) if err == nil { connPool.Put(cn) @@ -404,7 +404,7 @@ var _ = Describe("race", func() { } }, func(id int) { for i := 0; i < N; i++ { - cn, err := connPool.Get() + cn, err := connPool.Get(nil) Expect(err).NotTo(HaveOccurred()) if err == nil { connPool.Remove(cn) diff --git a/options.go b/options.go index 2a4d39f5..b4a2d12d 100644 --- a/options.go +++ b/options.go @@ -1,6 +1,7 @@ package redis import ( + "context" "crypto/tls" "errors" "fmt" @@ -34,7 +35,7 @@ type Options struct { // Dialer creates new network connection and has priority over // Network and Addr options. - Dialer func(network, addr string) (net.Conn, error) + Dialer func(ctx context.Context, network, addr string) (net.Conn, error) // Hook that is called when new connection is established. OnConnect func(*Conn) error @@ -105,7 +106,7 @@ func (opt *Options) init() { opt.Addr = "localhost:6379" } if opt.Dialer == nil { - opt.Dialer = func(network, addr string) (net.Conn, error) { + opt.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) { netDialer := &net.Dialer{ Timeout: opt.DialTimeout, KeepAlive: 5 * time.Minute, @@ -215,8 +216,8 @@ func ParseURL(redisURL string) (*Options, error) { func newConnPool(opt *Options) *pool.ConnPool { return pool.NewConnPool(&pool.Options{ - Dialer: func() (net.Conn, error) { - return opt.Dialer(opt.Network, opt.Addr) + Dialer: func(c context.Context) (net.Conn, error) { + return opt.Dialer(c, opt.Network, opt.Addr) }, PoolSize: opt.PoolSize, MinIdleConns: opt.MinIdleConns, diff --git a/pool_test.go b/pool_test.go index c8fa8590..3adcebbb 100644 --- a/pool_test.go +++ b/pool_test.go @@ -81,7 +81,7 @@ var _ = Describe("pool", func() { }) It("removes broken connections", func() { - cn, err := client.Pool().Get() + cn, err := client.Pool().Get(nil) Expect(err).NotTo(HaveOccurred()) cn.SetNetConn(&badConn{}) client.Pool().Put(cn) diff --git a/redis.go b/redis.go index 793a4ccf..6870e2b0 100644 --- a/redis.go +++ b/redis.go @@ -154,7 +154,7 @@ func (c *baseClient) newConn() (*pool.Conn, error) { return cn, nil } -func (c *baseClient) getConn() (*pool.Conn, error) { +func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) { if c.limiter != nil { err := c.limiter.Allow() if err != nil { @@ -162,7 +162,7 @@ func (c *baseClient) getConn() (*pool.Conn, error) { } } - cn, err := c._getConn() + cn, err := c._getConn(ctx) if err != nil { if c.limiter != nil { c.limiter.ReportResult(err) @@ -172,8 +172,8 @@ func (c *baseClient) getConn() (*pool.Conn, error) { return cn, nil } -func (c *baseClient) _getConn() (*pool.Conn, error) { - cn, err := c.connPool.Get() +func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { + cn, err := c.connPool.Get(ctx) if err != nil { return nil, err } @@ -256,7 +256,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error { time.Sleep(c.retryBackoff(attempt)) } - cn, err := c.getConn() + cn, err := c.getConn(ctx) if err != nil { cmd.setErr(err) if internal.IsRetryableError(err, true) { @@ -326,22 +326,24 @@ func (c *baseClient) getAddr() string { } func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error { - return c.generalProcessPipeline(cmds, c.pipelineProcessCmds) + return c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds) } func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error { - return c.generalProcessPipeline(cmds, c.txPipelineProcessCmds) + return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds) } type pipelineProcessor func(*pool.Conn, []Cmder) (bool, error) -func (c *baseClient) generalProcessPipeline(cmds []Cmder, p pipelineProcessor) error { +func (c *baseClient) generalProcessPipeline( + ctx context.Context, cmds []Cmder, p pipelineProcessor, +) error { for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { if attempt > 0 { time.Sleep(c.retryBackoff(attempt)) } - cn, err := c.getConn() + cn, err := c.getConn(ctx) if err != nil { setCmdsErr(cmds, err) return err diff --git a/redis_test.go b/redis_test.go index e829686c..80175350 100644 --- a/redis_test.go +++ b/redis_test.go @@ -2,6 +2,7 @@ package redis_test import ( "bytes" + "context" "net" "time" @@ -41,7 +42,7 @@ var _ = Describe("Client", func() { custom := redis.NewClient(&redis.Options{ Network: "tcp", Addr: redisAddr, - Dialer: func(network, addr string) (net.Conn, error) { + Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) { return net.Dial(network, addr) }, }) @@ -146,7 +147,7 @@ var _ = Describe("Client", func() { }) // Put bad connection in the pool. - cn, err := client.Pool().Get() + cn, err := client.Pool().Get(nil) Expect(err).NotTo(HaveOccurred()) cn.SetNetConn(&badConn{}) @@ -184,7 +185,7 @@ var _ = Describe("Client", func() { }) It("should update conn.UsedAt on read/write", func() { - cn, err := client.Pool().Get() + cn, err := client.Pool().Get(nil) Expect(err).NotTo(HaveOccurred()) Expect(cn.UsedAt).NotTo(BeZero()) createdAt := cn.UsedAt() @@ -197,7 +198,7 @@ var _ = Describe("Client", func() { err = client.Ping().Err() Expect(err).NotTo(HaveOccurred()) - cn, err = client.Pool().Get() + cn, err = client.Pool().Get(nil) Expect(err).NotTo(HaveOccurred()) Expect(cn).NotTo(BeNil()) Expect(cn.UsedAt().After(createdAt)).To(BeTrue()) diff --git a/ring.go b/ring.go index c2070393..fa77a974 100644 --- a/ring.go +++ b/ring.go @@ -610,7 +610,7 @@ func (c *Ring) _processPipeline(ctx context.Context, cmds []Cmder) error { return } - cn, err := shard.Client.getConn() + cn, err := shard.Client.getConn(ctx) if err != nil { setCmdsErr(cmds, err) return diff --git a/sentinel.go b/sentinel.go index 15e2160f..3f13cf6a 100644 --- a/sentinel.go +++ b/sentinel.go @@ -21,16 +21,16 @@ type FailoverOptions struct { // The master name. MasterName string // A seed list of host:port addresses of sentinel nodes. - SentinelAddrs []string + SentinelAddrs []string + SentinelPassword string // Following options are copied from Options struct. - Dialer func(network, addr string) (net.Conn, error) + Dialer func(ctx context.Context, network, addr string) (net.Conn, error) OnConnect func(*Conn) error - Password string - SentinelPassword string - DB int + Password string + DB int MaxRetries int MinRetryBackoff time.Duration @@ -312,7 +312,7 @@ func (c *sentinelFailover) Pool() *pool.ConnPool { return c.pool } -func (c *sentinelFailover) dial(network, addr string) (net.Conn, error) { +func (c *sentinelFailover) dial(ctx context.Context, network, addr string) (net.Conn, error) { addr, err := c.MasterAddr() if err != nil { return nil, err diff --git a/tx_test.go b/tx_test.go index c70f08ce..c5cb4b3c 100644 --- a/tx_test.go +++ b/tx_test.go @@ -124,7 +124,7 @@ var _ = Describe("Tx", func() { It("should recover from bad connection", func() { // Put bad connection in the pool. - cn, err := client.Pool().Get() + cn, err := client.Pool().Get(nil) Expect(err).NotTo(HaveOccurred()) cn.SetNetConn(&badConn{}) diff --git a/universal.go b/universal.go index dd2b5f71..62fd0fd0 100644 --- a/universal.go +++ b/universal.go @@ -20,7 +20,7 @@ type UniversalOptions struct { // Common options. - Dialer func(network, addr string) (net.Conn, error) + Dialer func(ctx context.Context, network, addr string) (net.Conn, error) OnConnect func(*Conn) error Password string MaxRetries int