Fix WithContext and add tests

This commit is contained in:
Vladimir Mihailenco 2019-07-04 11:18:06 +03:00
parent 73d3c18522
commit 2cbb5194fb
14 changed files with 114 additions and 90 deletions

View File

@ -673,6 +673,7 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient {
opt: opt, opt: opt,
nodes: newClusterNodes(opt), nodes: newClusterNodes(opt),
}, },
ctx: context.Background(),
} }
c.state = newClusterStateHolder(c.loadState) c.state = newClusterStateHolder(c.loadState)
c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo)
@ -690,10 +691,7 @@ func (c *ClusterClient) init() {
} }
func (c *ClusterClient) Context() context.Context { func (c *ClusterClient) Context() context.Context {
if c.ctx != nil {
return c.ctx return c.ctx
}
return context.Background()
} }
func (c *ClusterClient) WithContext(ctx context.Context) *ClusterClient { func (c *ClusterClient) WithContext(ctx context.Context) *ClusterClient {
@ -702,6 +700,7 @@ func (c *ClusterClient) WithContext(ctx context.Context) *ClusterClient {
} }
clone := *c clone := *c
clone.ctx = ctx clone.ctx = ctx
clone.init()
return &clone return &clone
} }
@ -732,7 +731,7 @@ func (c *ClusterClient) Do(args ...interface{}) *Cmd {
func (c *ClusterClient) DoContext(ctx context.Context, args ...interface{}) *Cmd { func (c *ClusterClient) DoContext(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(args...) cmd := NewCmd(args...)
c.ProcessContext(ctx, cmd) _ = c.ProcessContext(ctx, cmd)
return cmd return cmd
} }
@ -1035,7 +1034,7 @@ func (c *ClusterClient) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
} }
func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error { func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error {
return c.hooks.processPipeline(c.ctx, cmds, c._processPipeline) return c.hooks.processPipeline(ctx, cmds, c._processPipeline)
} }
func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) error { func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) error {

View File

@ -1,6 +1,7 @@
package redis_test package redis_test
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
@ -241,6 +242,14 @@ var _ = Describe("ClusterClient", func() {
var client *redis.ClusterClient var client *redis.ClusterClient
assertClusterClient := func() { assertClusterClient := func() {
It("supports WithContext", func() {
c, cancel := context.WithCancel(context.Background())
cancel()
err := client.WithContext(c).Ping().Err()
Expect(err).To(MatchError("context canceled"))
})
It("should GET/SET/DEL", func() { It("should GET/SET/DEL", func() {
err := client.Get("A").Err() err := client.Get("A").Err()
Expect(err).To(Equal(redis.Nil)) Expect(err).To(Equal(redis.Nil))

View File

@ -1,6 +1,7 @@
package pool_test package pool_test
import ( import (
"context"
"fmt" "fmt"
"testing" "testing"
"time" "time"
@ -39,7 +40,7 @@ func BenchmarkPoolGetPut(b *testing.B) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
cn, err := connPool.Get(nil) cn, err := connPool.Get(context.Background())
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -81,7 +82,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
cn, err := connPool.Get(nil) cn, err := connPool.Get(context.Background())
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

View File

@ -250,22 +250,23 @@ func (p *ConnPool) getTurn() {
} }
func (p *ConnPool) waitTurn(ctx context.Context) error { func (p *ConnPool) waitTurn(ctx context.Context) error {
var done <-chan struct{} select {
if ctx != nil { case <-ctx.Done():
done = ctx.Done() return ctx.Err()
default:
} }
select { select {
case <-done:
return ctx.Err()
case p.queue <- struct{}{}: case p.queue <- struct{}{}:
return nil return nil
default: default:
}
timer := timers.Get().(*time.Timer) timer := timers.Get().(*time.Timer)
timer.Reset(p.opt.PoolTimeout) timer.Reset(p.opt.PoolTimeout)
select { select {
case <-done: case <-ctx.Done():
if !timer.Stop() { if !timer.Stop() {
<-timer.C <-timer.C
} }
@ -282,7 +283,6 @@ func (p *ConnPool) waitTurn(ctx context.Context) error {
atomic.AddUint32(&p.stats.Timeouts, 1) atomic.AddUint32(&p.stats.Timeouts, 1)
return ErrPoolTimeout return ErrPoolTimeout
} }
}
} }
func (p *ConnPool) freeTurn() { func (p *ConnPool) freeTurn() {

View File

@ -1,6 +1,7 @@
package pool_test package pool_test
import ( import (
"context"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -12,6 +13,7 @@ import (
) )
var _ = Describe("ConnPool", func() { var _ = Describe("ConnPool", func() {
c := context.Background()
var connPool *pool.ConnPool var connPool *pool.ConnPool
BeforeEach(func() { BeforeEach(func() {
@ -30,13 +32,13 @@ var _ = Describe("ConnPool", func() {
It("should unblock client when conn is removed", func() { It("should unblock client when conn is removed", func() {
// Reserve one connection. // Reserve one connection.
cn, err := connPool.Get(nil) cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
// Reserve all other connections. // Reserve all other connections.
var cns []*pool.Conn var cns []*pool.Conn
for i := 0; i < 9; i++ { for i := 0; i < 9; i++ {
cn, err := connPool.Get(nil) cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cns = append(cns, cn) cns = append(cns, cn)
} }
@ -47,7 +49,7 @@ var _ = Describe("ConnPool", func() {
defer GinkgoRecover() defer GinkgoRecover()
started <- true started <- true
_, err := connPool.Get(nil) _, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
done <- true done <- true
@ -80,6 +82,7 @@ var _ = Describe("ConnPool", func() {
}) })
var _ = Describe("MinIdleConns", func() { var _ = Describe("MinIdleConns", func() {
c := context.Background()
const poolSize = 100 const poolSize = 100
var minIdleConns int var minIdleConns int
var connPool *pool.ConnPool var connPool *pool.ConnPool
@ -110,7 +113,7 @@ var _ = Describe("MinIdleConns", func() {
BeforeEach(func() { BeforeEach(func() {
var err error var err error
cn, err = connPool.Get(nil) cn, err = connPool.Get(c)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Eventually(func() int { Eventually(func() int {
@ -145,7 +148,7 @@ var _ = Describe("MinIdleConns", func() {
perform(poolSize, func(_ int) { perform(poolSize, func(_ int) {
defer GinkgoRecover() defer GinkgoRecover()
cn, err := connPool.Get(nil) cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
mu.Lock() mu.Lock()
cns = append(cns, cn) cns = append(cns, cn)
@ -160,7 +163,7 @@ var _ = Describe("MinIdleConns", func() {
It("Get is blocked", func() { It("Get is blocked", func() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
connPool.Get(nil) connPool.Get(c)
close(done) close(done)
}() }()
@ -247,6 +250,8 @@ var _ = Describe("MinIdleConns", func() {
}) })
var _ = Describe("conns reaper", func() { var _ = Describe("conns reaper", func() {
c := context.Background()
const idleTimeout = time.Minute const idleTimeout = time.Minute
const maxAge = time.Hour const maxAge = time.Hour
@ -274,7 +279,7 @@ var _ = Describe("conns reaper", func() {
// add stale connections // add stale connections
staleConns = nil staleConns = nil
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
cn, err := connPool.Get(nil) cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
switch typ { switch typ {
case "idle": case "idle":
@ -288,7 +293,7 @@ var _ = Describe("conns reaper", func() {
// add fresh connections // add fresh connections
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
cn, err := connPool.Get(nil) cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
conns = append(conns, cn) conns = append(conns, cn)
} }
@ -333,7 +338,7 @@ var _ = Describe("conns reaper", func() {
for j := 0; j < 3; j++ { for j := 0; j < 3; j++ {
var freeCns []*pool.Conn var freeCns []*pool.Conn
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
cn, err := connPool.Get(nil) cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil()) Expect(cn).NotTo(BeNil())
freeCns = append(freeCns, cn) freeCns = append(freeCns, cn)
@ -342,7 +347,7 @@ var _ = Describe("conns reaper", func() {
Expect(connPool.Len()).To(Equal(3)) Expect(connPool.Len()).To(Equal(3))
Expect(connPool.IdleLen()).To(Equal(0)) Expect(connPool.IdleLen()).To(Equal(0))
cn, err := connPool.Get(nil) cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil()) Expect(cn).NotTo(BeNil())
conns = append(conns, cn) conns = append(conns, cn)
@ -370,6 +375,7 @@ var _ = Describe("conns reaper", func() {
}) })
var _ = Describe("race", func() { var _ = Describe("race", func() {
c := context.Background()
var connPool *pool.ConnPool var connPool *pool.ConnPool
var C, N int var C, N int
@ -396,7 +402,7 @@ var _ = Describe("race", func() {
perform(C, func(id int) { perform(C, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
cn, err := connPool.Get(nil) cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
if err == nil { if err == nil {
connPool.Put(cn) connPool.Put(cn)
@ -404,7 +410,7 @@ var _ = Describe("race", func() {
} }
}, func(id int) { }, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
cn, err := connPool.Get(nil) cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
if err == nil { if err == nil {
connPool.Remove(cn) connPool.Remove(cn)

View File

@ -98,7 +98,7 @@ func (c *Pipeline) discard() error {
// Exec always returns list of commands and error of the first failed // Exec always returns list of commands and error of the first failed
// command if any. // command if any.
func (c *Pipeline) Exec() ([]Cmder, error) { func (c *Pipeline) Exec() ([]Cmder, error) {
return c.ExecContext(nil) return c.ExecContext(context.Background())
} }
func (c *Pipeline) ExecContext(ctx context.Context) ([]Cmder, error) { func (c *Pipeline) ExecContext(ctx context.Context) ([]Cmder, error) {

View File

@ -1,6 +1,7 @@
package redis_test package redis_test
import ( import (
"context"
"time" "time"
"github.com/go-redis/redis" "github.com/go-redis/redis"
@ -81,7 +82,7 @@ var _ = Describe("pool", func() {
}) })
It("removes broken connections", func() { It("removes broken connections", func() {
cn, err := client.Pool().Get(nil) cn, err := client.Pool().Get(context.Background())
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{}) cn.SetNetConn(&badConn{})
client.Pool().Put(cn) client.Pool().Put(cn)

View File

@ -32,10 +32,6 @@ type hooks struct {
hooks []Hook hooks []Hook
} }
func (hs *hooks) lazyCopy() {
hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)]
}
func (hs *hooks) AddHook(hook Hook) { func (hs *hooks) AddHook(hook Hook) {
hs.hooks = append(hs.hooks, hook) hs.hooks = append(hs.hooks, hook)
} }
@ -475,6 +471,7 @@ func NewClient(opt *Options) *Client {
connPool: newConnPool(opt), connPool: newConnPool(opt),
}, },
}, },
ctx: context.Background(),
} }
c.init() c.init()
@ -486,10 +483,7 @@ func (c *Client) init() {
} }
func (c *Client) Context() context.Context { func (c *Client) Context() context.Context {
if c.ctx != nil {
return c.ctx return c.ctx
}
return context.Background()
} }
func (c *Client) WithContext(ctx context.Context) *Client { func (c *Client) WithContext(ctx context.Context) *Client {
@ -498,6 +492,7 @@ func (c *Client) WithContext(ctx context.Context) *Client {
} }
clone := *c clone := *c
clone.ctx = ctx clone.ctx = ctx
clone.init()
return &clone return &clone
} }

View File

@ -24,6 +24,14 @@ var _ = Describe("Client", func() {
client.Close() client.Close()
}) })
It("supports WithContext", func() {
c, cancel := context.WithCancel(context.Background())
cancel()
err := client.WithContext(c).Ping().Err()
Expect(err).To(MatchError("context canceled"))
})
It("should Stringer", func() { It("should Stringer", func() {
Expect(client.String()).To(Equal("Redis<:6380 db:15>")) Expect(client.String()).To(Equal("Redis<:6380 db:15>"))
}) })
@ -129,7 +137,7 @@ var _ = Describe("Client", func() {
It("processes custom commands", func() { It("processes custom commands", func() {
cmd := redis.NewCmd("PING") cmd := redis.NewCmd("PING")
client.Process(cmd) _ = client.Process(cmd)
// Flush buffers. // Flush buffers.
Expect(client.Echo("hello").Err()).NotTo(HaveOccurred()) Expect(client.Echo("hello").Err()).NotTo(HaveOccurred())
@ -147,7 +155,7 @@ var _ = Describe("Client", func() {
}) })
// Put bad connection in the pool. // Put bad connection in the pool.
cn, err := client.Pool().Get(nil) cn, err := client.Pool().Get(context.Background())
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{}) cn.SetNetConn(&badConn{})
@ -185,7 +193,7 @@ var _ = Describe("Client", func() {
}) })
It("should update conn.UsedAt on read/write", func() { It("should update conn.UsedAt on read/write", func() {
cn, err := client.Pool().Get(nil) cn, err := client.Pool().Get(context.Background())
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn.UsedAt).NotTo(BeZero()) Expect(cn.UsedAt).NotTo(BeZero())
createdAt := cn.UsedAt() createdAt := cn.UsedAt()
@ -198,7 +206,7 @@ var _ = Describe("Client", func() {
err = client.Ping().Err() err = client.Ping().Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn, err = client.Pool().Get(nil) cn, err = client.Pool().Get(context.Background())
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil()) Expect(cn).NotTo(BeNil())
Expect(cn.UsedAt().After(createdAt)).To(BeTrue()) Expect(cn.UsedAt().After(createdAt)).To(BeTrue())

View File

@ -358,6 +358,7 @@ func NewRing(opt *RingOptions) *Ring {
opt: opt, opt: opt,
shards: newRingShards(opt), shards: newRingShards(opt),
}, },
ctx: context.Background(),
} }
ring.init() ring.init()
@ -379,10 +380,7 @@ func (c *Ring) init() {
} }
func (c *Ring) Context() context.Context { func (c *Ring) Context() context.Context {
if c.ctx != nil {
return c.ctx return c.ctx
}
return context.Background()
} }
func (c *Ring) WithContext(ctx context.Context) *Ring { func (c *Ring) WithContext(ctx context.Context) *Ring {
@ -391,6 +389,7 @@ func (c *Ring) WithContext(ctx context.Context) *Ring {
} }
clone := *c clone := *c
clone.ctx = ctx clone.ctx = ctx
clone.init()
return &clone return &clone
} }
@ -401,7 +400,7 @@ func (c *Ring) Do(args ...interface{}) *Cmd {
func (c *Ring) DoContext(ctx context.Context, args ...interface{}) *Cmd { func (c *Ring) DoContext(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(args...) cmd := NewCmd(args...)
c.ProcessContext(ctx, cmd) _ = c.ProcessContext(ctx, cmd)
return cmd return cmd
} }

View File

@ -1,6 +1,7 @@
package redis_test package redis_test
import ( import (
"context"
"crypto/rand" "crypto/rand"
"fmt" "fmt"
"net" "net"
@ -41,6 +42,14 @@ var _ = Describe("Redis Ring", func() {
Expect(ring.Close()).NotTo(HaveOccurred()) Expect(ring.Close()).NotTo(HaveOccurred())
}) })
It("supports WithContext", func() {
c, cancel := context.WithCancel(context.Background())
cancel()
err := ring.WithContext(c).Ping().Err()
Expect(err).To(MatchError("context canceled"))
})
It("distributes keys", func() { It("distributes keys", func() {
setRingKeys() setRingKeys()

View File

@ -97,8 +97,9 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
onClose: failover.Close, onClose: failover.Close,
}, },
}, },
ctx: context.Background(),
} }
c.cmdable = c.Process c.init()
return &c return &c
} }
@ -117,15 +118,13 @@ func NewSentinelClient(opt *Options) *SentinelClient {
opt: opt, opt: opt,
connPool: newConnPool(opt), connPool: newConnPool(opt),
}, },
ctx: context.Background(),
} }
return c return c
} }
func (c *SentinelClient) Context() context.Context { func (c *SentinelClient) Context() context.Context {
if c.ctx != nil {
return c.ctx return c.ctx
}
return context.Background()
} }
func (c *SentinelClient) WithContext(ctx context.Context) *SentinelClient { func (c *SentinelClient) WithContext(ctx context.Context) *SentinelClient {
@ -162,7 +161,7 @@ func (c *SentinelClient) pubSub() *PubSub {
// measure latency. // measure latency.
func (c *SentinelClient) Ping() *StringCmd { func (c *SentinelClient) Ping() *StringCmd {
cmd := NewStringCmd("ping") cmd := NewStringCmd("ping")
c.Process(cmd) _ = c.Process(cmd)
return cmd return cmd
} }
@ -188,13 +187,13 @@ func (c *SentinelClient) PSubscribe(channels ...string) *PubSub {
func (c *SentinelClient) GetMasterAddrByName(name string) *StringSliceCmd { func (c *SentinelClient) GetMasterAddrByName(name string) *StringSliceCmd {
cmd := NewStringSliceCmd("sentinel", "get-master-addr-by-name", name) cmd := NewStringSliceCmd("sentinel", "get-master-addr-by-name", name)
c.Process(cmd) _ = c.Process(cmd)
return cmd return cmd
} }
func (c *SentinelClient) Sentinels(name string) *SliceCmd { func (c *SentinelClient) Sentinels(name string) *SliceCmd {
cmd := NewSliceCmd("sentinel", "sentinels", name) cmd := NewSliceCmd("sentinel", "sentinels", name)
c.Process(cmd) _ = c.Process(cmd)
return cmd return cmd
} }
@ -202,7 +201,7 @@ func (c *SentinelClient) Sentinels(name string) *SliceCmd {
// asking for agreement to other Sentinels. // asking for agreement to other Sentinels.
func (c *SentinelClient) Failover(name string) *StatusCmd { func (c *SentinelClient) Failover(name string) *StatusCmd {
cmd := NewStatusCmd("sentinel", "failover", name) cmd := NewStatusCmd("sentinel", "failover", name)
c.Process(cmd) _ = c.Process(cmd)
return cmd return cmd
} }
@ -212,7 +211,7 @@ func (c *SentinelClient) Failover(name string) *StatusCmd {
// already discovered and associated with the master. // already discovered and associated with the master.
func (c *SentinelClient) Reset(pattern string) *IntCmd { func (c *SentinelClient) Reset(pattern string) *IntCmd {
cmd := NewIntCmd("sentinel", "reset", pattern) cmd := NewIntCmd("sentinel", "reset", pattern)
c.Process(cmd) _ = c.Process(cmd)
return cmd return cmd
} }
@ -220,28 +219,28 @@ func (c *SentinelClient) Reset(pattern string) *IntCmd {
// the current Sentinel state. // the current Sentinel state.
func (c *SentinelClient) FlushConfig() *StatusCmd { func (c *SentinelClient) FlushConfig() *StatusCmd {
cmd := NewStatusCmd("sentinel", "flushconfig") cmd := NewStatusCmd("sentinel", "flushconfig")
c.Process(cmd) _ = c.Process(cmd)
return cmd return cmd
} }
// Master shows the state and info of the specified master. // Master shows the state and info of the specified master.
func (c *SentinelClient) Master(name string) *StringStringMapCmd { func (c *SentinelClient) Master(name string) *StringStringMapCmd {
cmd := NewStringStringMapCmd("sentinel", "master", name) cmd := NewStringStringMapCmd("sentinel", "master", name)
c.Process(cmd) _ = c.Process(cmd)
return cmd return cmd
} }
// Masters shows a list of monitored masters and their state. // Masters shows a list of monitored masters and their state.
func (c *SentinelClient) Masters() *SliceCmd { func (c *SentinelClient) Masters() *SliceCmd {
cmd := NewSliceCmd("sentinel", "masters") cmd := NewSliceCmd("sentinel", "masters")
c.Process(cmd) _ = c.Process(cmd)
return cmd return cmd
} }
// Slaves shows a list of slaves for the specified master and their state. // Slaves shows a list of slaves for the specified master and their state.
func (c *SentinelClient) Slaves(name string) *SliceCmd { func (c *SentinelClient) Slaves(name string) *SliceCmd {
cmd := NewSliceCmd("sentinel", "slaves", name) cmd := NewSliceCmd("sentinel", "slaves", name)
c.Process(cmd) _ = c.Process(cmd)
return cmd return cmd
} }
@ -251,7 +250,7 @@ func (c *SentinelClient) Slaves(name string) *SliceCmd {
// Sentinel deployment is ok. // Sentinel deployment is ok.
func (c *SentinelClient) CkQuorum(name string) *StringCmd { func (c *SentinelClient) CkQuorum(name string) *StringCmd {
cmd := NewStringCmd("sentinel", "ckquorum", name) cmd := NewStringCmd("sentinel", "ckquorum", name)
c.Process(cmd) _ = c.Process(cmd)
return cmd return cmd
} }
@ -259,14 +258,14 @@ func (c *SentinelClient) CkQuorum(name string) *StringCmd {
// name, ip, port, and quorum. // name, ip, port, and quorum.
func (c *SentinelClient) Monitor(name, ip, port, quorum string) *StringCmd { func (c *SentinelClient) Monitor(name, ip, port, quorum string) *StringCmd {
cmd := NewStringCmd("sentinel", "monitor", name, ip, port, quorum) cmd := NewStringCmd("sentinel", "monitor", name, ip, port, quorum)
c.Process(cmd) _ = c.Process(cmd)
return cmd return cmd
} }
// Set is used in order to change configuration parameters of a specific master. // Set is used in order to change configuration parameters of a specific master.
func (c *SentinelClient) Set(name, option, value string) *StringCmd { func (c *SentinelClient) Set(name, option, value string) *StringCmd {
cmd := NewStringCmd("sentinel", "set", name, option, value) cmd := NewStringCmd("sentinel", "set", name, option, value)
c.Process(cmd) _ = c.Process(cmd)
return cmd return cmd
} }
@ -275,7 +274,7 @@ func (c *SentinelClient) Set(name, option, value string) *StringCmd {
// the Sentinel. // the Sentinel.
func (c *SentinelClient) Remove(name string) *StringCmd { func (c *SentinelClient) Remove(name string) *StringCmd {
cmd := NewStringCmd("sentinel", "remove", name) cmd := NewStringCmd("sentinel", "remove", name)
c.Process(cmd) _ = c.Process(cmd)
return cmd return cmd
} }
@ -313,7 +312,7 @@ func (c *sentinelFailover) Pool() *pool.ConnPool {
return c.pool return c.pool
} }
func (c *sentinelFailover) dial(ctx context.Context, network, addr string) (net.Conn, error) { func (c *sentinelFailover) dial(ctx context.Context, network, _ string) (net.Conn, error) {
addr, err := c.MasterAddr() addr, err := c.MasterAddr()
if err != nil { if err != nil {
return nil, err return nil, err
@ -396,7 +395,7 @@ func (c *sentinelFailover) getMasterAddr() string {
c.masterName, err) c.masterName, err)
c.mu.Lock() c.mu.Lock()
if c.sentinel == sentinel { if c.sentinel == sentinel {
c.closeSentinel() _ = c.closeSentinel()
} }
c.mu.Unlock() c.mu.Unlock()
return "" return ""
@ -436,13 +435,13 @@ func (c *sentinelFailover) closeSentinel() error {
var firstErr error var firstErr error
err := c.pubsub.Close() err := c.pubsub.Close()
if err != nil && firstErr == err { if err != nil && firstErr == nil {
firstErr = err firstErr = err
} }
c.pubsub = nil c.pubsub = nil
err = c.sentinel.Close() err = c.sentinel.Close()
if err != nil && firstErr == err { if err != nil && firstErr == nil {
firstErr = err firstErr = err
} }
c.sentinel = nil c.sentinel = nil

3
tx.go
View File

@ -40,10 +40,7 @@ func (c *Tx) init() {
} }
func (c *Tx) Context() context.Context { func (c *Tx) Context() context.Context {
if c.ctx != nil {
return c.ctx return c.ctx
}
return context.Background()
} }
func (c *Tx) WithContext(ctx context.Context) *Tx { func (c *Tx) WithContext(ctx context.Context) *Tx {

View File

@ -1,6 +1,7 @@
package redis_test package redis_test
import ( import (
"context"
"strconv" "strconv"
"sync" "sync"
@ -124,7 +125,7 @@ var _ = Describe("Tx", func() {
It("should recover from bad connection", func() { It("should recover from bad connection", func() {
// Put bad connection in the pool. // Put bad connection in the pool.
cn, err := client.Pool().Get(nil) cn, err := client.Pool().Get(context.Background())
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{}) cn.SetNetConn(&badConn{})