Fix WithContext and add tests

This commit is contained in:
Vladimir Mihailenco 2019-07-04 11:18:06 +03:00
parent 73d3c18522
commit 2cbb5194fb
14 changed files with 114 additions and 90 deletions

View File

@ -673,6 +673,7 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient {
opt: opt,
nodes: newClusterNodes(opt),
},
ctx: context.Background(),
}
c.state = newClusterStateHolder(c.loadState)
c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo)
@ -690,11 +691,8 @@ func (c *ClusterClient) init() {
}
func (c *ClusterClient) Context() context.Context {
if c.ctx != nil {
return c.ctx
}
return context.Background()
}
func (c *ClusterClient) WithContext(ctx context.Context) *ClusterClient {
if ctx == nil {
@ -702,6 +700,7 @@ func (c *ClusterClient) WithContext(ctx context.Context) *ClusterClient {
}
clone := *c
clone.ctx = ctx
clone.init()
return &clone
}
@ -732,7 +731,7 @@ func (c *ClusterClient) Do(args ...interface{}) *Cmd {
func (c *ClusterClient) DoContext(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(args...)
c.ProcessContext(ctx, cmd)
_ = c.ProcessContext(ctx, cmd)
return cmd
}
@ -1035,7 +1034,7 @@ func (c *ClusterClient) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
}
func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error {
return c.hooks.processPipeline(c.ctx, cmds, c._processPipeline)
return c.hooks.processPipeline(ctx, cmds, c._processPipeline)
}
func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) error {

View File

@ -1,6 +1,7 @@
package redis_test
import (
"context"
"fmt"
"net"
"strconv"
@ -241,6 +242,14 @@ var _ = Describe("ClusterClient", func() {
var client *redis.ClusterClient
assertClusterClient := func() {
It("supports WithContext", func() {
c, cancel := context.WithCancel(context.Background())
cancel()
err := client.WithContext(c).Ping().Err()
Expect(err).To(MatchError("context canceled"))
})
It("should GET/SET/DEL", func() {
err := client.Get("A").Err()
Expect(err).To(Equal(redis.Nil))

View File

@ -1,6 +1,7 @@
package pool_test
import (
"context"
"fmt"
"testing"
"time"
@ -39,7 +40,7 @@ func BenchmarkPoolGetPut(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(context.Background())
if err != nil {
b.Fatal(err)
}
@ -81,7 +82,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(context.Background())
if err != nil {
b.Fatal(err)
}

View File

@ -250,22 +250,23 @@ func (p *ConnPool) getTurn() {
}
func (p *ConnPool) waitTurn(ctx context.Context) error {
var done <-chan struct{}
if ctx != nil {
done = ctx.Done()
select {
case <-ctx.Done():
return ctx.Err()
default:
}
select {
case <-done:
return ctx.Err()
case p.queue <- struct{}{}:
return nil
default:
}
timer := timers.Get().(*time.Timer)
timer.Reset(p.opt.PoolTimeout)
select {
case <-done:
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
@ -283,7 +284,6 @@ func (p *ConnPool) waitTurn(ctx context.Context) error {
return ErrPoolTimeout
}
}
}
func (p *ConnPool) freeTurn() {
<-p.queue

View File

@ -1,6 +1,7 @@
package pool_test
import (
"context"
"sync"
"testing"
"time"
@ -12,6 +13,7 @@ import (
)
var _ = Describe("ConnPool", func() {
c := context.Background()
var connPool *pool.ConnPool
BeforeEach(func() {
@ -30,13 +32,13 @@ var _ = Describe("ConnPool", func() {
It("should unblock client when conn is removed", func() {
// Reserve one connection.
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
// Reserve all other connections.
var cns []*pool.Conn
for i := 0; i < 9; i++ {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
cns = append(cns, cn)
}
@ -47,7 +49,7 @@ var _ = Describe("ConnPool", func() {
defer GinkgoRecover()
started <- true
_, err := connPool.Get(nil)
_, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
done <- true
@ -80,6 +82,7 @@ var _ = Describe("ConnPool", func() {
})
var _ = Describe("MinIdleConns", func() {
c := context.Background()
const poolSize = 100
var minIdleConns int
var connPool *pool.ConnPool
@ -110,7 +113,7 @@ var _ = Describe("MinIdleConns", func() {
BeforeEach(func() {
var err error
cn, err = connPool.Get(nil)
cn, err = connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
Eventually(func() int {
@ -145,7 +148,7 @@ var _ = Describe("MinIdleConns", func() {
perform(poolSize, func(_ int) {
defer GinkgoRecover()
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
mu.Lock()
cns = append(cns, cn)
@ -160,7 +163,7 @@ var _ = Describe("MinIdleConns", func() {
It("Get is blocked", func() {
done := make(chan struct{})
go func() {
connPool.Get(nil)
connPool.Get(c)
close(done)
}()
@ -247,6 +250,8 @@ var _ = Describe("MinIdleConns", func() {
})
var _ = Describe("conns reaper", func() {
c := context.Background()
const idleTimeout = time.Minute
const maxAge = time.Hour
@ -274,7 +279,7 @@ var _ = Describe("conns reaper", func() {
// add stale connections
staleConns = nil
for i := 0; i < 3; i++ {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
switch typ {
case "idle":
@ -288,7 +293,7 @@ var _ = Describe("conns reaper", func() {
// add fresh connections
for i := 0; i < 3; i++ {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
conns = append(conns, cn)
}
@ -333,7 +338,7 @@ var _ = Describe("conns reaper", func() {
for j := 0; j < 3; j++ {
var freeCns []*pool.Conn
for i := 0; i < 3; i++ {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil())
freeCns = append(freeCns, cn)
@ -342,7 +347,7 @@ var _ = Describe("conns reaper", func() {
Expect(connPool.Len()).To(Equal(3))
Expect(connPool.IdleLen()).To(Equal(0))
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil())
conns = append(conns, cn)
@ -370,6 +375,7 @@ var _ = Describe("conns reaper", func() {
})
var _ = Describe("race", func() {
c := context.Background()
var connPool *pool.ConnPool
var C, N int
@ -396,7 +402,7 @@ var _ = Describe("race", func() {
perform(C, func(id int) {
for i := 0; i < N; i++ {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
if err == nil {
connPool.Put(cn)
@ -404,7 +410,7 @@ var _ = Describe("race", func() {
}
}, func(id int) {
for i := 0; i < N; i++ {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
if err == nil {
connPool.Remove(cn)

View File

@ -98,7 +98,7 @@ func (c *Pipeline) discard() error {
// Exec always returns list of commands and error of the first failed
// command if any.
func (c *Pipeline) Exec() ([]Cmder, error) {
return c.ExecContext(nil)
return c.ExecContext(context.Background())
}
func (c *Pipeline) ExecContext(ctx context.Context) ([]Cmder, error) {

View File

@ -1,6 +1,7 @@
package redis_test
import (
"context"
"time"
"github.com/go-redis/redis"
@ -81,7 +82,7 @@ var _ = Describe("pool", func() {
})
It("removes broken connections", func() {
cn, err := client.Pool().Get(nil)
cn, err := client.Pool().Get(context.Background())
Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{})
client.Pool().Put(cn)

View File

@ -32,10 +32,6 @@ type hooks struct {
hooks []Hook
}
func (hs *hooks) lazyCopy() {
hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)]
}
func (hs *hooks) AddHook(hook Hook) {
hs.hooks = append(hs.hooks, hook)
}
@ -475,6 +471,7 @@ func NewClient(opt *Options) *Client {
connPool: newConnPool(opt),
},
},
ctx: context.Background(),
}
c.init()
@ -486,11 +483,8 @@ func (c *Client) init() {
}
func (c *Client) Context() context.Context {
if c.ctx != nil {
return c.ctx
}
return context.Background()
}
func (c *Client) WithContext(ctx context.Context) *Client {
if ctx == nil {
@ -498,6 +492,7 @@ func (c *Client) WithContext(ctx context.Context) *Client {
}
clone := *c
clone.ctx = ctx
clone.init()
return &clone
}

View File

@ -24,6 +24,14 @@ var _ = Describe("Client", func() {
client.Close()
})
It("supports WithContext", func() {
c, cancel := context.WithCancel(context.Background())
cancel()
err := client.WithContext(c).Ping().Err()
Expect(err).To(MatchError("context canceled"))
})
It("should Stringer", func() {
Expect(client.String()).To(Equal("Redis<:6380 db:15>"))
})
@ -129,7 +137,7 @@ var _ = Describe("Client", func() {
It("processes custom commands", func() {
cmd := redis.NewCmd("PING")
client.Process(cmd)
_ = client.Process(cmd)
// Flush buffers.
Expect(client.Echo("hello").Err()).NotTo(HaveOccurred())
@ -147,7 +155,7 @@ var _ = Describe("Client", func() {
})
// Put bad connection in the pool.
cn, err := client.Pool().Get(nil)
cn, err := client.Pool().Get(context.Background())
Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{})
@ -185,7 +193,7 @@ var _ = Describe("Client", func() {
})
It("should update conn.UsedAt on read/write", func() {
cn, err := client.Pool().Get(nil)
cn, err := client.Pool().Get(context.Background())
Expect(err).NotTo(HaveOccurred())
Expect(cn.UsedAt).NotTo(BeZero())
createdAt := cn.UsedAt()
@ -198,7 +206,7 @@ var _ = Describe("Client", func() {
err = client.Ping().Err()
Expect(err).NotTo(HaveOccurred())
cn, err = client.Pool().Get(nil)
cn, err = client.Pool().Get(context.Background())
Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil())
Expect(cn.UsedAt().After(createdAt)).To(BeTrue())

View File

@ -358,6 +358,7 @@ func NewRing(opt *RingOptions) *Ring {
opt: opt,
shards: newRingShards(opt),
},
ctx: context.Background(),
}
ring.init()
@ -379,11 +380,8 @@ func (c *Ring) init() {
}
func (c *Ring) Context() context.Context {
if c.ctx != nil {
return c.ctx
}
return context.Background()
}
func (c *Ring) WithContext(ctx context.Context) *Ring {
if ctx == nil {
@ -391,6 +389,7 @@ func (c *Ring) WithContext(ctx context.Context) *Ring {
}
clone := *c
clone.ctx = ctx
clone.init()
return &clone
}
@ -401,7 +400,7 @@ func (c *Ring) Do(args ...interface{}) *Cmd {
func (c *Ring) DoContext(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(args...)
c.ProcessContext(ctx, cmd)
_ = c.ProcessContext(ctx, cmd)
return cmd
}

View File

@ -1,6 +1,7 @@
package redis_test
import (
"context"
"crypto/rand"
"fmt"
"net"
@ -41,6 +42,14 @@ var _ = Describe("Redis Ring", func() {
Expect(ring.Close()).NotTo(HaveOccurred())
})
It("supports WithContext", func() {
c, cancel := context.WithCancel(context.Background())
cancel()
err := ring.WithContext(c).Ping().Err()
Expect(err).To(MatchError("context canceled"))
})
It("distributes keys", func() {
setRingKeys()

View File

@ -97,8 +97,9 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
onClose: failover.Close,
},
},
ctx: context.Background(),
}
c.cmdable = c.Process
c.init()
return &c
}
@ -117,16 +118,14 @@ func NewSentinelClient(opt *Options) *SentinelClient {
opt: opt,
connPool: newConnPool(opt),
},
ctx: context.Background(),
}
return c
}
func (c *SentinelClient) Context() context.Context {
if c.ctx != nil {
return c.ctx
}
return context.Background()
}
func (c *SentinelClient) WithContext(ctx context.Context) *SentinelClient {
if ctx == nil {
@ -162,7 +161,7 @@ func (c *SentinelClient) pubSub() *PubSub {
// measure latency.
func (c *SentinelClient) Ping() *StringCmd {
cmd := NewStringCmd("ping")
c.Process(cmd)
_ = c.Process(cmd)
return cmd
}
@ -188,13 +187,13 @@ func (c *SentinelClient) PSubscribe(channels ...string) *PubSub {
func (c *SentinelClient) GetMasterAddrByName(name string) *StringSliceCmd {
cmd := NewStringSliceCmd("sentinel", "get-master-addr-by-name", name)
c.Process(cmd)
_ = c.Process(cmd)
return cmd
}
func (c *SentinelClient) Sentinels(name string) *SliceCmd {
cmd := NewSliceCmd("sentinel", "sentinels", name)
c.Process(cmd)
_ = c.Process(cmd)
return cmd
}
@ -202,7 +201,7 @@ func (c *SentinelClient) Sentinels(name string) *SliceCmd {
// asking for agreement to other Sentinels.
func (c *SentinelClient) Failover(name string) *StatusCmd {
cmd := NewStatusCmd("sentinel", "failover", name)
c.Process(cmd)
_ = c.Process(cmd)
return cmd
}
@ -212,7 +211,7 @@ func (c *SentinelClient) Failover(name string) *StatusCmd {
// already discovered and associated with the master.
func (c *SentinelClient) Reset(pattern string) *IntCmd {
cmd := NewIntCmd("sentinel", "reset", pattern)
c.Process(cmd)
_ = c.Process(cmd)
return cmd
}
@ -220,28 +219,28 @@ func (c *SentinelClient) Reset(pattern string) *IntCmd {
// the current Sentinel state.
func (c *SentinelClient) FlushConfig() *StatusCmd {
cmd := NewStatusCmd("sentinel", "flushconfig")
c.Process(cmd)
_ = c.Process(cmd)
return cmd
}
// Master shows the state and info of the specified master.
func (c *SentinelClient) Master(name string) *StringStringMapCmd {
cmd := NewStringStringMapCmd("sentinel", "master", name)
c.Process(cmd)
_ = c.Process(cmd)
return cmd
}
// Masters shows a list of monitored masters and their state.
func (c *SentinelClient) Masters() *SliceCmd {
cmd := NewSliceCmd("sentinel", "masters")
c.Process(cmd)
_ = c.Process(cmd)
return cmd
}
// Slaves shows a list of slaves for the specified master and their state.
func (c *SentinelClient) Slaves(name string) *SliceCmd {
cmd := NewSliceCmd("sentinel", "slaves", name)
c.Process(cmd)
_ = c.Process(cmd)
return cmd
}
@ -251,7 +250,7 @@ func (c *SentinelClient) Slaves(name string) *SliceCmd {
// Sentinel deployment is ok.
func (c *SentinelClient) CkQuorum(name string) *StringCmd {
cmd := NewStringCmd("sentinel", "ckquorum", name)
c.Process(cmd)
_ = c.Process(cmd)
return cmd
}
@ -259,14 +258,14 @@ func (c *SentinelClient) CkQuorum(name string) *StringCmd {
// name, ip, port, and quorum.
func (c *SentinelClient) Monitor(name, ip, port, quorum string) *StringCmd {
cmd := NewStringCmd("sentinel", "monitor", name, ip, port, quorum)
c.Process(cmd)
_ = c.Process(cmd)
return cmd
}
// Set is used in order to change configuration parameters of a specific master.
func (c *SentinelClient) Set(name, option, value string) *StringCmd {
cmd := NewStringCmd("sentinel", "set", name, option, value)
c.Process(cmd)
_ = c.Process(cmd)
return cmd
}
@ -275,7 +274,7 @@ func (c *SentinelClient) Set(name, option, value string) *StringCmd {
// the Sentinel.
func (c *SentinelClient) Remove(name string) *StringCmd {
cmd := NewStringCmd("sentinel", "remove", name)
c.Process(cmd)
_ = c.Process(cmd)
return cmd
}
@ -313,7 +312,7 @@ func (c *sentinelFailover) Pool() *pool.ConnPool {
return c.pool
}
func (c *sentinelFailover) dial(ctx context.Context, network, addr string) (net.Conn, error) {
func (c *sentinelFailover) dial(ctx context.Context, network, _ string) (net.Conn, error) {
addr, err := c.MasterAddr()
if err != nil {
return nil, err
@ -396,7 +395,7 @@ func (c *sentinelFailover) getMasterAddr() string {
c.masterName, err)
c.mu.Lock()
if c.sentinel == sentinel {
c.closeSentinel()
_ = c.closeSentinel()
}
c.mu.Unlock()
return ""
@ -436,13 +435,13 @@ func (c *sentinelFailover) closeSentinel() error {
var firstErr error
err := c.pubsub.Close()
if err != nil && firstErr == err {
if err != nil && firstErr == nil {
firstErr = err
}
c.pubsub = nil
err = c.sentinel.Close()
if err != nil && firstErr == err {
if err != nil && firstErr == nil {
firstErr = err
}
c.sentinel = nil

3
tx.go
View File

@ -40,11 +40,8 @@ func (c *Tx) init() {
}
func (c *Tx) Context() context.Context {
if c.ctx != nil {
return c.ctx
}
return context.Background()
}
func (c *Tx) WithContext(ctx context.Context) *Tx {
if ctx == nil {

View File

@ -1,6 +1,7 @@
package redis_test
import (
"context"
"strconv"
"sync"
@ -124,7 +125,7 @@ var _ = Describe("Tx", func() {
It("should recover from bad connection", func() {
// Put bad connection in the pool.
cn, err := client.Pool().Get(nil)
cn, err := client.Pool().Get(context.Background())
Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{})