Add ctx as first arg

This commit is contained in:
Vladimir Mihailenco 2020-03-11 16:26:42 +02:00
parent 64bb0b7f3a
commit f5593121e0
36 changed files with 3200 additions and 2970 deletions

View File

@ -11,7 +11,7 @@ import (
"github.com/go-redis/redis/v7"
)
func benchmarkRedisClient(poolSize int) *redis.Client {
func benchmarkRedisClient(ctx context.Context, poolSize int) *redis.Client {
client := redis.NewClient(&redis.Options{
Addr: ":6379",
DialTimeout: time.Second,
@ -19,21 +19,22 @@ func benchmarkRedisClient(poolSize int) *redis.Client {
WriteTimeout: time.Second,
PoolSize: poolSize,
})
if err := client.FlushDB().Err(); err != nil {
if err := client.FlushDB(ctx).Err(); err != nil {
panic(err)
}
return client
}
func BenchmarkRedisPing(b *testing.B) {
client := benchmarkRedisClient(10)
ctx := context.Background()
client := benchmarkRedisClient(ctx, 10)
defer client.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if err := client.Ping().Err(); err != nil {
if err := client.Ping(ctx).Err(); err != nil {
b.Fatal(err)
}
}
@ -41,14 +42,15 @@ func BenchmarkRedisPing(b *testing.B) {
}
func BenchmarkRedisGetNil(b *testing.B) {
client := benchmarkRedisClient(10)
ctx := context.Background()
client := benchmarkRedisClient(ctx, 10)
defer client.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if err := client.Get("key").Err(); err != redis.Nil {
if err := client.Get(ctx, "key").Err(); err != redis.Nil {
b.Fatal(err)
}
}
@ -80,7 +82,8 @@ func BenchmarkRedisSetString(b *testing.B) {
}
for _, bm := range benchmarks {
b.Run(bm.String(), func(b *testing.B) {
client := benchmarkRedisClient(bm.poolSize)
ctx := context.Background()
client := benchmarkRedisClient(ctx, bm.poolSize)
defer client.Close()
value := strings.Repeat("1", bm.valueSize)
@ -89,7 +92,7 @@ func BenchmarkRedisSetString(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
err := client.Set("key", value, 0).Err()
err := client.Set(ctx, "key", value, 0).Err()
if err != nil {
b.Fatal(err)
}
@ -100,7 +103,8 @@ func BenchmarkRedisSetString(b *testing.B) {
}
func BenchmarkRedisSetGetBytes(b *testing.B) {
client := benchmarkRedisClient(10)
ctx := context.Background()
client := benchmarkRedisClient(ctx, 10)
defer client.Close()
value := bytes.Repeat([]byte{'1'}, 10000)
@ -109,11 +113,11 @@ func BenchmarkRedisSetGetBytes(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if err := client.Set("key", value, 0).Err(); err != nil {
if err := client.Set(ctx, "key", value, 0).Err(); err != nil {
b.Fatal(err)
}
got, err := client.Get("key").Bytes()
got, err := client.Get(ctx, "key").Bytes()
if err != nil {
b.Fatal(err)
}
@ -125,10 +129,11 @@ func BenchmarkRedisSetGetBytes(b *testing.B) {
}
func BenchmarkRedisMGet(b *testing.B) {
client := benchmarkRedisClient(10)
ctx := context.Background()
client := benchmarkRedisClient(ctx, 10)
defer client.Close()
if err := client.MSet("key1", "hello1", "key2", "hello2").Err(); err != nil {
if err := client.MSet(ctx, "key1", "hello1", "key2", "hello2").Err(); err != nil {
b.Fatal(err)
}
@ -136,7 +141,7 @@ func BenchmarkRedisMGet(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if err := client.MGet("key1", "key2").Err(); err != nil {
if err := client.MGet(ctx, "key1", "key2").Err(); err != nil {
b.Fatal(err)
}
}
@ -144,17 +149,18 @@ func BenchmarkRedisMGet(b *testing.B) {
}
func BenchmarkSetExpire(b *testing.B) {
client := benchmarkRedisClient(10)
ctx := context.Background()
client := benchmarkRedisClient(ctx, 10)
defer client.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if err := client.Set("key", "hello", 0).Err(); err != nil {
if err := client.Set(ctx, "key", "hello", 0).Err(); err != nil {
b.Fatal(err)
}
if err := client.Expire("key", time.Second).Err(); err != nil {
if err := client.Expire(ctx, "key", time.Second).Err(); err != nil {
b.Fatal(err)
}
}
@ -162,16 +168,17 @@ func BenchmarkSetExpire(b *testing.B) {
}
func BenchmarkPipeline(b *testing.B) {
client := benchmarkRedisClient(10)
ctx := context.Background()
client := benchmarkRedisClient(ctx, 10)
defer client.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := client.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Set("key", "hello", 0)
pipe.Expire("key", time.Second)
_, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set(ctx, "key", "hello", 0)
pipe.Expire(ctx, "key", time.Second)
return nil
})
if err != nil {
@ -182,14 +189,15 @@ func BenchmarkPipeline(b *testing.B) {
}
func BenchmarkZAdd(b *testing.B) {
client := benchmarkRedisClient(10)
ctx := context.Background()
client := benchmarkRedisClient(ctx, 10)
defer client.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
err := client.ZAdd("key", &redis.Z{
err := client.ZAdd(ctx, "key", &redis.Z{
Score: float64(1),
Member: "hello",
}).Err()
@ -203,10 +211,9 @@ func BenchmarkZAdd(b *testing.B) {
var clientSink *redis.Client
func BenchmarkWithContext(b *testing.B) {
rdb := benchmarkRedisClient(10)
defer rdb.Close()
ctx := context.Background()
rdb := benchmarkRedisClient(ctx, 10)
defer rdb.Close()
b.ResetTimer()
b.ReportAllocs()
@ -219,11 +226,10 @@ func BenchmarkWithContext(b *testing.B) {
var ringSink *redis.Ring
func BenchmarkRingWithContext(b *testing.B) {
ctx := context.Background()
rdb := redis.NewRing(&redis.RingOptions{})
defer rdb.Close()
ctx := context.Background()
b.ResetTimer()
b.ReportAllocs()
@ -248,20 +254,21 @@ func BenchmarkClusterPing(b *testing.B) {
b.Skip("skipping in short mode")
}
ctx := context.Background()
cluster := newClusterScenario()
if err := startCluster(cluster); err != nil {
if err := startCluster(ctx, cluster); err != nil {
b.Fatal(err)
}
defer stopCluster(cluster)
client := cluster.newClusterClient(redisClusterOptions())
client := cluster.newClusterClient(ctx, redisClusterOptions())
defer client.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
err := client.Ping().Err()
err := client.Ping(ctx).Err()
if err != nil {
b.Fatal(err)
}
@ -274,13 +281,14 @@ func BenchmarkClusterSetString(b *testing.B) {
b.Skip("skipping in short mode")
}
ctx := context.Background()
cluster := newClusterScenario()
if err := startCluster(cluster); err != nil {
if err := startCluster(ctx, cluster); err != nil {
b.Fatal(err)
}
defer stopCluster(cluster)
client := cluster.newClusterClient(redisClusterOptions())
client := cluster.newClusterClient(ctx, redisClusterOptions())
defer client.Close()
value := string(bytes.Repeat([]byte{'1'}, 10000))
@ -289,7 +297,7 @@ func BenchmarkClusterSetString(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
err := client.Set("key", value, 0).Err()
err := client.Set(ctx, "key", value, 0).Err()
if err != nil {
b.Fatal(err)
}
@ -302,19 +310,20 @@ func BenchmarkClusterReloadState(b *testing.B) {
b.Skip("skipping in short mode")
}
ctx := context.Background()
cluster := newClusterScenario()
if err := startCluster(cluster); err != nil {
if err := startCluster(ctx, cluster); err != nil {
b.Fatal(err)
}
defer stopCluster(cluster)
client := cluster.newClusterClient(redisClusterOptions())
client := cluster.newClusterClient(ctx, redisClusterOptions())
defer client.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err := client.ReloadState()
err := client.ReloadState(ctx)
if err != nil {
b.Fatal(err)
}
@ -324,11 +333,10 @@ func BenchmarkClusterReloadState(b *testing.B) {
var clusterSink *redis.ClusterClient
func BenchmarkClusterWithContext(b *testing.B) {
ctx := context.Background()
rdb := redis.NewClusterClient(&redis.ClusterOptions{})
defer rdb.Close()
ctx := context.Background()
b.ResetTimer()
b.ReportAllocs()

View File

@ -191,7 +191,7 @@ func (n *clusterNode) updateLatency() {
var latency uint32
for i := 0; i < probes; i++ {
start := time.Now()
n.Client.Ping()
n.Client.Ping(context.TODO())
probe := uint32(time.Since(start) / time.Microsecond)
latency = (latency + probe) / 2
}
@ -588,20 +588,20 @@ func (c *clusterState) slotNodes(slot int) []*clusterNode {
//------------------------------------------------------------------------------
type clusterStateHolder struct {
load func() (*clusterState, error)
load func(ctx context.Context) (*clusterState, error)
state atomic.Value
reloading uint32 // atomic
}
func newClusterStateHolder(fn func() (*clusterState, error)) *clusterStateHolder {
func newClusterStateHolder(fn func(ctx context.Context) (*clusterState, error)) *clusterStateHolder {
return &clusterStateHolder{
load: fn,
}
}
func (c *clusterStateHolder) Reload() (*clusterState, error) {
state, err := c.load()
func (c *clusterStateHolder) Reload(ctx context.Context) (*clusterState, error) {
state, err := c.load(ctx)
if err != nil {
return nil, err
}
@ -609,14 +609,14 @@ func (c *clusterStateHolder) Reload() (*clusterState, error) {
return state, nil
}
func (c *clusterStateHolder) LazyReload() {
func (c *clusterStateHolder) LazyReload(ctx context.Context) {
if !atomic.CompareAndSwapUint32(&c.reloading, 0, 1) {
return
}
go func() {
defer atomic.StoreUint32(&c.reloading, 0)
_, err := c.Reload()
_, err := c.Reload(ctx)
if err != nil {
return
}
@ -624,24 +624,24 @@ func (c *clusterStateHolder) LazyReload() {
}()
}
func (c *clusterStateHolder) Get() (*clusterState, error) {
func (c *clusterStateHolder) Get(ctx context.Context) (*clusterState, error) {
v := c.state.Load()
if v != nil {
state := v.(*clusterState)
if time.Since(state.createdAt) > time.Minute {
c.LazyReload()
c.LazyReload(ctx)
}
return state, nil
}
return c.Reload()
return c.Reload(ctx)
}
func (c *clusterStateHolder) ReloadOrGet() (*clusterState, error) {
state, err := c.Reload()
func (c *clusterStateHolder) ReloadOrGet(ctx context.Context) (*clusterState, error) {
state, err := c.Reload(ctx)
if err == nil {
return state, nil
}
return c.Get()
return c.Get(ctx)
}
//------------------------------------------------------------------------------
@ -708,8 +708,8 @@ func (c *ClusterClient) Options() *ClusterOptions {
// ReloadState reloads cluster state. If available it calls ClusterSlots func
// to get cluster slots information.
func (c *ClusterClient) ReloadState() error {
_, err := c.state.Reload()
func (c *ClusterClient) ReloadState(ctx context.Context) error {
_, err := c.state.Reload(ctx)
return err
}
@ -722,21 +722,13 @@ func (c *ClusterClient) Close() error {
}
// Do creates a Cmd from the args and processes the cmd.
func (c *ClusterClient) Do(args ...interface{}) *Cmd {
return c.DoContext(c.ctx, args...)
}
func (c *ClusterClient) DoContext(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(args...)
_ = c.ProcessContext(ctx, cmd)
func (c *ClusterClient) Do(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(ctx, args...)
_ = c.Process(ctx, cmd)
return cmd
}
func (c *ClusterClient) Process(cmd Cmder) error {
return c.ProcessContext(c.ctx, cmd)
}
func (c *ClusterClient) ProcessContext(ctx context.Context, cmd Cmder) error {
func (c *ClusterClient) Process(ctx context.Context, cmd Cmder) error {
return c.hooks.process(ctx, cmd, c.process)
}
@ -765,7 +757,7 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error {
if node == nil {
var err error
node, err = c.cmdNode(cmdInfo, slot)
node, err = c.cmdNode(ctx, cmdInfo, slot)
if err != nil {
return err
}
@ -773,13 +765,13 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error {
if ask {
pipe := node.Client.Pipeline()
_ = pipe.Process(NewCmd("asking"))
_ = pipe.Process(cmd)
_, lastErr = pipe.ExecContext(ctx)
_ = pipe.Process(ctx, NewCmd(ctx, "asking"))
_ = pipe.Process(ctx, cmd)
_, lastErr = pipe.Exec(ctx)
_ = pipe.Close()
ask = false
} else {
lastErr = node.Client.ProcessContext(ctx, cmd)
lastErr = node.Client.Process(ctx, cmd)
}
// If there is no error - we are done.
@ -787,7 +779,7 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error {
return nil
}
if lastErr != Nil {
c.state.LazyReload()
c.state.LazyReload(ctx)
}
if lastErr == pool.ErrClosed || isReadOnlyError(lastErr) {
node = nil
@ -832,8 +824,11 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error {
// ForEachMaster concurrently calls the fn on each master node in the cluster.
// It returns the first error if any.
func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error {
state, err := c.state.ReloadOrGet()
func (c *ClusterClient) ForEachMaster(
ctx context.Context,
fn func(ctx context.Context, client *Client) error,
) error {
state, err := c.state.ReloadOrGet(ctx)
if err != nil {
return err
}
@ -845,7 +840,7 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error {
wg.Add(1)
go func(node *clusterNode) {
defer wg.Done()
err := fn(node.Client)
err := fn(ctx, node.Client)
if err != nil {
select {
case errCh <- err:
@ -867,8 +862,11 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error {
// ForEachSlave concurrently calls the fn on each slave node in the cluster.
// It returns the first error if any.
func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error {
state, err := c.state.ReloadOrGet()
func (c *ClusterClient) ForEachSlave(
ctx context.Context,
fn func(ctx context.Context, client *Client) error,
) error {
state, err := c.state.ReloadOrGet(ctx)
if err != nil {
return err
}
@ -880,7 +878,7 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error {
wg.Add(1)
go func(node *clusterNode) {
defer wg.Done()
err := fn(node.Client)
err := fn(ctx, node.Client)
if err != nil {
select {
case errCh <- err:
@ -902,8 +900,11 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error {
// ForEachNode concurrently calls the fn on each known node in the cluster.
// It returns the first error if any.
func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error {
state, err := c.state.ReloadOrGet()
func (c *ClusterClient) ForEachNode(
ctx context.Context,
fn func(ctx context.Context, client *Client) error,
) error {
state, err := c.state.ReloadOrGet(ctx)
if err != nil {
return err
}
@ -913,7 +914,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error {
worker := func(node *clusterNode) {
defer wg.Done()
err := fn(node.Client)
err := fn(ctx, node.Client)
if err != nil {
select {
case errCh <- err:
@ -945,7 +946,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error {
func (c *ClusterClient) PoolStats() *PoolStats {
var acc PoolStats
state, _ := c.state.Get()
state, _ := c.state.Get(context.TODO())
if state == nil {
return &acc
}
@ -975,7 +976,7 @@ func (c *ClusterClient) PoolStats() *PoolStats {
return &acc
}
func (c *ClusterClient) loadState() (*clusterState, error) {
func (c *ClusterClient) loadState(ctx context.Context) (*clusterState, error) {
if c.opt.ClusterSlots != nil {
slots, err := c.opt.ClusterSlots()
if err != nil {
@ -999,7 +1000,7 @@ func (c *ClusterClient) loadState() (*clusterState, error) {
continue
}
slots, err := node.Client.ClusterSlots().Result()
slots, err := node.Client.ClusterSlots(ctx).Result()
if err != nil {
if firstErr == nil {
firstErr = err
@ -1042,8 +1043,8 @@ func (c *ClusterClient) Pipeline() Pipeliner {
return &pipe
}
func (c *ClusterClient) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(fn)
func (c *ClusterClient) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(ctx, fn)
}
func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error {
@ -1052,7 +1053,7 @@ func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error
func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) error {
cmdsMap := newCmdsMap()
err := c.mapCmdsByNode(cmdsMap, cmds)
err := c.mapCmdsByNode(ctx, cmdsMap, cmds)
if err != nil {
setCmdsErr(cmds, err)
return err
@ -1079,7 +1080,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro
return
}
if attempt < c.opt.MaxRedirects {
if err := c.mapCmdsByNode(failedCmds, cmds); err != nil {
if err := c.mapCmdsByNode(ctx, failedCmds, cmds); err != nil {
setCmdsErr(cmds, err)
}
} else {
@ -1098,8 +1099,8 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro
return cmdsFirstErr(cmds)
}
func (c *ClusterClient) mapCmdsByNode(cmdsMap *cmdsMap, cmds []Cmder) error {
state, err := c.state.Get()
func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmds []Cmder) error {
state, err := c.state.Get(ctx)
if err != nil {
return err
}
@ -1150,21 +1151,25 @@ func (c *ClusterClient) _processPipelineNode(
}
return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
return c.pipelineReadCmds(node, rd, cmds, failedCmds)
return c.pipelineReadCmds(ctx, node, rd, cmds, failedCmds)
})
})
})
}
func (c *ClusterClient) pipelineReadCmds(
node *clusterNode, rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap,
ctx context.Context,
node *clusterNode,
rd *proto.Reader,
cmds []Cmder,
failedCmds *cmdsMap,
) error {
for _, cmd := range cmds {
err := cmd.readReply(rd)
if err == nil {
continue
}
if c.checkMovedErr(cmd, err, failedCmds) {
if c.checkMovedErr(ctx, cmd, err, failedCmds) {
continue
}
@ -1181,7 +1186,7 @@ func (c *ClusterClient) pipelineReadCmds(
}
func (c *ClusterClient) checkMovedErr(
cmd Cmder, err error, failedCmds *cmdsMap,
ctx context.Context, cmd Cmder, err error, failedCmds *cmdsMap,
) bool {
moved, ask, addr := isMovedError(err)
if !moved && !ask {
@ -1194,13 +1199,13 @@ func (c *ClusterClient) checkMovedErr(
}
if moved {
c.state.LazyReload()
c.state.LazyReload(ctx)
failedCmds.Add(node, cmd)
return true
}
if ask {
failedCmds.Add(node, NewCmd("asking"), cmd)
failedCmds.Add(node, NewCmd(ctx, "asking"), cmd)
return true
}
@ -1217,8 +1222,8 @@ func (c *ClusterClient) TxPipeline() Pipeliner {
return &pipe
}
func (c *ClusterClient) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(fn)
func (c *ClusterClient) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(ctx, fn)
}
func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
@ -1226,7 +1231,7 @@ func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) err
}
func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) error {
state, err := c.state.Get()
state, err := c.state.Get(ctx)
if err != nil {
setCmdsErr(cmds, err)
return err
@ -1262,7 +1267,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er
return
}
if attempt < c.opt.MaxRedirects {
if err := c.mapCmdsByNode(failedCmds, cmds); err != nil {
if err := c.mapCmdsByNode(ctx, failedCmds, cmds); err != nil {
setCmdsErr(cmds, err)
}
} else {
@ -1308,11 +1313,11 @@ func (c *ClusterClient) _processTxPipelineNode(
// Trim multi and exec.
cmds = cmds[1 : len(cmds)-1]
err := c.txPipelineReadQueued(rd, statusCmd, cmds, failedCmds)
err := c.txPipelineReadQueued(ctx, rd, statusCmd, cmds, failedCmds)
if err != nil {
moved, ask, addr := isMovedError(err)
if moved || ask {
return c.cmdsMoved(cmds, moved, ask, addr, failedCmds)
return c.cmdsMoved(ctx, cmds, moved, ask, addr, failedCmds)
}
return err
}
@ -1324,7 +1329,11 @@ func (c *ClusterClient) _processTxPipelineNode(
}
func (c *ClusterClient) txPipelineReadQueued(
rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder, failedCmds *cmdsMap,
ctx context.Context,
rd *proto.Reader,
statusCmd *StatusCmd,
cmds []Cmder,
failedCmds *cmdsMap,
) error {
// Parse queued replies.
if err := statusCmd.readReply(rd); err != nil {
@ -1333,7 +1342,7 @@ func (c *ClusterClient) txPipelineReadQueued(
for _, cmd := range cmds {
err := statusCmd.readReply(rd)
if err == nil || c.checkMovedErr(cmd, err, failedCmds) || isRedisError(err) {
if err == nil || c.checkMovedErr(ctx, cmd, err, failedCmds) || isRedisError(err) {
continue
}
return err
@ -1361,7 +1370,10 @@ func (c *ClusterClient) txPipelineReadQueued(
}
func (c *ClusterClient) cmdsMoved(
cmds []Cmder, moved, ask bool, addr string, failedCmds *cmdsMap,
ctx context.Context, cmds []Cmder,
moved, ask bool,
addr string,
failedCmds *cmdsMap,
) error {
node, err := c.nodes.Get(addr)
if err != nil {
@ -1369,7 +1381,7 @@ func (c *ClusterClient) cmdsMoved(
}
if moved {
c.state.LazyReload()
c.state.LazyReload(ctx)
for _, cmd := range cmds {
failedCmds.Add(node, cmd)
}
@ -1378,7 +1390,7 @@ func (c *ClusterClient) cmdsMoved(
if ask {
for _, cmd := range cmds {
failedCmds.Add(node, NewCmd("asking"), cmd)
failedCmds.Add(node, NewCmd(ctx, "asking"), cmd)
}
return nil
}
@ -1386,11 +1398,7 @@ func (c *ClusterClient) cmdsMoved(
return nil
}
func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error {
return c.WatchContext(c.ctx, fn, keys...)
}
func (c *ClusterClient) WatchContext(ctx context.Context, fn func(*Tx) error, keys ...string) error {
func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error {
if len(keys) == 0 {
return fmt.Errorf("redis: Watch requires at least one key")
}
@ -1403,7 +1411,7 @@ func (c *ClusterClient) WatchContext(ctx context.Context, fn func(*Tx) error, ke
}
}
node, err := c.slotMasterNode(slot)
node, err := c.slotMasterNode(ctx, slot)
if err != nil {
return err
}
@ -1415,12 +1423,12 @@ func (c *ClusterClient) WatchContext(ctx context.Context, fn func(*Tx) error, ke
}
}
err = node.Client.WatchContext(ctx, fn, keys...)
err = node.Client.Watch(ctx, fn, keys...)
if err == nil {
break
}
if err != Nil {
c.state.LazyReload()
c.state.LazyReload(ctx)
}
moved, ask, addr := isMovedError(err)
@ -1433,7 +1441,7 @@ func (c *ClusterClient) WatchContext(ctx context.Context, fn func(*Tx) error, ke
}
if err == pool.ErrClosed || isReadOnlyError(err) {
node, err = c.slotMasterNode(slot)
node, err = c.slotMasterNode(ctx, slot)
if err != nil {
return err
}
@ -1455,7 +1463,7 @@ func (c *ClusterClient) pubSub() *PubSub {
pubsub := &PubSub{
opt: c.opt.clientOptions(),
newConn: func(channels []string) (*pool.Conn, error) {
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
if node != nil {
panic("node != nil")
}
@ -1463,7 +1471,7 @@ func (c *ClusterClient) pubSub() *PubSub {
var err error
if len(channels) > 0 {
slot := hashtag.Slot(channels[0])
node, err = c.slotMasterNode(slot)
node, err = c.slotMasterNode(ctx, slot)
} else {
node, err = c.nodes.Random()
}
@ -1493,20 +1501,20 @@ func (c *ClusterClient) pubSub() *PubSub {
// Subscribe subscribes the client to the specified channels.
// Channels can be omitted to create empty subscription.
func (c *ClusterClient) Subscribe(channels ...string) *PubSub {
func (c *ClusterClient) Subscribe(ctx context.Context, channels ...string) *PubSub {
pubsub := c.pubSub()
if len(channels) > 0 {
_ = pubsub.Subscribe(channels...)
_ = pubsub.Subscribe(ctx, channels...)
}
return pubsub
}
// PSubscribe subscribes the client to the given patterns.
// Patterns can be omitted to create empty subscription.
func (c *ClusterClient) PSubscribe(channels ...string) *PubSub {
func (c *ClusterClient) PSubscribe(ctx context.Context, channels ...string) *PubSub {
pubsub := c.pubSub()
if len(channels) > 0 {
_ = pubsub.PSubscribe(channels...)
_ = pubsub.PSubscribe(ctx, channels...)
}
return pubsub
}
@ -1531,7 +1539,7 @@ func (c *ClusterClient) cmdsInfo() (map[string]*CommandInfo, error) {
continue
}
info, err := node.Client.Command().Result()
info, err := node.Client.Command(context.TODO()).Result()
if err == nil {
return info, nil
}
@ -1573,8 +1581,12 @@ func cmdSlot(cmd Cmder, pos int) int {
return hashtag.Slot(firstKey)
}
func (c *ClusterClient) cmdNode(cmdInfo *CommandInfo, slot int) (*clusterNode, error) {
state, err := c.state.Get()
func (c *ClusterClient) cmdNode(
ctx context.Context,
cmdInfo *CommandInfo,
slot int,
) (*clusterNode, error) {
state, err := c.state.Get(ctx)
if err != nil {
return nil, err
}
@ -1595,8 +1607,8 @@ func (c *clusterClient) slotReadOnlyNode(state *clusterState, slot int) (*cluste
return state.slotSlaveNode(slot)
}
func (c *ClusterClient) slotMasterNode(slot int) (*clusterNode, error) {
state, err := c.state.Get()
func (c *ClusterClient) slotMasterNode(ctx context.Context, slot int) (*clusterNode, error) {
state, err := c.state.Get(ctx)
if err != nil {
return nil, err
}

View File

@ -1,12 +1,15 @@
package redis
import "sync/atomic"
import (
"context"
"sync/atomic"
)
func (c *ClusterClient) DBSize() *IntCmd {
cmd := NewIntCmd("dbsize")
func (c *ClusterClient) DBSize(ctx context.Context) *IntCmd {
cmd := NewIntCmd(ctx, "dbsize")
var size int64
err := c.ForEachMaster(func(master *Client) error {
n, err := master.DBSize().Result()
err := c.ForEachMaster(ctx, func(ctx context.Context, master *Client) error {
n, err := master.DBSize(ctx).Result()
if err != nil {
return err
}

View File

@ -53,7 +53,9 @@ func (s *clusterScenario) newClusterClientUnsafe(opt *redis.ClusterOptions) *red
}
func (s *clusterScenario) newClusterClient(opt *redis.ClusterOptions) *redis.ClusterClient {
func (s *clusterScenario) newClusterClient(
ctx context.Context, opt *redis.ClusterOptions,
) *redis.ClusterClient {
client := s.newClusterClientUnsafe(opt)
err := eventually(func() error {
@ -61,12 +63,12 @@ func (s *clusterScenario) newClusterClient(opt *redis.ClusterOptions) *redis.Clu
return nil
}
state, err := client.LoadState()
state, err := client.LoadState(ctx)
if err != nil {
return err
}
if !state.IsConsistent() {
if !state.IsConsistent(ctx) {
return fmt.Errorf("cluster state is not consistent")
}
@ -79,7 +81,7 @@ func (s *clusterScenario) newClusterClient(opt *redis.ClusterOptions) *redis.Clu
return client
}
func startCluster(scenario *clusterScenario) error {
func startCluster(ctx context.Context, scenario *clusterScenario) error {
// Start processes and collect node ids
for pos, port := range scenario.ports {
process, err := startRedis(port, "--cluster-enabled", "yes")
@ -91,7 +93,7 @@ func startCluster(scenario *clusterScenario) error {
Addr: ":" + port,
})
info, err := client.ClusterNodes().Result()
info, err := client.ClusterNodes(ctx).Result()
if err != nil {
return err
}
@ -103,7 +105,7 @@ func startCluster(scenario *clusterScenario) error {
// Meet cluster nodes.
for _, client := range scenario.clients {
err := client.ClusterMeet("127.0.0.1", scenario.ports[0]).Err()
err := client.ClusterMeet(ctx, "127.0.0.1", scenario.ports[0]).Err()
if err != nil {
return err
}
@ -112,7 +114,7 @@ func startCluster(scenario *clusterScenario) error {
// Bootstrap masters.
slots := []int{0, 5000, 10000, 16384}
for pos, master := range scenario.masters() {
err := master.ClusterAddSlotsRange(slots[pos], slots[pos+1]-1).Err()
err := master.ClusterAddSlotsRange(ctx, slots[pos], slots[pos+1]-1).Err()
if err != nil {
return err
}
@ -124,7 +126,7 @@ func startCluster(scenario *clusterScenario) error {
// Wait until master is available
err := eventually(func() error {
s := slave.ClusterNodes().Val()
s := slave.ClusterNodes(ctx).Val()
wanted := masterID
if !strings.Contains(s, wanted) {
return fmt.Errorf("%q does not contain %q", s, wanted)
@ -135,7 +137,7 @@ func startCluster(scenario *clusterScenario) error {
return err
}
err = slave.ClusterReplicate(masterID).Err()
err = slave.ClusterReplicate(ctx, masterID).Err()
if err != nil {
return err
}
@ -175,7 +177,7 @@ func startCluster(scenario *clusterScenario) error {
}}
for _, client := range scenario.clients {
err := eventually(func() error {
res, err := client.ClusterSlots().Result()
res, err := client.ClusterSlots(ctx).Result()
if err != nil {
return err
}
@ -243,48 +245,48 @@ var _ = Describe("ClusterClient", func() {
assertClusterClient := func() {
It("supports WithContext", func() {
c, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(ctx)
cancel()
err := client.WithContext(c).Ping().Err()
err := client.Ping(ctx).Err()
Expect(err).To(MatchError("context canceled"))
})
It("should GET/SET/DEL", func() {
err := client.Get("A").Err()
err := client.Get(ctx, "A").Err()
Expect(err).To(Equal(redis.Nil))
err = client.Set("A", "VALUE", 0).Err()
err = client.Set(ctx, "A", "VALUE", 0).Err()
Expect(err).NotTo(HaveOccurred())
Eventually(func() string {
return client.Get("A").Val()
return client.Get(ctx, "A").Val()
}, 30*time.Second).Should(Equal("VALUE"))
cnt, err := client.Del("A").Result()
cnt, err := client.Del(ctx, "A").Result()
Expect(err).NotTo(HaveOccurred())
Expect(cnt).To(Equal(int64(1)))
})
It("GET follows redirects", func() {
err := client.Set("A", "VALUE", 0).Err()
err := client.Set(ctx, "A", "VALUE", 0).Err()
Expect(err).NotTo(HaveOccurred())
if !failover {
Eventually(func() int64 {
nodes, err := client.Nodes("A")
nodes, err := client.Nodes(ctx, "A")
if err != nil {
return 0
}
return nodes[1].Client.DBSize().Val()
return nodes[1].Client.DBSize(ctx).Val()
}, 30*time.Second).Should(Equal(int64(1)))
Eventually(func() error {
return client.SwapNodes("A")
return client.SwapNodes(ctx, "A")
}, 30*time.Second).ShouldNot(HaveOccurred())
}
v, err := client.Get("A").Result()
v, err := client.Get(ctx, "A").Result()
Expect(err).NotTo(HaveOccurred())
Expect(v).To(Equal("VALUE"))
})
@ -292,28 +294,28 @@ var _ = Describe("ClusterClient", func() {
It("SET follows redirects", func() {
if !failover {
Eventually(func() error {
return client.SwapNodes("A")
return client.SwapNodes(ctx, "A")
}, 30*time.Second).ShouldNot(HaveOccurred())
}
err := client.Set("A", "VALUE", 0).Err()
err := client.Set(ctx, "A", "VALUE", 0).Err()
Expect(err).NotTo(HaveOccurred())
v, err := client.Get("A").Result()
v, err := client.Get(ctx, "A").Result()
Expect(err).NotTo(HaveOccurred())
Expect(v).To(Equal("VALUE"))
})
It("distributes keys", func() {
for i := 0; i < 100; i++ {
err := client.Set(fmt.Sprintf("key%d", i), "value", 0).Err()
err := client.Set(ctx, fmt.Sprintf("key%d", i), "value", 0).Err()
Expect(err).NotTo(HaveOccurred())
}
client.ForEachMaster(func(master *redis.Client) error {
client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error {
defer GinkgoRecover()
Eventually(func() string {
return master.Info("keyspace").Val()
return master.Info(ctx, "keyspace").Val()
}, 30*time.Second).Should(Or(
ContainSubstring("keys=31"),
ContainSubstring("keys=29"),
@ -332,14 +334,14 @@ var _ = Describe("ClusterClient", func() {
var key string
for i := 0; i < 100; i++ {
key = fmt.Sprintf("key%d", i)
err := script.Run(client, []string{key}, "value").Err()
err := script.Run(ctx, client, []string{key}, "value").Err()
Expect(err).NotTo(HaveOccurred())
}
client.ForEachMaster(func(master *redis.Client) error {
client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error {
defer GinkgoRecover()
Eventually(func() string {
return master.Info("keyspace").Val()
return master.Info(ctx, "keyspace").Val()
}, 30*time.Second).Should(Or(
ContainSubstring("keys=31"),
ContainSubstring("keys=29"),
@ -354,14 +356,14 @@ var _ = Describe("ClusterClient", func() {
// Transactionally increments key using GET and SET commands.
incr = func(key string) error {
err := client.Watch(func(tx *redis.Tx) error {
n, err := tx.Get(key).Int64()
err := client.Watch(ctx, func(tx *redis.Tx) error {
n, err := tx.Get(ctx, key).Int64()
if err != nil && err != redis.Nil {
return err
}
_, err = tx.TxPipelined(func(pipe redis.Pipeliner) error {
pipe.Set(key, strconv.FormatInt(n+1, 10), 0)
_, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set(ctx, key, strconv.FormatInt(n+1, 10), 0)
return nil
})
return err
@ -386,7 +388,7 @@ var _ = Describe("ClusterClient", func() {
wg.Wait()
Eventually(func() string {
return client.Get("key").Val()
return client.Get(ctx, "key").Val()
}, 30*time.Second).Should(Equal("100"))
})
@ -400,23 +402,23 @@ var _ = Describe("ClusterClient", func() {
if !failover {
for _, key := range keys {
Eventually(func() error {
return client.SwapNodes(key)
return client.SwapNodes(ctx, key)
}, 30*time.Second).ShouldNot(HaveOccurred())
}
}
for i, key := range keys {
pipe.Set(key, key+"_value", 0)
pipe.Expire(key, time.Duration(i+1)*time.Hour)
pipe.Set(ctx, key, key+"_value", 0)
pipe.Expire(ctx, key, time.Duration(i+1)*time.Hour)
}
cmds, err := pipe.Exec()
cmds, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(14))
_ = client.ForEachNode(func(node *redis.Client) error {
_ = client.ForEachNode(ctx, func(ctx context.Context, node *redis.Client) error {
defer GinkgoRecover()
Eventually(func() int64 {
return node.DBSize().Val()
return node.DBSize(ctx).Val()
}, 30*time.Second).ShouldNot(BeZero())
return nil
})
@ -424,16 +426,16 @@ var _ = Describe("ClusterClient", func() {
if !failover {
for _, key := range keys {
Eventually(func() error {
return client.SwapNodes(key)
return client.SwapNodes(ctx, key)
}, 30*time.Second).ShouldNot(HaveOccurred())
}
}
for _, key := range keys {
pipe.Get(key)
pipe.TTL(key)
pipe.Get(ctx, key)
pipe.TTL(ctx, key)
}
cmds, err = pipe.Exec()
cmds, err = pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(14))
@ -448,15 +450,15 @@ var _ = Describe("ClusterClient", func() {
})
It("works with missing keys", func() {
pipe.Set("A", "A_value", 0)
pipe.Set("C", "C_value", 0)
_, err := pipe.Exec()
pipe.Set(ctx, "A", "A_value", 0)
pipe.Set(ctx, "C", "C_value", 0)
_, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred())
a := pipe.Get("A")
b := pipe.Get("B")
c := pipe.Get("C")
cmds, err := pipe.Exec()
a := pipe.Get(ctx, "A")
b := pipe.Get(ctx, "B")
c := pipe.Get(ctx, "C")
cmds, err := pipe.Exec(ctx)
Expect(err).To(Equal(redis.Nil))
Expect(cmds).To(HaveLen(3))
@ -497,16 +499,16 @@ var _ = Describe("ClusterClient", func() {
})
It("supports PubSub", func() {
pubsub := client.Subscribe("mychannel")
pubsub := client.Subscribe(ctx, "mychannel")
defer pubsub.Close()
Eventually(func() error {
_, err := client.Publish("mychannel", "hello").Result()
_, err := client.Publish(ctx, "mychannel", "hello").Result()
if err != nil {
return err
}
msg, err := pubsub.ReceiveTimeout(time.Second)
msg, err := pubsub.ReceiveTimeout(ctx, time.Second)
if err != nil {
return err
}
@ -521,19 +523,19 @@ var _ = Describe("ClusterClient", func() {
})
It("supports PubSub.Ping without channels", func() {
pubsub := client.Subscribe()
pubsub := client.Subscribe(ctx)
defer pubsub.Close()
err := pubsub.Ping()
err := pubsub.Ping(ctx)
Expect(err).NotTo(HaveOccurred())
})
It("supports Process hook", func() {
err := client.Ping().Err()
err := client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred())
err = client.ForEachNode(func(node *redis.Client) error {
return node.Ping().Err()
err = client.ForEachNode(ctx, func(ctx context.Context, node *redis.Client) error {
return node.Ping(ctx).Err()
})
Expect(err).NotTo(HaveOccurred())
@ -566,12 +568,12 @@ var _ = Describe("ClusterClient", func() {
},
}
_ = client.ForEachNode(func(node *redis.Client) error {
_ = client.ForEachNode(ctx, func(ctx context.Context, node *redis.Client) error {
node.AddHook(nodeHook)
return nil
})
err = client.Ping().Err()
err = client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred())
Expect(stack).To(Equal([]string{
"cluster.BeforeProcess",
@ -587,11 +589,11 @@ var _ = Describe("ClusterClient", func() {
})
It("supports Pipeline hook", func() {
err := client.Ping().Err()
err := client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred())
err = client.ForEachNode(func(node *redis.Client) error {
return node.Ping().Err()
err = client.ForEachNode(ctx, func(ctx context.Context, node *redis.Client) error {
return node.Ping(ctx).Err()
})
Expect(err).NotTo(HaveOccurred())
@ -612,7 +614,7 @@ var _ = Describe("ClusterClient", func() {
},
})
_ = client.ForEachNode(func(node *redis.Client) error {
_ = client.ForEachNode(ctx, func(ctx context.Context, node *redis.Client) error {
node.AddHook(&hook{
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
Expect(cmds).To(HaveLen(1))
@ -630,8 +632,8 @@ var _ = Describe("ClusterClient", func() {
return nil
})
_, err = client.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
_, err = client.Pipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping(ctx)
return nil
})
Expect(err).NotTo(HaveOccurred())
@ -644,11 +646,11 @@ var _ = Describe("ClusterClient", func() {
})
It("supports TxPipeline hook", func() {
err := client.Ping().Err()
err := client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred())
err = client.ForEachNode(func(node *redis.Client) error {
return node.Ping().Err()
err = client.ForEachNode(ctx, func(ctx context.Context, node *redis.Client) error {
return node.Ping(ctx).Err()
})
Expect(err).NotTo(HaveOccurred())
@ -669,7 +671,7 @@ var _ = Describe("ClusterClient", func() {
},
})
_ = client.ForEachNode(func(node *redis.Client) error {
_ = client.ForEachNode(ctx, func(ctx context.Context, node *redis.Client) error {
node.AddHook(&hook{
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
Expect(cmds).To(HaveLen(3))
@ -687,8 +689,8 @@ var _ = Describe("ClusterClient", func() {
return nil
})
_, err = client.TxPipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
_, err = client.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping(ctx)
return nil
})
Expect(err).NotTo(HaveOccurred())
@ -704,17 +706,17 @@ var _ = Describe("ClusterClient", func() {
Describe("ClusterClient", func() {
BeforeEach(func() {
opt = redisClusterOptions()
client = cluster.newClusterClient(opt)
client = cluster.newClusterClient(ctx, opt)
err := client.ForEachMaster(func(master *redis.Client) error {
return master.FlushDB().Err()
err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error {
return master.FlushDB(ctx).Err()
})
Expect(err).NotTo(HaveOccurred())
})
AfterEach(func() {
_ = client.ForEachMaster(func(master *redis.Client) error {
return master.FlushDB().Err()
_ = client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error {
return master.FlushDB(ctx).Err()
})
Expect(client.Close()).NotTo(HaveOccurred())
})
@ -727,13 +729,13 @@ var _ = Describe("ClusterClient", func() {
It("returns an error when there are no attempts left", func() {
opt := redisClusterOptions()
opt.MaxRedirects = -1
client := cluster.newClusterClient(opt)
client := cluster.newClusterClient(ctx, opt)
Eventually(func() error {
return client.SwapNodes("A")
return client.SwapNodes(ctx, "A")
}, 30*time.Second).ShouldNot(HaveOccurred())
err := client.Get("A").Err()
err := client.Get(ctx, "A").Err()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("MOVED"))
@ -742,21 +744,21 @@ var _ = Describe("ClusterClient", func() {
It("calls fn for every master node", func() {
for i := 0; i < 10; i++ {
Expect(client.Set(strconv.Itoa(i), "", 0).Err()).NotTo(HaveOccurred())
Expect(client.Set(ctx, strconv.Itoa(i), "", 0).Err()).NotTo(HaveOccurred())
}
err := client.ForEachMaster(func(master *redis.Client) error {
return master.FlushDB().Err()
err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error {
return master.FlushDB(ctx).Err()
})
Expect(err).NotTo(HaveOccurred())
size, err := client.DBSize().Result()
size, err := client.DBSize(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(size).To(Equal(int64(0)))
})
It("should CLUSTER SLOTS", func() {
res, err := client.ClusterSlots().Result()
res, err := client.ClusterSlots(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(res).To(HaveLen(3))
@ -795,49 +797,49 @@ var _ = Describe("ClusterClient", func() {
})
It("should CLUSTER NODES", func() {
res, err := client.ClusterNodes().Result()
res, err := client.ClusterNodes(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(len(res)).To(BeNumerically(">", 400))
})
It("should CLUSTER INFO", func() {
res, err := client.ClusterInfo().Result()
res, err := client.ClusterInfo(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(res).To(ContainSubstring("cluster_known_nodes:6"))
})
It("should CLUSTER KEYSLOT", func() {
hashSlot, err := client.ClusterKeySlot("somekey").Result()
hashSlot, err := client.ClusterKeySlot(ctx, "somekey").Result()
Expect(err).NotTo(HaveOccurred())
Expect(hashSlot).To(Equal(int64(hashtag.Slot("somekey"))))
})
It("should CLUSTER GETKEYSINSLOT", func() {
keys, err := client.ClusterGetKeysInSlot(hashtag.Slot("somekey"), 1).Result()
keys, err := client.ClusterGetKeysInSlot(ctx, hashtag.Slot("somekey"), 1).Result()
Expect(err).NotTo(HaveOccurred())
Expect(len(keys)).To(Equal(0))
})
It("should CLUSTER COUNT-FAILURE-REPORTS", func() {
n, err := client.ClusterCountFailureReports(cluster.nodeIDs[0]).Result()
n, err := client.ClusterCountFailureReports(ctx, cluster.nodeIDs[0]).Result()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(0)))
})
It("should CLUSTER COUNTKEYSINSLOT", func() {
n, err := client.ClusterCountKeysInSlot(10).Result()
n, err := client.ClusterCountKeysInSlot(ctx, 10).Result()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(0)))
})
It("should CLUSTER SAVECONFIG", func() {
res, err := client.ClusterSaveConfig().Result()
res, err := client.ClusterSaveConfig(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(res).To(Equal("OK"))
})
It("should CLUSTER SLAVES", func() {
nodesList, err := client.ClusterSlaves(cluster.nodeIDs[0]).Result()
nodesList, err := client.ClusterSlaves(ctx, cluster.nodeIDs[0]).Result()
Expect(err).NotTo(HaveOccurred())
Expect(nodesList).Should(ContainElement(ContainSubstring("slave")))
Expect(nodesList).Should(HaveLen(1))
@ -847,7 +849,7 @@ var _ = Describe("ClusterClient", func() {
const nkeys = 100
for i := 0; i < nkeys; i++ {
err := client.Set(fmt.Sprintf("key%d", i), "value", 0).Err()
err := client.Set(ctx, fmt.Sprintf("key%d", i), "value", 0).Err()
Expect(err).NotTo(HaveOccurred())
}
@ -862,7 +864,7 @@ var _ = Describe("ClusterClient", func() {
}
for i := 0; i < nkeys*10; i++ {
key := client.RandomKey().Val()
key := client.RandomKey(ctx).Val()
addKey(key)
}
@ -879,40 +881,40 @@ var _ = Describe("ClusterClient", func() {
opt = redisClusterOptions()
opt.MinRetryBackoff = 250 * time.Millisecond
opt.MaxRetryBackoff = time.Second
client = cluster.newClusterClient(opt)
client = cluster.newClusterClient(ctx, opt)
err := client.ForEachMaster(func(master *redis.Client) error {
return master.FlushDB().Err()
err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error {
return master.FlushDB(ctx).Err()
})
Expect(err).NotTo(HaveOccurred())
err = client.ForEachSlave(func(slave *redis.Client) error {
err = client.ForEachSlave(ctx, func(ctx context.Context, slave *redis.Client) error {
defer GinkgoRecover()
Eventually(func() int64 {
return slave.DBSize().Val()
return slave.DBSize(ctx).Val()
}, "30s").Should(Equal(int64(0)))
return nil
})
Expect(err).NotTo(HaveOccurred())
state, err := client.LoadState()
state, err := client.LoadState(ctx)
Eventually(func() bool {
state, err = client.LoadState()
state, err = client.LoadState(ctx)
if err != nil {
return false
}
return state.IsConsistent()
return state.IsConsistent(ctx)
}, "30s").Should(BeTrue())
for _, slave := range state.Slaves {
err = slave.Client.ClusterFailover().Err()
err = slave.Client.ClusterFailover(ctx).Err()
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool {
state, _ := client.LoadState()
return state.IsConsistent()
state, _ := client.LoadState(ctx)
return state.IsConsistent(ctx)
}, "30s").Should(BeTrue())
}
})
@ -929,16 +931,16 @@ var _ = Describe("ClusterClient", func() {
BeforeEach(func() {
opt = redisClusterOptions()
opt.RouteByLatency = true
client = cluster.newClusterClient(opt)
client = cluster.newClusterClient(ctx, opt)
err := client.ForEachMaster(func(master *redis.Client) error {
return master.FlushDB().Err()
err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error {
return master.FlushDB(ctx).Err()
})
Expect(err).NotTo(HaveOccurred())
err = client.ForEachSlave(func(slave *redis.Client) error {
err = client.ForEachSlave(ctx, func(ctx context.Context, slave *redis.Client) error {
Eventually(func() int64 {
return client.DBSize().Val()
return client.DBSize(ctx).Val()
}, 30*time.Second).Should(Equal(int64(0)))
return nil
})
@ -946,8 +948,8 @@ var _ = Describe("ClusterClient", func() {
})
AfterEach(func() {
err := client.ForEachSlave(func(slave *redis.Client) error {
return slave.ReadWrite().Err()
err := client.ForEachSlave(ctx, func(ctx context.Context, slave *redis.Client) error {
return slave.ReadWrite(ctx).Err()
})
Expect(err).NotTo(HaveOccurred())
@ -985,16 +987,16 @@ var _ = Describe("ClusterClient", func() {
}}
return slots, nil
}
client = cluster.newClusterClient(opt)
client = cluster.newClusterClient(ctx, opt)
err := client.ForEachMaster(func(master *redis.Client) error {
return master.FlushDB().Err()
err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error {
return master.FlushDB(ctx).Err()
})
Expect(err).NotTo(HaveOccurred())
err = client.ForEachSlave(func(slave *redis.Client) error {
err = client.ForEachSlave(ctx, func(ctx context.Context, slave *redis.Client) error {
Eventually(func() int64 {
return client.DBSize().Val()
return client.DBSize(ctx).Val()
}, 30*time.Second).Should(Equal(int64(0)))
return nil
})
@ -1039,16 +1041,16 @@ var _ = Describe("ClusterClient", func() {
}}
return slots, nil
}
client = cluster.newClusterClient(opt)
client = cluster.newClusterClient(ctx, opt)
err := client.ForEachMaster(func(master *redis.Client) error {
return master.FlushDB().Err()
err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error {
return master.FlushDB(ctx).Err()
})
Expect(err).NotTo(HaveOccurred())
err = client.ForEachSlave(func(slave *redis.Client) error {
err = client.ForEachSlave(ctx, func(ctx context.Context, slave *redis.Client) error {
Eventually(func() int64 {
return client.DBSize().Val()
return client.DBSize(ctx).Val()
}, 30*time.Second).Should(Equal(int64(0)))
return nil
})
@ -1078,13 +1080,13 @@ var _ = Describe("ClusterClient without nodes", func() {
})
It("Ping returns an error", func() {
err := client.Ping().Err()
err := client.Ping(ctx).Err()
Expect(err).To(MatchError("redis: cluster has no nodes"))
})
It("pipeline returns an error", func() {
_, err := client.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
_, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping(ctx)
return nil
})
Expect(err).To(MatchError("redis: cluster has no nodes"))
@ -1105,13 +1107,13 @@ var _ = Describe("ClusterClient without valid nodes", func() {
})
It("returns an error", func() {
err := client.Ping().Err()
err := client.Ping(ctx).Err()
Expect(err).To(MatchError("ERR This instance has cluster support disabled"))
})
It("pipeline returns an error", func() {
_, err := client.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
_, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping(ctx)
return nil
})
Expect(err).To(MatchError("ERR This instance has cluster support disabled"))
@ -1123,7 +1125,7 @@ var _ = Describe("ClusterClient with unavailable Cluster", func() {
BeforeEach(func() {
for _, node := range cluster.clients {
err := node.ClientPause(5 * time.Second).Err()
err := node.ClientPause(ctx, 5*time.Second).Err()
Expect(err).NotTo(HaveOccurred())
}
@ -1139,11 +1141,11 @@ var _ = Describe("ClusterClient with unavailable Cluster", func() {
})
It("recovers when Cluster recovers", func() {
err := client.Ping().Err()
err := client.Ping(ctx).Err()
Expect(err).To(HaveOccurred())
Eventually(func() error {
return client.Ping().Err()
return client.Ping(ctx).Err()
}, "30s").ShouldNot(HaveOccurred())
})
})
@ -1157,14 +1159,14 @@ var _ = Describe("ClusterClient timeout", func() {
testTimeout := func() {
It("Ping timeouts", func() {
err := client.Ping().Err()
err := client.Ping(ctx).Err()
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Pipeline timeouts", func() {
_, err := client.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
_, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping(ctx)
return nil
})
Expect(err).To(HaveOccurred())
@ -1172,17 +1174,17 @@ var _ = Describe("ClusterClient timeout", func() {
})
It("Tx timeouts", func() {
err := client.Watch(func(tx *redis.Tx) error {
return tx.Ping().Err()
err := client.Watch(ctx, func(tx *redis.Tx) error {
return tx.Ping(ctx).Err()
}, "foo")
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Tx Pipeline timeouts", func() {
err := client.Watch(func(tx *redis.Tx) error {
_, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
err := client.Watch(ctx, func(tx *redis.Tx) error {
_, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping(ctx)
return nil
})
return err
@ -1200,19 +1202,19 @@ var _ = Describe("ClusterClient timeout", func() {
opt.ReadTimeout = 250 * time.Millisecond
opt.WriteTimeout = 250 * time.Millisecond
opt.MaxRedirects = 1
client = cluster.newClusterClient(opt)
client = cluster.newClusterClient(ctx, opt)
err := client.ForEachNode(func(client *redis.Client) error {
return client.ClientPause(pause).Err()
err := client.ForEachNode(ctx, func(ctx context.Context, client *redis.Client) error {
return client.ClientPause(ctx, pause).Err()
})
Expect(err).NotTo(HaveOccurred())
})
AfterEach(func() {
_ = client.ForEachNode(func(client *redis.Client) error {
_ = client.ForEachNode(ctx, func(ctx context.Context, client *redis.Client) error {
defer GinkgoRecover()
Eventually(func() error {
return client.Ping().Err()
return client.Ping(ctx).Err()
}, 2*pause).ShouldNot(HaveOccurred())
return nil
})

View File

@ -1,6 +1,7 @@
package redis
import (
"context"
"fmt"
"net"
"strconv"
@ -95,6 +96,7 @@ func cmdFirstKeyPos(cmd Cmder, info *CommandInfo) int {
//------------------------------------------------------------------------------
type baseCmd struct {
ctx context.Context
args []interface{}
err error
@ -147,9 +149,12 @@ type Cmd struct {
val interface{}
}
func NewCmd(args ...interface{}) *Cmd {
func NewCmd(ctx context.Context, args ...interface{}) *Cmd {
return &Cmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -308,9 +313,12 @@ type SliceCmd struct {
var _ Cmder = (*SliceCmd)(nil)
func NewSliceCmd(args ...interface{}) *SliceCmd {
func NewSliceCmd(ctx context.Context, args ...interface{}) *SliceCmd {
return &SliceCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -346,9 +354,12 @@ type StatusCmd struct {
var _ Cmder = (*StatusCmd)(nil)
func NewStatusCmd(args ...interface{}) *StatusCmd {
func NewStatusCmd(ctx context.Context, args ...interface{}) *StatusCmd {
return &StatusCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -379,9 +390,12 @@ type IntCmd struct {
var _ Cmder = (*IntCmd)(nil)
func NewIntCmd(args ...interface{}) *IntCmd {
func NewIntCmd(ctx context.Context, args ...interface{}) *IntCmd {
return &IntCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -416,9 +430,12 @@ type IntSliceCmd struct {
var _ Cmder = (*IntSliceCmd)(nil)
func NewIntSliceCmd(args ...interface{}) *IntSliceCmd {
func NewIntSliceCmd(ctx context.Context, args ...interface{}) *IntSliceCmd {
return &IntSliceCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -460,9 +477,12 @@ type DurationCmd struct {
var _ Cmder = (*DurationCmd)(nil)
func NewDurationCmd(precision time.Duration, args ...interface{}) *DurationCmd {
func NewDurationCmd(ctx context.Context, precision time.Duration, args ...interface{}) *DurationCmd {
return &DurationCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
precision: precision,
}
}
@ -506,9 +526,12 @@ type TimeCmd struct {
var _ Cmder = (*TimeCmd)(nil)
func NewTimeCmd(args ...interface{}) *TimeCmd {
func NewTimeCmd(ctx context.Context, args ...interface{}) *TimeCmd {
return &TimeCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -556,9 +579,12 @@ type BoolCmd struct {
var _ Cmder = (*BoolCmd)(nil)
func NewBoolCmd(args ...interface{}) *BoolCmd {
func NewBoolCmd(ctx context.Context, args ...interface{}) *BoolCmd {
return &BoolCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -610,9 +636,12 @@ type StringCmd struct {
var _ Cmder = (*StringCmd)(nil)
func NewStringCmd(args ...interface{}) *StringCmd {
func NewStringCmd(ctx context.Context, args ...interface{}) *StringCmd {
return &StringCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -700,9 +729,12 @@ type FloatCmd struct {
var _ Cmder = (*FloatCmd)(nil)
func NewFloatCmd(args ...interface{}) *FloatCmd {
func NewFloatCmd(ctx context.Context, args ...interface{}) *FloatCmd {
return &FloatCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -733,9 +765,12 @@ type StringSliceCmd struct {
var _ Cmder = (*StringSliceCmd)(nil)
func NewStringSliceCmd(args ...interface{}) *StringSliceCmd {
func NewStringSliceCmd(ctx context.Context, args ...interface{}) *StringSliceCmd {
return &StringSliceCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -783,9 +818,12 @@ type BoolSliceCmd struct {
var _ Cmder = (*BoolSliceCmd)(nil)
func NewBoolSliceCmd(args ...interface{}) *BoolSliceCmd {
func NewBoolSliceCmd(ctx context.Context, args ...interface{}) *BoolSliceCmd {
return &BoolSliceCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -826,9 +864,12 @@ type StringStringMapCmd struct {
var _ Cmder = (*StringStringMapCmd)(nil)
func NewStringStringMapCmd(args ...interface{}) *StringStringMapCmd {
func NewStringStringMapCmd(ctx context.Context, args ...interface{}) *StringStringMapCmd {
return &StringStringMapCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -875,9 +916,12 @@ type StringIntMapCmd struct {
var _ Cmder = (*StringIntMapCmd)(nil)
func NewStringIntMapCmd(args ...interface{}) *StringIntMapCmd {
func NewStringIntMapCmd(ctx context.Context, args ...interface{}) *StringIntMapCmd {
return &StringIntMapCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -924,9 +968,12 @@ type StringStructMapCmd struct {
var _ Cmder = (*StringStructMapCmd)(nil)
func NewStringStructMapCmd(args ...interface{}) *StringStructMapCmd {
func NewStringStructMapCmd(ctx context.Context, args ...interface{}) *StringStructMapCmd {
return &StringStructMapCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -972,9 +1019,12 @@ type XMessageSliceCmd struct {
var _ Cmder = (*XMessageSliceCmd)(nil)
func NewXMessageSliceCmd(args ...interface{}) *XMessageSliceCmd {
func NewXMessageSliceCmd(ctx context.Context, args ...interface{}) *XMessageSliceCmd {
return &XMessageSliceCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -1069,9 +1119,12 @@ type XStreamSliceCmd struct {
var _ Cmder = (*XStreamSliceCmd)(nil)
func NewXStreamSliceCmd(args ...interface{}) *XStreamSliceCmd {
func NewXStreamSliceCmd(ctx context.Context, args ...interface{}) *XStreamSliceCmd {
return &XStreamSliceCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -1138,9 +1191,12 @@ type XPendingCmd struct {
var _ Cmder = (*XPendingCmd)(nil)
func NewXPendingCmd(args ...interface{}) *XPendingCmd {
func NewXPendingCmd(ctx context.Context, args ...interface{}) *XPendingCmd {
return &XPendingCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -1237,9 +1293,12 @@ type XPendingExtCmd struct {
var _ Cmder = (*XPendingExtCmd)(nil)
func NewXPendingExtCmd(args ...interface{}) *XPendingExtCmd {
func NewXPendingExtCmd(ctx context.Context, args ...interface{}) *XPendingExtCmd {
return &XPendingExtCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -1317,9 +1376,12 @@ type XInfoGroups struct {
var _ Cmder = (*XInfoGroupsCmd)(nil)
func NewXInfoGroupsCmd(stream string) *XInfoGroupsCmd {
func NewXInfoGroupsCmd(ctx context.Context, stream string) *XInfoGroupsCmd {
return &XInfoGroupsCmd{
baseCmd: baseCmd{args: []interface{}{"xinfo", "groups", stream}},
baseCmd: baseCmd{
ctx: ctx,
args: []interface{}{"xinfo", "groups", stream},
},
}
}
@ -1401,9 +1463,12 @@ type ZSliceCmd struct {
var _ Cmder = (*ZSliceCmd)(nil)
func NewZSliceCmd(args ...interface{}) *ZSliceCmd {
func NewZSliceCmd(ctx context.Context, args ...interface{}) *ZSliceCmd {
return &ZSliceCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -1453,9 +1518,12 @@ type ZWithKeyCmd struct {
var _ Cmder = (*ZWithKeyCmd)(nil)
func NewZWithKeyCmd(args ...interface{}) *ZWithKeyCmd {
func NewZWithKeyCmd(ctx context.Context, args ...interface{}) *ZWithKeyCmd {
return &ZWithKeyCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -1508,14 +1576,17 @@ type ScanCmd struct {
page []string
cursor uint64
process func(cmd Cmder) error
process cmdable
}
var _ Cmder = (*ScanCmd)(nil)
func NewScanCmd(process func(cmd Cmder) error, args ...interface{}) *ScanCmd {
func NewScanCmd(ctx context.Context, process cmdable, args ...interface{}) *ScanCmd {
return &ScanCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
process: process,
}
}
@ -1565,9 +1636,12 @@ type ClusterSlotsCmd struct {
var _ Cmder = (*ClusterSlotsCmd)(nil)
func NewClusterSlotsCmd(args ...interface{}) *ClusterSlotsCmd {
func NewClusterSlotsCmd(ctx context.Context, args ...interface{}) *ClusterSlotsCmd {
return &ClusterSlotsCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -1682,10 +1756,13 @@ type GeoLocationCmd struct {
var _ Cmder = (*GeoLocationCmd)(nil)
func NewGeoLocationCmd(q *GeoRadiusQuery, args ...interface{}) *GeoLocationCmd {
func NewGeoLocationCmd(ctx context.Context, q *GeoRadiusQuery, args ...interface{}) *GeoLocationCmd {
return &GeoLocationCmd{
baseCmd: baseCmd{args: geoLocationArgs(q, args...)},
q: q,
baseCmd: baseCmd{
ctx: ctx,
args: geoLocationArgs(q, args...),
},
q: q,
}
}
@ -1826,9 +1903,12 @@ type GeoPosCmd struct {
var _ Cmder = (*GeoPosCmd)(nil)
func NewGeoPosCmd(args ...interface{}) *GeoPosCmd {
func NewGeoPosCmd(ctx context.Context, args ...interface{}) *GeoPosCmd {
return &GeoPosCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
@ -1899,9 +1979,12 @@ type CommandsInfoCmd struct {
var _ Cmder = (*CommandsInfoCmd)(nil)
func NewCommandsInfoCmd(args ...interface{}) *CommandsInfoCmd {
func NewCommandsInfoCmd(ctx context.Context, args ...interface{}) *CommandsInfoCmd {
return &CommandsInfoCmd{
baseCmd: baseCmd{args: args},
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}

View File

@ -15,7 +15,7 @@ var _ = Describe("Cmd", func() {
BeforeEach(func() {
client = redis.NewClient(redisOptions())
Expect(client.FlushDB().Err()).NotTo(HaveOccurred())
Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
})
AfterEach(func() {
@ -23,19 +23,19 @@ var _ = Describe("Cmd", func() {
})
It("implements Stringer", func() {
set := client.Set("foo", "bar", 0)
set := client.Set(ctx, "foo", "bar", 0)
Expect(set.String()).To(Equal("set foo bar: OK"))
get := client.Get("foo")
get := client.Get(ctx, "foo")
Expect(get.String()).To(Equal("get foo: bar"))
})
It("has val/err", func() {
set := client.Set("key", "hello", 0)
set := client.Set(ctx, "key", "hello", 0)
Expect(set.Err()).NotTo(HaveOccurred())
Expect(set.Val()).To(Equal("OK"))
get := client.Get("key")
get := client.Get(ctx, "key")
Expect(get.Err()).NotTo(HaveOccurred())
Expect(get.Val()).To(Equal("hello"))
@ -44,18 +44,18 @@ var _ = Describe("Cmd", func() {
})
It("has helpers", func() {
set := client.Set("key", "10", 0)
set := client.Set(ctx, "key", "10", 0)
Expect(set.Err()).NotTo(HaveOccurred())
n, err := client.Get("key").Int64()
n, err := client.Get(ctx, "key").Int64()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(10)))
un, err := client.Get("key").Uint64()
un, err := client.Get(ctx, "key").Uint64()
Expect(err).NotTo(HaveOccurred())
Expect(un).To(Equal(uint64(10)))
f, err := client.Get("key").Float64()
f, err := client.Get(ctx, "key").Float64()
Expect(err).NotTo(HaveOccurred())
Expect(f).To(Equal(float64(10)))
})
@ -63,10 +63,10 @@ var _ = Describe("Cmd", func() {
It("supports float32", func() {
f := float32(66.97)
err := client.Set("float_key", f, 0).Err()
err := client.Set(ctx, "float_key", f, 0).Err()
Expect(err).NotTo(HaveOccurred())
val, err := client.Get("float_key").Float32()
val, err := client.Get(ctx, "float_key").Float32()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal(f))
})
@ -74,14 +74,14 @@ var _ = Describe("Cmd", func() {
It("supports time.Time", func() {
tm := time.Date(2019, 01, 01, 0, 0, 0, 0, time.UTC)
err := client.Set("time_key", tm, 0).Err()
err := client.Set(ctx, "time_key", tm, 0).Err()
Expect(err).NotTo(HaveOccurred())
s, err := client.Get("time_key").Result()
s, err := client.Get(ctx, "time_key").Result()
Expect(err).NotTo(HaveOccurred())
Expect(s).To(Equal("2019-01-01T00:00:00Z"))
tm2, err := client.Get("time_key").Time()
tm2, err := client.Get(ctx, "time_key").Time()
Expect(err).NotTo(HaveOccurred())
Expect(tm2).To(BeTemporally("==", tm))
})

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -37,7 +37,7 @@ func Example_instrumentation() {
})
rdb.AddHook(redisHook{})
rdb.Ping()
rdb.Ping(ctx)
// Output: starting processing: <ping: >
// finished processing: <ping: PONG>
}
@ -48,9 +48,9 @@ func ExamplePipeline_instrumentation() {
})
rdb.AddHook(redisHook{})
rdb.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
pipe.Ping()
rdb.Pipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping(ctx)
pipe.Ping(ctx)
return nil
})
// Output: pipeline starting processing: [ping: ping: ]
@ -63,9 +63,9 @@ func ExampleWatch_instrumentation() {
})
rdb.AddHook(redisHook{})
rdb.Watch(func(tx *redis.Tx) error {
tx.Ping()
tx.Ping()
rdb.Watch(ctx, func(tx *redis.Tx) error {
tx.Ping(ctx)
tx.Ping(ctx)
return nil
}, "foo")
// Output:

View File

@ -1,6 +1,7 @@
package redis_test
import (
"context"
"errors"
"fmt"
"sync"
@ -9,6 +10,7 @@ import (
"github.com/go-redis/redis/v7"
)
var ctx = context.Background()
var rdb *redis.Client
func init() {
@ -29,7 +31,7 @@ func ExampleNewClient() {
DB: 0, // use default DB
})
pong, err := rdb.Ping().Result()
pong, err := rdb.Ping(ctx).Result()
fmt.Println(pong, err)
// Output: PONG <nil>
}
@ -58,7 +60,7 @@ func ExampleNewFailoverClient() {
MasterName: "master",
SentinelAddrs: []string{":26379"},
})
rdb.Ping()
rdb.Ping(ctx)
}
func ExampleNewClusterClient() {
@ -67,7 +69,7 @@ func ExampleNewClusterClient() {
rdb := redis.NewClusterClient(&redis.ClusterOptions{
Addrs: []string{":7000", ":7001", ":7002", ":7003", ":7004", ":7005"},
})
rdb.Ping()
rdb.Ping(ctx)
}
// Following example creates a cluster from 2 master nodes and 2 slave nodes
@ -106,11 +108,11 @@ func ExampleNewClusterClient_manualSetup() {
ClusterSlots: clusterSlots,
RouteRandomly: true,
})
rdb.Ping()
rdb.Ping(ctx)
// ReloadState reloads cluster state. It calls ClusterSlots func
// to get cluster slots information.
err := rdb.ReloadState()
err := rdb.ReloadState(ctx)
if err != nil {
panic(err)
}
@ -124,22 +126,22 @@ func ExampleNewRing() {
"shard3": ":7002",
},
})
rdb.Ping()
rdb.Ping(ctx)
}
func ExampleClient() {
err := rdb.Set("key", "value", 0).Err()
err := rdb.Set(ctx, "key", "value", 0).Err()
if err != nil {
panic(err)
}
val, err := rdb.Get("key").Result()
val, err := rdb.Get(ctx, "key").Result()
if err != nil {
panic(err)
}
fmt.Println("key", val)
val2, err := rdb.Get("missing_key").Result()
val2, err := rdb.Get(ctx, "missing_key").Result()
if err == redis.Nil {
fmt.Println("missing_key does not exist")
} else if err != nil {
@ -154,17 +156,17 @@ func ExampleClient() {
func ExampleConn() {
conn := rdb.Conn()
err := conn.ClientSetName("foobar").Err()
err := conn.ClientSetName(ctx, "foobar").Err()
if err != nil {
panic(err)
}
// Open other connections.
for i := 0; i < 10; i++ {
go rdb.Ping()
go rdb.Ping(ctx)
}
s, err := conn.ClientGetName().Result()
s, err := conn.ClientGetName(ctx).Result()
if err != nil {
panic(err)
}
@ -175,20 +177,20 @@ func ExampleConn() {
func ExampleClient_Set() {
// Last argument is expiration. Zero means the key has no
// expiration time.
err := rdb.Set("key", "value", 0).Err()
err := rdb.Set(ctx, "key", "value", 0).Err()
if err != nil {
panic(err)
}
// key2 will expire in an hour.
err = rdb.Set("key2", "value", time.Hour).Err()
err = rdb.Set(ctx, "key2", "value", time.Hour).Err()
if err != nil {
panic(err)
}
}
func ExampleClient_Incr() {
result, err := rdb.Incr("counter").Result()
result, err := rdb.Incr(ctx, "counter").Result()
if err != nil {
panic(err)
}
@ -198,12 +200,12 @@ func ExampleClient_Incr() {
}
func ExampleClient_BLPop() {
if err := rdb.RPush("queue", "message").Err(); err != nil {
if err := rdb.RPush(ctx, "queue", "message").Err(); err != nil {
panic(err)
}
// use `rdb.BLPop(0, "queue")` for infinite waiting time
result, err := rdb.BLPop(1*time.Second, "queue").Result()
result, err := rdb.BLPop(ctx, 1*time.Second, "queue").Result()
if err != nil {
panic(err)
}
@ -213,9 +215,9 @@ func ExampleClient_BLPop() {
}
func ExampleClient_Scan() {
rdb.FlushDB()
rdb.FlushDB(ctx)
for i := 0; i < 33; i++ {
err := rdb.Set(fmt.Sprintf("key%d", i), "value", 0).Err()
err := rdb.Set(ctx, fmt.Sprintf("key%d", i), "value", 0).Err()
if err != nil {
panic(err)
}
@ -226,7 +228,7 @@ func ExampleClient_Scan() {
for {
var keys []string
var err error
keys, cursor, err = rdb.Scan(cursor, "key*", 10).Result()
keys, cursor, err = rdb.Scan(ctx, cursor, "key*", 10).Result()
if err != nil {
panic(err)
}
@ -242,9 +244,9 @@ func ExampleClient_Scan() {
func ExampleClient_Pipelined() {
var incr *redis.IntCmd
_, err := rdb.Pipelined(func(pipe redis.Pipeliner) error {
incr = pipe.Incr("pipelined_counter")
pipe.Expire("pipelined_counter", time.Hour)
_, err := rdb.Pipelined(ctx, func(pipe redis.Pipeliner) error {
incr = pipe.Incr(ctx, "pipelined_counter")
pipe.Expire(ctx, "pipelined_counter", time.Hour)
return nil
})
fmt.Println(incr.Val(), err)
@ -254,8 +256,8 @@ func ExampleClient_Pipelined() {
func ExampleClient_Pipeline() {
pipe := rdb.Pipeline()
incr := pipe.Incr("pipeline_counter")
pipe.Expire("pipeline_counter", time.Hour)
incr := pipe.Incr(ctx, "pipeline_counter")
pipe.Expire(ctx, "pipeline_counter", time.Hour)
// Execute
//
@ -263,16 +265,16 @@ func ExampleClient_Pipeline() {
// EXPIRE pipeline_counts 3600
//
// using one rdb-server roundtrip.
_, err := pipe.Exec()
_, err := pipe.Exec(ctx)
fmt.Println(incr.Val(), err)
// Output: 1 <nil>
}
func ExampleClient_TxPipelined() {
var incr *redis.IntCmd
_, err := rdb.TxPipelined(func(pipe redis.Pipeliner) error {
incr = pipe.Incr("tx_pipelined_counter")
pipe.Expire("tx_pipelined_counter", time.Hour)
_, err := rdb.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
incr = pipe.Incr(ctx, "tx_pipelined_counter")
pipe.Expire(ctx, "tx_pipelined_counter", time.Hour)
return nil
})
fmt.Println(incr.Val(), err)
@ -282,8 +284,8 @@ func ExampleClient_TxPipelined() {
func ExampleClient_TxPipeline() {
pipe := rdb.TxPipeline()
incr := pipe.Incr("tx_pipeline_counter")
pipe.Expire("tx_pipeline_counter", time.Hour)
incr := pipe.Incr(ctx, "tx_pipeline_counter")
pipe.Expire(ctx, "tx_pipeline_counter", time.Hour)
// Execute
//
@ -293,7 +295,7 @@ func ExampleClient_TxPipeline() {
// EXEC
//
// using one rdb-server roundtrip.
_, err := pipe.Exec()
_, err := pipe.Exec(ctx)
fmt.Println(incr.Val(), err)
// Output: 1 <nil>
}
@ -305,7 +307,7 @@ func ExampleClient_Watch() {
increment := func(key string) error {
txf := func(tx *redis.Tx) error {
// get current value or zero
n, err := tx.Get(key).Int()
n, err := tx.Get(ctx, key).Int()
if err != nil && err != redis.Nil {
return err
}
@ -314,16 +316,16 @@ func ExampleClient_Watch() {
n++
// runs only if the watched keys remain unchanged
_, err = tx.TxPipelined(func(pipe redis.Pipeliner) error {
_, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
// pipe handles the error case
pipe.Set(key, n, 0)
pipe.Set(ctx, key, n, 0)
return nil
})
return err
}
for retries := routineCount; retries > 0; retries-- {
err := rdb.Watch(txf, key)
err := rdb.Watch(ctx, txf, key)
if err != redis.TxFailedErr {
return err
}
@ -345,16 +347,16 @@ func ExampleClient_Watch() {
}
wg.Wait()
n, err := rdb.Get("counter3").Int()
n, err := rdb.Get(ctx, "counter3").Int()
fmt.Println("ended with", n, err)
// Output: ended with 100 <nil>
}
func ExamplePubSub() {
pubsub := rdb.Subscribe("mychannel1")
pubsub := rdb.Subscribe(ctx, "mychannel1")
// Wait for confirmation that subscription is created before publishing anything.
_, err := pubsub.Receive()
_, err := pubsub.Receive(ctx)
if err != nil {
panic(err)
}
@ -363,7 +365,7 @@ func ExamplePubSub() {
ch := pubsub.Channel()
// Publish a message.
err = rdb.Publish("mychannel1", "hello").Err()
err = rdb.Publish(ctx, "mychannel1", "hello").Err()
if err != nil {
panic(err)
}
@ -382,12 +384,12 @@ func ExamplePubSub() {
}
func ExamplePubSub_Receive() {
pubsub := rdb.Subscribe("mychannel2")
pubsub := rdb.Subscribe(ctx, "mychannel2")
defer pubsub.Close()
for i := 0; i < 2; i++ {
// ReceiveTimeout is a low level API. Use ReceiveMessage instead.
msgi, err := pubsub.ReceiveTimeout(time.Second)
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
if err != nil {
break
}
@ -396,7 +398,7 @@ func ExamplePubSub_Receive() {
case *redis.Subscription:
fmt.Println("subscribed to", msg.Channel)
_, err := rdb.Publish("mychannel2", "hello").Result()
_, err := rdb.Publish(ctx, "mychannel2", "hello").Result()
if err != nil {
panic(err)
}
@ -419,15 +421,15 @@ func ExampleScript() {
return false
`)
n, err := IncrByXX.Run(rdb, []string{"xx_counter"}, 2).Result()
n, err := IncrByXX.Run(ctx, rdb, []string{"xx_counter"}, 2).Result()
fmt.Println(n, err)
err = rdb.Set("xx_counter", "40", 0).Err()
err = rdb.Set(ctx, "xx_counter", "40", 0).Err()
if err != nil {
panic(err)
}
n, err = IncrByXX.Run(rdb, []string{"xx_counter"}, 2).Result()
n, err = IncrByXX.Run(ctx, rdb, []string{"xx_counter"}, 2).Result()
fmt.Println(n, err)
// Output: <nil> redis: nil
@ -435,26 +437,26 @@ func ExampleScript() {
}
func Example_customCommand() {
Get := func(rdb *redis.Client, key string) *redis.StringCmd {
cmd := redis.NewStringCmd("get", key)
rdb.Process(cmd)
Get := func(ctx context.Context, rdb *redis.Client, key string) *redis.StringCmd {
cmd := redis.NewStringCmd(ctx, "get", key)
rdb.Process(ctx, cmd)
return cmd
}
v, err := Get(rdb, "key_does_not_exist").Result()
v, err := Get(ctx, rdb, "key_does_not_exist").Result()
fmt.Printf("%q %s", v, err)
// Output: "" redis: nil
}
func Example_customCommand2() {
v, err := rdb.Do("get", "key_does_not_exist").Text()
v, err := rdb.Do(ctx, "get", "key_does_not_exist").Text()
fmt.Printf("%q %s", v, err)
// Output: "" redis: nil
}
func ExampleScanIterator() {
iter := rdb.Scan(0, "", 0).Iterator()
for iter.Next() {
iter := rdb.Scan(ctx, 0, "", 0).Iterator()
for iter.Next(ctx) {
fmt.Println(iter.Val())
}
if err := iter.Err(); err != nil {
@ -463,8 +465,8 @@ func ExampleScanIterator() {
}
func ExampleScanCmd_Iterator() {
iter := rdb.Scan(0, "", 0).Iterator()
for iter.Next() {
iter := rdb.Scan(ctx, 0, "", 0).Iterator()
for iter.Next(ctx) {
fmt.Println(iter.Val())
}
if err := iter.Err(); err != nil {
@ -478,7 +480,7 @@ func ExampleNewUniversalClient_simple() {
})
defer rdb.Close()
rdb.Ping()
rdb.Ping(ctx)
}
func ExampleNewUniversalClient_failover() {
@ -488,7 +490,7 @@ func ExampleNewUniversalClient_failover() {
})
defer rdb.Close()
rdb.Ping()
rdb.Ping(ctx)
}
func ExampleNewUniversalClient_cluster() {
@ -497,5 +499,5 @@ func ExampleNewUniversalClient_cluster() {
})
defer rdb.Close()
rdb.Ping()
rdb.Ping(ctx)
}

View File

@ -1,6 +1,7 @@
package redis
import (
"context"
"fmt"
"net"
"strings"
@ -17,12 +18,12 @@ func (c *PubSub) SetNetConn(netConn net.Conn) {
c.cn = pool.NewConn(netConn)
}
func (c *ClusterClient) LoadState() (*clusterState, error) {
return c.loadState()
func (c *ClusterClient) LoadState(ctx context.Context) (*clusterState, error) {
return c.loadState(ctx)
}
func (c *ClusterClient) SlotAddrs(slot int) []string {
state, err := c.state.Get()
func (c *ClusterClient) SlotAddrs(ctx context.Context, slot int) []string {
state, err := c.state.Get(ctx)
if err != nil {
panic(err)
}
@ -34,8 +35,8 @@ func (c *ClusterClient) SlotAddrs(slot int) []string {
return addrs
}
func (c *ClusterClient) Nodes(key string) ([]*clusterNode, error) {
state, err := c.state.Reload()
func (c *ClusterClient) Nodes(ctx context.Context, key string) ([]*clusterNode, error) {
state, err := c.state.Reload(ctx)
if err != nil {
return nil, err
}
@ -48,8 +49,8 @@ func (c *ClusterClient) Nodes(key string) ([]*clusterNode, error) {
return nodes, nil
}
func (c *ClusterClient) SwapNodes(key string) error {
nodes, err := c.Nodes(key)
func (c *ClusterClient) SwapNodes(ctx context.Context, key string) error {
nodes, err := c.Nodes(ctx, key)
if err != nil {
return err
}
@ -57,12 +58,12 @@ func (c *ClusterClient) SwapNodes(key string) error {
return nil
}
func (state *clusterState) IsConsistent() bool {
func (state *clusterState) IsConsistent(ctx context.Context) bool {
if len(state.Masters) < 3 {
return false
}
for _, master := range state.Masters {
s := master.Client.Info("replication").Val()
s := master.Client.Info(ctx, "replication").Val()
if !strings.Contains(s, "role:master") {
return false
}
@ -72,7 +73,7 @@ func (state *clusterState) IsConsistent() bool {
return false
}
for _, slave := range state.Slaves {
s := slave.Client.Info("replication").Val()
s := slave.Client.Info(ctx, "replication").Val()
if !strings.Contains(s, "role:slave") {
return false
}

3
go.mod
View File

@ -5,11 +5,12 @@ require (
github.com/kr/pretty v0.1.0 // indirect
github.com/onsi/ginkgo v1.10.1
github.com/onsi/gomega v1.7.0
go.opentelemetry.io/otel v0.2.3
golang.org/x/net v0.0.0-20190923162816-aa69164e4478 // indirect
golang.org/x/sys v0.0.0-20191010194322-b09406accb47 // indirect
golang.org/x/text v0.3.2 // indirect
google.golang.org/grpc v1.27.1
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
gopkg.in/yaml.v2 v2.2.4 // indirect
)
go 1.11

54
go.sum
View File

@ -1,9 +1,24 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/DataDog/sketches-go v0.0.0-20190923095040-43f19ad77ff7/go.mod h1:Q5DbzQ+3AkgGwymQO7aZFNP7ns2lZKGtvRBzRXfdi60=
github.com/benbjohnson/clock v1.0.0/go.mod h1:bGMdMPoPVvcYyt1gHDf4J2KE153Yf9BuiUKYMaxlTDM=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
@ -16,13 +31,33 @@ github.com/onsi/ginkgo v1.10.1 h1:q/mM8GF/n0shIN8SaAZ0V+jnLPzen6WIVZdiwrRlMlo=
github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/gomega v1.7.0 h1:XPnZz8VVBHjVsy1vzJmRwIcSwiUO+JFfrv/xGiigmME=
github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/opentracing/opentracing-go v1.1.1-0.20190913142402-a7454ce5950e/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
go.opentelemetry.io v0.1.0 h1:EANZoRCOP+A3faIlw/iN6YEWoYb1vleZRKm1EvH8T48=
go.opentelemetry.io/otel v0.2.3 h1:o97YpRYk0PyhCyuanlJY0DepUgAlyzl3rJ+4kb+456c=
go.opentelemetry.io/otel v0.2.3/go.mod h1:OgNpQOjrlt33Ew6Ds0mGjmcTQg/rhUctsbkRdk/g1fw=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190923162816-aa69164e4478 h1:l5EDrHhldLYb3ZRHDUhXF7Om7MvYXnkV9/iQNo1lX6g=
golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e h1:o3PsSEY8E4eXWkXrIP9YJALUkVZqzHJT5DOasTyn8Vs=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -33,6 +68,20 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20191009194640-548a555dbc03/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.27.1 h1:zvIju4sqAGvwKspUQOhwnpcqSbzi7/H6QomNNjTL4sk=
google.golang.org/grpc v1.27.1/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
@ -43,5 +92,10 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkep
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.4 h1:/eiJrUcujPVeJ3xlSWaiNi3uSVmDGBK1pDHUHAnao1I=
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

View File

@ -6,6 +6,7 @@ import (
"sync/atomic"
"time"
"github.com/go-redis/redis/v7/internal"
"github.com/go-redis/redis/v7/internal/proto"
)
@ -58,31 +59,35 @@ func (cn *Conn) RemoteAddr() net.Addr {
}
func (cn *Conn) WithReader(ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error) error {
err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout))
if err != nil {
return err
}
return fn(cn.rd)
return internal.WithSpan(ctx, "with_reader", func(ctx context.Context) error {
err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout))
if err != nil {
return err
}
return fn(cn.rd)
})
}
func (cn *Conn) WithWriter(
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error,
) error {
err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout))
if err != nil {
return err
}
return internal.WithSpan(ctx, "with_writer", func(ctx context.Context) error {
err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout))
if err != nil {
return err
}
if cn.wr.Buffered() > 0 {
cn.wr.Reset(cn.netConn)
}
if cn.wr.Buffered() > 0 {
cn.wr.Reset(cn.netConn)
}
err = fn(cn.wr)
if err != nil {
return err
}
err = fn(cn.wr)
if err != nil {
return err
}
return cn.wr.Flush()
return cn.wr.Flush()
})
}
func (cn *Conn) Close() error {

View File

@ -2,21 +2,28 @@ package internal
import (
"context"
"reflect"
"time"
"github.com/go-redis/redis/v7/internal/util"
"go.opentelemetry.io/otel/api/core"
"go.opentelemetry.io/otel/api/global"
"go.opentelemetry.io/otel/api/trace"
"google.golang.org/grpc/codes"
)
func Sleep(ctx context.Context, dur time.Duration) error {
t := time.NewTimer(dur)
defer t.Stop()
return WithSpan(ctx, "sleep", func(ctx context.Context) error {
t := time.NewTimer(dur)
defer t.Stop()
select {
case <-t.C:
return nil
case <-ctx.Done():
return ctx.Err()
}
select {
case <-t.C:
return nil
case <-ctx.Done():
return ctx.Err()
}
})
}
func ToLower(s string) string {
@ -54,3 +61,27 @@ func Unwrap(err error) error {
}
return u.Unwrap()
}
var (
logTypeKey = core.Key("log.type")
logMessageKey = core.Key("log.message")
)
func WithSpan(ctx context.Context, name string, fn func(context.Context) error) error {
if !trace.SpanFromContext(ctx).IsRecording() {
return fn(ctx)
}
ctx, span := global.TraceProvider().Tracer("go-redis").Start(ctx, name)
defer span.End()
if err := fn(ctx); err != nil {
span.SetStatus(codes.Internal)
span.AddEvent(ctx, "error",
logTypeKey.String(reflect.TypeOf(err).String()),
logMessageKey.String(err.Error()),
)
return err
}
return nil
}

View File

@ -1,6 +1,7 @@
package redis
import (
"context"
"sync"
)
@ -21,7 +22,7 @@ func (it *ScanIterator) Err() error {
}
// Next advances the cursor and returns true if more values can be read.
func (it *ScanIterator) Next() bool {
func (it *ScanIterator) Next(ctx context.Context) bool {
it.mu.Lock()
defer it.mu.Unlock()
@ -49,7 +50,7 @@ func (it *ScanIterator) Next() bool {
it.cmd.args[2] = it.cmd.cursor
}
err := it.cmd.process(it.cmd)
err := it.cmd.process(ctx, it.cmd)
if err != nil {
return false
}

View File

@ -15,21 +15,21 @@ var _ = Describe("ScanIterator", func() {
var seed = func(n int) error {
pipe := client.Pipeline()
for i := 1; i <= n; i++ {
pipe.Set(fmt.Sprintf("K%02d", i), "x", 0).Err()
pipe.Set(ctx, fmt.Sprintf("K%02d", i), "x", 0).Err()
}
_, err := pipe.Exec()
_, err := pipe.Exec(ctx)
return err
}
var extraSeed = func(n int, m int) error {
pipe := client.Pipeline()
for i := 1; i <= m; i++ {
pipe.Set(fmt.Sprintf("A%02d", i), "x", 0).Err()
pipe.Set(ctx, fmt.Sprintf("A%02d", i), "x", 0).Err()
}
for i := 1; i <= n; i++ {
pipe.Set(fmt.Sprintf("K%02d", i), "x", 0).Err()
pipe.Set(ctx, fmt.Sprintf("K%02d", i), "x", 0).Err()
}
_, err := pipe.Exec()
_, err := pipe.Exec(ctx)
return err
}
@ -37,15 +37,15 @@ var _ = Describe("ScanIterator", func() {
var hashSeed = func(n int) error {
pipe := client.Pipeline()
for i := 1; i <= n; i++ {
pipe.HSet(hashKey, fmt.Sprintf("K%02d", i), "x").Err()
pipe.HSet(ctx, hashKey, fmt.Sprintf("K%02d", i), "x").Err()
}
_, err := pipe.Exec()
_, err := pipe.Exec(ctx)
return err
}
BeforeEach(func() {
client = redis.NewClient(redisOptions())
Expect(client.FlushDB().Err()).NotTo(HaveOccurred())
Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
})
AfterEach(func() {
@ -53,8 +53,8 @@ var _ = Describe("ScanIterator", func() {
})
It("should scan across empty DBs", func() {
iter := client.Scan(0, "", 10).Iterator()
Expect(iter.Next()).To(BeFalse())
iter := client.Scan(ctx, 0, "", 10).Iterator()
Expect(iter.Next(ctx)).To(BeFalse())
Expect(iter.Err()).NotTo(HaveOccurred())
})
@ -62,8 +62,8 @@ var _ = Describe("ScanIterator", func() {
Expect(seed(7)).NotTo(HaveOccurred())
var vals []string
iter := client.Scan(0, "", 0).Iterator()
for iter.Next() {
iter := client.Scan(ctx, 0, "", 0).Iterator()
for iter.Next(ctx) {
vals = append(vals, iter.Val())
}
Expect(iter.Err()).NotTo(HaveOccurred())
@ -74,8 +74,8 @@ var _ = Describe("ScanIterator", func() {
Expect(seed(71)).NotTo(HaveOccurred())
var vals []string
iter := client.Scan(0, "", 10).Iterator()
for iter.Next() {
iter := client.Scan(ctx, 0, "", 10).Iterator()
for iter.Next(ctx) {
vals = append(vals, iter.Val())
}
Expect(iter.Err()).NotTo(HaveOccurred())
@ -88,8 +88,8 @@ var _ = Describe("ScanIterator", func() {
Expect(hashSeed(71)).NotTo(HaveOccurred())
var vals []string
iter := client.HScan(hashKey, 0, "", 10).Iterator()
for iter.Next() {
iter := client.HScan(ctx, hashKey, 0, "", 10).Iterator()
for iter.Next(ctx) {
vals = append(vals, iter.Val())
}
Expect(iter.Err()).NotTo(HaveOccurred())
@ -102,8 +102,8 @@ var _ = Describe("ScanIterator", func() {
Expect(seed(20)).NotTo(HaveOccurred())
var vals []string
iter := client.Scan(0, "", 10).Iterator()
for iter.Next() {
iter := client.Scan(ctx, 0, "", 10).Iterator()
for iter.Next(ctx) {
vals = append(vals, iter.Val())
}
Expect(iter.Err()).NotTo(HaveOccurred())
@ -114,8 +114,8 @@ var _ = Describe("ScanIterator", func() {
Expect(seed(33)).NotTo(HaveOccurred())
var vals []string
iter := client.Scan(0, "K*2*", 10).Iterator()
for iter.Next() {
iter := client.Scan(ctx, 0, "K*2*", 10).Iterator()
for iter.Next(ctx) {
vals = append(vals, iter.Val())
}
Expect(iter.Err()).NotTo(HaveOccurred())
@ -126,8 +126,8 @@ var _ = Describe("ScanIterator", func() {
Expect(extraSeed(2, 10)).NotTo(HaveOccurred())
var vals []string
iter := client.Scan(0, "K*", 1).Iterator()
for iter.Next() {
iter := client.Scan(ctx, 0, "K*", 1).Iterator()
for iter.Next(ctx) {
vals = append(vals, iter.Val())
}
Expect(iter.Err()).NotTo(HaveOccurred())

View File

@ -32,10 +32,10 @@ const (
const (
sentinelName = "mymaster"
sentinelMasterPort = "8123"
sentinelSlave1Port = "8124"
sentinelSlave2Port = "8125"
sentinelPort = "8126"
sentinelMasterPort = "9123"
sentinelSlave1Port = "9124"
sentinelSlave2Port = "9125"
sentinelPort = "9126"
)
var (
@ -80,7 +80,7 @@ var _ = BeforeSuite(func() {
sentinelSlave2Port, "--slaveof", "127.0.0.1", sentinelMasterPort)
Expect(err).NotTo(HaveOccurred())
Expect(startCluster(cluster)).NotTo(HaveOccurred())
Expect(startCluster(ctx, cluster)).NotTo(HaveOccurred())
})
var _ = AfterSuite(func() {
@ -223,7 +223,7 @@ func connectTo(port string) (*redis.Client, error) {
})
err := eventually(func() error {
return client.Ping().Err()
return client.Ping(ctx).Err()
}, 30*time.Second)
if err != nil {
return nil, err
@ -243,7 +243,7 @@ func (p *redisProcess) Close() error {
}
err := eventually(func() error {
if err := p.Client.Ping().Err(); err != nil {
if err := p.Client.Ping(ctx).Err(); err != nil {
return nil
}
return errors.New("client is not shutdown")
@ -313,12 +313,12 @@ func startSentinel(port, masterName, masterPort string) (*redisProcess, error) {
return nil, err
}
for _, cmd := range []*redis.StatusCmd{
redis.NewStatusCmd("SENTINEL", "MONITOR", masterName, "127.0.0.1", masterPort, "1"),
redis.NewStatusCmd("SENTINEL", "SET", masterName, "down-after-milliseconds", "500"),
redis.NewStatusCmd("SENTINEL", "SET", masterName, "failover-timeout", "1000"),
redis.NewStatusCmd("SENTINEL", "SET", masterName, "parallel-syncs", "1"),
redis.NewStatusCmd(ctx, "SENTINEL", "MONITOR", masterName, "127.0.0.1", masterPort, "1"),
redis.NewStatusCmd(ctx, "SENTINEL", "SET", masterName, "down-after-milliseconds", "500"),
redis.NewStatusCmd(ctx, "SENTINEL", "SET", masterName, "failover-timeout", "1000"),
redis.NewStatusCmd(ctx, "SENTINEL", "SET", masterName, "parallel-syncs", "1"),
} {
client.Process(cmd)
client.Process(ctx, cmd)
if err := cmd.Err(); err != nil {
process.Kill()
return nil, err

View File

@ -12,6 +12,7 @@ import (
"strings"
"time"
"github.com/go-redis/redis/v7/internal"
"github.com/go-redis/redis/v7/internal/pool"
)
@ -231,7 +232,13 @@ func ParseURL(redisURL string) (*Options, error) {
func newConnPool(opt *Options) *pool.ConnPool {
return pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return opt.Dialer(ctx, opt.Network, opt.Addr)
var conn net.Conn
err := internal.WithSpan(ctx, "dialer", func(ctx context.Context) error {
var err error
conn, err = opt.Dialer(ctx, opt.Network, opt.Addr)
return err
})
return conn, err
},
PoolSize: opt.PoolSize,
MinIdleConns: opt.MinIdleConns,

View File

@ -24,12 +24,11 @@ type pipelineExecer func(context.Context, []Cmder) error
// depends of your batch size and/or use TxPipeline.
type Pipeliner interface {
StatefulCmdable
Do(args ...interface{}) *Cmd
Process(cmd Cmder) error
Do(ctx context.Context, args ...interface{}) *Cmd
Process(ctx context.Context, cmd Cmder) error
Close() error
Discard() error
Exec() ([]Cmder, error)
ExecContext(ctx context.Context) ([]Cmder, error)
Exec(ctx context.Context) ([]Cmder, error)
}
var _ Pipeliner = (*Pipeline)(nil)
@ -54,14 +53,14 @@ func (c *Pipeline) init() {
c.statefulCmdable = c.Process
}
func (c *Pipeline) Do(args ...interface{}) *Cmd {
cmd := NewCmd(args...)
_ = c.Process(cmd)
func (c *Pipeline) Do(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(ctx, args...)
_ = c.Process(ctx, cmd)
return cmd
}
// Process queues the cmd for later execution.
func (c *Pipeline) Process(cmd Cmder) error {
func (c *Pipeline) Process(ctx context.Context, cmd Cmder) error {
c.mu.Lock()
c.cmds = append(c.cmds, cmd)
c.mu.Unlock()
@ -98,11 +97,7 @@ func (c *Pipeline) discard() error {
//
// Exec always returns list of commands and error of the first failed
// command if any.
func (c *Pipeline) Exec() ([]Cmder, error) {
return c.ExecContext(c.ctx)
}
func (c *Pipeline) ExecContext(ctx context.Context) ([]Cmder, error) {
func (c *Pipeline) Exec(ctx context.Context) ([]Cmder, error) {
c.mu.Lock()
defer c.mu.Unlock()
@ -120,11 +115,11 @@ func (c *Pipeline) ExecContext(ctx context.Context) ([]Cmder, error) {
return cmds, c.exec(ctx, cmds)
}
func (c *Pipeline) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
func (c *Pipeline) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
if err := fn(c); err != nil {
return nil, err
}
cmds, err := c.Exec()
cmds, err := c.Exec(ctx)
_ = c.Close()
return cmds, err
}
@ -133,8 +128,8 @@ func (c *Pipeline) Pipeline() Pipeliner {
return c
}
func (c *Pipeline) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipelined(fn)
func (c *Pipeline) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipelined(ctx, fn)
}
func (c *Pipeline) TxPipeline() Pipeliner {

View File

@ -13,7 +13,7 @@ var _ = Describe("pipelining", func() {
BeforeEach(func() {
client = redis.NewClient(redisOptions())
Expect(client.FlushDB().Err()).NotTo(HaveOccurred())
Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
})
AfterEach(func() {
@ -22,8 +22,8 @@ var _ = Describe("pipelining", func() {
It("supports block style", func() {
var get *redis.StringCmd
cmds, err := client.Pipelined(func(pipe redis.Pipeliner) error {
get = pipe.Get("foo")
cmds, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error {
get = pipe.Get(ctx, "foo")
return nil
})
Expect(err).To(Equal(redis.Nil))
@ -35,24 +35,24 @@ var _ = Describe("pipelining", func() {
assertPipeline := func() {
It("returns no errors when there are no commands", func() {
_, err := pipe.Exec()
_, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred())
})
It("discards queued commands", func() {
pipe.Get("key")
pipe.Get(ctx, "key")
pipe.Discard()
cmds, err := pipe.Exec()
cmds, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(BeNil())
})
It("handles val/err", func() {
err := client.Set("key", "value", 0).Err()
err := client.Set(ctx, "key", "value", 0).Err()
Expect(err).NotTo(HaveOccurred())
get := pipe.Get("key")
cmds, err := pipe.Exec()
get := pipe.Get(ctx, "key")
cmds, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(1))
@ -62,8 +62,8 @@ var _ = Describe("pipelining", func() {
})
It("supports custom command", func() {
pipe.Do("ping")
cmds, err := pipe.Exec()
pipe.Do(ctx, "ping")
cmds, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(1))
})

View File

@ -27,7 +27,7 @@ var _ = Describe("pool", func() {
It("respects max size", func() {
perform(1000, func(id int) {
val, err := client.Ping().Result()
val, err := client.Ping(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("PONG"))
})
@ -42,9 +42,9 @@ var _ = Describe("pool", func() {
perform(1000, func(id int) {
var ping *redis.StatusCmd
err := client.Watch(func(tx *redis.Tx) error {
cmds, err := tx.Pipelined(func(pipe redis.Pipeliner) error {
ping = pipe.Ping()
err := client.Watch(ctx, func(tx *redis.Tx) error {
cmds, err := tx.Pipelined(ctx, func(pipe redis.Pipeliner) error {
ping = pipe.Ping(ctx)
return nil
})
Expect(err).NotTo(HaveOccurred())
@ -66,8 +66,8 @@ var _ = Describe("pool", func() {
It("respects max size on pipelines", func() {
perform(1000, func(id int) {
pipe := client.Pipeline()
ping := pipe.Ping()
cmds, err := pipe.Exec()
ping := pipe.Ping(ctx)
cmds, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(1))
Expect(ping.Err()).NotTo(HaveOccurred())
@ -87,10 +87,10 @@ var _ = Describe("pool", func() {
cn.SetNetConn(&badConn{})
client.Pool().Put(cn)
err = client.Ping().Err()
err = client.Ping(ctx).Err()
Expect(err).To(MatchError("bad connection"))
val, err := client.Ping().Result()
val, err := client.Ping(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("PONG"))
@ -106,7 +106,7 @@ var _ = Describe("pool", func() {
It("reuses connections", func() {
for i := 0; i < 100; i++ {
val, err := client.Ping().Result()
val, err := client.Ping(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("PONG"))
}
@ -122,7 +122,7 @@ var _ = Describe("pool", func() {
})
It("removes idle connections", func() {
err := client.Ping().Err()
err := client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred())
stats := client.PoolStats()

106
pubsub.go
View File

@ -26,7 +26,7 @@ var errPingTimeout = errors.New("redis: ping timeout")
type PubSub struct {
opt *Options
newConn func([]string) (*pool.Conn, error)
newConn func(ctx context.Context, channels []string) (*pool.Conn, error)
closeConn func(*pool.Conn) error
mu sync.Mutex
@ -55,14 +55,14 @@ func (c *PubSub) init() {
c.exit = make(chan struct{})
}
func (c *PubSub) connWithLock() (*pool.Conn, error) {
func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) {
c.mu.Lock()
cn, err := c.conn(nil)
cn, err := c.conn(ctx, nil)
c.mu.Unlock()
return cn, err
}
func (c *PubSub) conn(newChannels []string) (*pool.Conn, error) {
func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, error) {
if c.closed {
return nil, pool.ErrClosed
}
@ -73,12 +73,12 @@ func (c *PubSub) conn(newChannels []string) (*pool.Conn, error) {
channels := mapKeys(c.channels)
channels = append(channels, newChannels...)
cn, err := c.newConn(channels)
cn, err := c.newConn(ctx, channels)
if err != nil {
return nil, err
}
if err := c.resubscribe(cn); err != nil {
if err := c.resubscribe(ctx, cn); err != nil {
_ = c.closeConn(cn)
return nil, err
}
@ -93,15 +93,15 @@ func (c *PubSub) writeCmd(ctx context.Context, cn *pool.Conn, cmd Cmder) error {
})
}
func (c *PubSub) resubscribe(cn *pool.Conn) error {
func (c *PubSub) resubscribe(ctx context.Context, cn *pool.Conn) error {
var firstErr error
if len(c.channels) > 0 {
firstErr = c._subscribe(cn, "subscribe", mapKeys(c.channels))
firstErr = c._subscribe(ctx, cn, "subscribe", mapKeys(c.channels))
}
if len(c.patterns) > 0 {
err := c._subscribe(cn, "psubscribe", mapKeys(c.patterns))
err := c._subscribe(ctx, cn, "psubscribe", mapKeys(c.patterns))
if err != nil && firstErr == nil {
firstErr = err
}
@ -121,35 +121,40 @@ func mapKeys(m map[string]struct{}) []string {
}
func (c *PubSub) _subscribe(
cn *pool.Conn, redisCmd string, channels []string,
ctx context.Context, cn *pool.Conn, redisCmd string, channels []string,
) error {
args := make([]interface{}, 0, 1+len(channels))
args = append(args, redisCmd)
for _, channel := range channels {
args = append(args, channel)
}
cmd := NewSliceCmd(args...)
return c.writeCmd(context.TODO(), cn, cmd)
cmd := NewSliceCmd(ctx, args...)
return c.writeCmd(ctx, cn, cmd)
}
func (c *PubSub) releaseConnWithLock(cn *pool.Conn, err error, allowTimeout bool) {
func (c *PubSub) releaseConnWithLock(
ctx context.Context,
cn *pool.Conn,
err error,
allowTimeout bool,
) {
c.mu.Lock()
c.releaseConn(cn, err, allowTimeout)
c.releaseConn(ctx, cn, err, allowTimeout)
c.mu.Unlock()
}
func (c *PubSub) releaseConn(cn *pool.Conn, err error, allowTimeout bool) {
func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allowTimeout bool) {
if c.cn != cn {
return
}
if isBadConn(err, allowTimeout) {
c.reconnect(err)
c.reconnect(ctx, err)
}
}
func (c *PubSub) reconnect(reason error) {
func (c *PubSub) reconnect(ctx context.Context, reason error) {
_ = c.closeTheCn(reason)
_, _ = c.conn(nil)
_, _ = c.conn(ctx, nil)
}
func (c *PubSub) closeTheCn(reason error) error {
@ -179,11 +184,11 @@ func (c *PubSub) Close() error {
// Subscribe the client to the specified channels. It returns
// empty subscription if there are no channels.
func (c *PubSub) Subscribe(channels ...string) error {
func (c *PubSub) Subscribe(ctx context.Context, channels ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
err := c.subscribe("subscribe", channels...)
err := c.subscribe(ctx, "subscribe", channels...)
if c.channels == nil {
c.channels = make(map[string]struct{})
}
@ -195,11 +200,11 @@ func (c *PubSub) Subscribe(channels ...string) error {
// PSubscribe the client to the given patterns. It returns
// empty subscription if there are no patterns.
func (c *PubSub) PSubscribe(patterns ...string) error {
func (c *PubSub) PSubscribe(ctx context.Context, patterns ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
err := c.subscribe("psubscribe", patterns...)
err := c.subscribe(ctx, "psubscribe", patterns...)
if c.patterns == nil {
c.patterns = make(map[string]struct{})
}
@ -211,55 +216,55 @@ func (c *PubSub) PSubscribe(patterns ...string) error {
// Unsubscribe the client from the given channels, or from all of
// them if none is given.
func (c *PubSub) Unsubscribe(channels ...string) error {
func (c *PubSub) Unsubscribe(ctx context.Context, channels ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
for _, channel := range channels {
delete(c.channels, channel)
}
err := c.subscribe("unsubscribe", channels...)
err := c.subscribe(ctx, "unsubscribe", channels...)
return err
}
// PUnsubscribe the client from the given patterns, or from all of
// them if none is given.
func (c *PubSub) PUnsubscribe(patterns ...string) error {
func (c *PubSub) PUnsubscribe(ctx context.Context, patterns ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
for _, pattern := range patterns {
delete(c.patterns, pattern)
}
err := c.subscribe("punsubscribe", patterns...)
err := c.subscribe(ctx, "punsubscribe", patterns...)
return err
}
func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
cn, err := c.conn(channels)
func (c *PubSub) subscribe(ctx context.Context, redisCmd string, channels ...string) error {
cn, err := c.conn(ctx, channels)
if err != nil {
return err
}
err = c._subscribe(cn, redisCmd, channels)
c.releaseConn(cn, err, false)
err = c._subscribe(ctx, cn, redisCmd, channels)
c.releaseConn(ctx, cn, err, false)
return err
}
func (c *PubSub) Ping(payload ...string) error {
func (c *PubSub) Ping(ctx context.Context, payload ...string) error {
args := []interface{}{"ping"}
if len(payload) == 1 {
args = append(args, payload[0])
}
cmd := NewCmd(args...)
cmd := NewCmd(ctx, args...)
cn, err := c.connWithLock()
cn, err := c.connWithLock(ctx)
if err != nil {
return err
}
err = c.writeCmd(context.TODO(), cn, cmd)
c.releaseConnWithLock(cn, err, false)
err = c.writeCmd(ctx, cn, cmd)
c.releaseConnWithLock(ctx, cn, err, false)
return err
}
@ -340,21 +345,21 @@ func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
// ReceiveTimeout acts like Receive but returns an error if message
// is not received in time. This is low-level API and in most cases
// Channel should be used instead.
func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (interface{}, error) {
if c.cmd == nil {
c.cmd = NewCmd()
c.cmd = NewCmd(ctx)
}
cn, err := c.connWithLock()
cn, err := c.connWithLock(ctx)
if err != nil {
return nil, err
}
err = cn.WithReader(context.TODO(), timeout, func(rd *proto.Reader) error {
err = cn.WithReader(ctx, timeout, func(rd *proto.Reader) error {
return c.cmd.readReply(rd)
})
c.releaseConnWithLock(cn, err, timeout > 0)
c.releaseConnWithLock(ctx, cn, err, timeout > 0)
if err != nil {
return nil, err
}
@ -365,16 +370,16 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
// Receive returns a message as a Subscription, Message, Pong or error.
// See PubSub example for details. This is low-level API and in most cases
// Channel should be used instead.
func (c *PubSub) Receive() (interface{}, error) {
return c.ReceiveTimeout(0)
func (c *PubSub) Receive(ctx context.Context) (interface{}, error) {
return c.ReceiveTimeout(ctx, 0)
}
// ReceiveMessage returns a Message or error ignoring Subscription and Pong
// messages. This is low-level API and in most cases Channel should be used
// instead.
func (c *PubSub) ReceiveMessage() (*Message, error) {
func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) {
for {
msg, err := c.Receive()
msg, err := c.Receive(ctx)
if err != nil {
return nil, err
}
@ -427,7 +432,7 @@ func (c *PubSub) ChannelSize(size int) <-chan *Message {
// reconnections.
//
// ChannelWithSubscriptions can not be used together with Channel or ChannelSize.
func (c *PubSub) ChannelWithSubscriptions(size int) <-chan interface{} {
func (c *PubSub) ChannelWithSubscriptions(ctx context.Context, size int) <-chan interface{} {
c.chOnce.Do(func() {
c.initPing()
c.initAllChan(size)
@ -444,6 +449,7 @@ func (c *PubSub) ChannelWithSubscriptions(size int) <-chan interface{} {
}
func (c *PubSub) initPing() {
ctx := context.TODO()
c.ping = make(chan struct{}, 1)
go func() {
timer := time.NewTimer(pingTimeout)
@ -459,7 +465,7 @@ func (c *PubSub) initPing() {
<-timer.C
}
case <-timer.C:
pingErr := c.Ping()
pingErr := c.Ping(ctx)
if healthy {
healthy = false
} else {
@ -467,7 +473,7 @@ func (c *PubSub) initPing() {
pingErr = errPingTimeout
}
c.mu.Lock()
c.reconnect(pingErr)
c.reconnect(ctx, pingErr)
healthy = true
c.mu.Unlock()
}
@ -480,6 +486,7 @@ func (c *PubSub) initPing() {
// initMsgChan must be in sync with initAllChan.
func (c *PubSub) initMsgChan(size int) {
ctx := context.TODO()
c.msgCh = make(chan *Message, size)
go func() {
timer := time.NewTimer(pingTimeout)
@ -487,7 +494,7 @@ func (c *PubSub) initMsgChan(size int) {
var errCount int
for {
msg, err := c.Receive()
msg, err := c.Receive(ctx)
if err != nil {
if err == pool.ErrClosed {
close(c.msgCh)
@ -533,6 +540,7 @@ func (c *PubSub) initMsgChan(size int) {
// initAllChan must be in sync with initMsgChan.
func (c *PubSub) initAllChan(size int) {
ctx := context.TODO()
c.allCh = make(chan interface{}, size)
go func() {
timer := time.NewTimer(pingTimeout)
@ -540,7 +548,7 @@ func (c *PubSub) initAllChan(size int) {
var errCount int
for {
msg, err := c.Receive()
msg, err := c.Receive(ctx)
if err != nil {
if err == pool.ErrClosed {
close(c.allCh)

View File

@ -20,7 +20,7 @@ var _ = Describe("PubSub", func() {
opt.MinIdleConns = 0
opt.MaxConnAge = 0
client = redis.NewClient(opt)
Expect(client.FlushDB().Err()).NotTo(HaveOccurred())
Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
})
AfterEach(func() {
@ -28,18 +28,18 @@ var _ = Describe("PubSub", func() {
})
It("implements Stringer", func() {
pubsub := client.PSubscribe("mychannel*")
pubsub := client.PSubscribe(ctx, "mychannel*")
defer pubsub.Close()
Expect(pubsub.String()).To(Equal("PubSub(mychannel*)"))
})
It("should support pattern matching", func() {
pubsub := client.PSubscribe("mychannel*")
pubsub := client.PSubscribe(ctx, "mychannel*")
defer pubsub.Close()
{
msgi, err := pubsub.ReceiveTimeout(time.Second)
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
subscr := msgi.(*redis.Subscription)
Expect(subscr.Kind).To(Equal("psubscribe"))
@ -48,19 +48,19 @@ var _ = Describe("PubSub", func() {
}
{
msgi, err := pubsub.ReceiveTimeout(time.Second)
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err.(net.Error).Timeout()).To(Equal(true))
Expect(msgi).To(BeNil())
}
n, err := client.Publish("mychannel1", "hello").Result()
n, err := client.Publish(ctx, "mychannel1", "hello").Result()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1)))
Expect(pubsub.PUnsubscribe("mychannel*")).NotTo(HaveOccurred())
Expect(pubsub.PUnsubscribe(ctx, "mychannel*")).NotTo(HaveOccurred())
{
msgi, err := pubsub.ReceiveTimeout(time.Second)
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
subscr := msgi.(*redis.Message)
Expect(subscr.Channel).To(Equal("mychannel1"))
@ -69,7 +69,7 @@ var _ = Describe("PubSub", func() {
}
{
msgi, err := pubsub.ReceiveTimeout(time.Second)
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
subscr := msgi.(*redis.Subscription)
Expect(subscr.Kind).To(Equal("punsubscribe"))
@ -82,31 +82,31 @@ var _ = Describe("PubSub", func() {
})
It("should pub/sub channels", func() {
channels, err := client.PubSubChannels("mychannel*").Result()
channels, err := client.PubSubChannels(ctx, "mychannel*").Result()
Expect(err).NotTo(HaveOccurred())
Expect(channels).To(BeEmpty())
pubsub := client.Subscribe("mychannel", "mychannel2")
pubsub := client.Subscribe(ctx, "mychannel", "mychannel2")
defer pubsub.Close()
channels, err = client.PubSubChannels("mychannel*").Result()
channels, err = client.PubSubChannels(ctx, "mychannel*").Result()
Expect(err).NotTo(HaveOccurred())
Expect(channels).To(ConsistOf([]string{"mychannel", "mychannel2"}))
channels, err = client.PubSubChannels("").Result()
channels, err = client.PubSubChannels(ctx, "").Result()
Expect(err).NotTo(HaveOccurred())
Expect(channels).To(BeEmpty())
channels, err = client.PubSubChannels("*").Result()
channels, err = client.PubSubChannels(ctx, "*").Result()
Expect(err).NotTo(HaveOccurred())
Expect(len(channels)).To(BeNumerically(">=", 2))
})
It("should return the numbers of subscribers", func() {
pubsub := client.Subscribe("mychannel", "mychannel2")
pubsub := client.Subscribe(ctx, "mychannel", "mychannel2")
defer pubsub.Close()
channels, err := client.PubSubNumSub("mychannel", "mychannel2", "mychannel3").Result()
channels, err := client.PubSubNumSub(ctx, "mychannel", "mychannel2", "mychannel3").Result()
Expect(err).NotTo(HaveOccurred())
Expect(channels).To(Equal(map[string]int64{
"mychannel": 1,
@ -116,24 +116,24 @@ var _ = Describe("PubSub", func() {
})
It("should return the numbers of subscribers by pattern", func() {
num, err := client.PubSubNumPat().Result()
num, err := client.PubSubNumPat(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(num).To(Equal(int64(0)))
pubsub := client.PSubscribe("*")
pubsub := client.PSubscribe(ctx, "*")
defer pubsub.Close()
num, err = client.PubSubNumPat().Result()
num, err = client.PubSubNumPat(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(num).To(Equal(int64(1)))
})
It("should pub/sub", func() {
pubsub := client.Subscribe("mychannel", "mychannel2")
pubsub := client.Subscribe(ctx, "mychannel", "mychannel2")
defer pubsub.Close()
{
msgi, err := pubsub.ReceiveTimeout(time.Second)
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
subscr := msgi.(*redis.Subscription)
Expect(subscr.Kind).To(Equal("subscribe"))
@ -142,7 +142,7 @@ var _ = Describe("PubSub", func() {
}
{
msgi, err := pubsub.ReceiveTimeout(time.Second)
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
subscr := msgi.(*redis.Subscription)
Expect(subscr.Kind).To(Equal("subscribe"))
@ -151,23 +151,23 @@ var _ = Describe("PubSub", func() {
}
{
msgi, err := pubsub.ReceiveTimeout(time.Second)
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err.(net.Error).Timeout()).To(Equal(true))
Expect(msgi).NotTo(HaveOccurred())
}
n, err := client.Publish("mychannel", "hello").Result()
n, err := client.Publish(ctx, "mychannel", "hello").Result()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1)))
n, err = client.Publish("mychannel2", "hello2").Result()
n, err = client.Publish(ctx, "mychannel2", "hello2").Result()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1)))
Expect(pubsub.Unsubscribe("mychannel", "mychannel2")).NotTo(HaveOccurred())
Expect(pubsub.Unsubscribe(ctx, "mychannel", "mychannel2")).NotTo(HaveOccurred())
{
msgi, err := pubsub.ReceiveTimeout(time.Second)
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
msg := msgi.(*redis.Message)
Expect(msg.Channel).To(Equal("mychannel"))
@ -175,7 +175,7 @@ var _ = Describe("PubSub", func() {
}
{
msgi, err := pubsub.ReceiveTimeout(time.Second)
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
msg := msgi.(*redis.Message)
Expect(msg.Channel).To(Equal("mychannel2"))
@ -183,7 +183,7 @@ var _ = Describe("PubSub", func() {
}
{
msgi, err := pubsub.ReceiveTimeout(time.Second)
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
subscr := msgi.(*redis.Subscription)
Expect(subscr.Kind).To(Equal("unsubscribe"))
@ -192,7 +192,7 @@ var _ = Describe("PubSub", func() {
}
{
msgi, err := pubsub.ReceiveTimeout(time.Second)
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
subscr := msgi.(*redis.Subscription)
Expect(subscr.Kind).To(Equal("unsubscribe"))
@ -205,42 +205,42 @@ var _ = Describe("PubSub", func() {
})
It("should ping/pong", func() {
pubsub := client.Subscribe("mychannel")
pubsub := client.Subscribe(ctx, "mychannel")
defer pubsub.Close()
_, err := pubsub.ReceiveTimeout(time.Second)
_, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
err = pubsub.Ping("")
err = pubsub.Ping(ctx, "")
Expect(err).NotTo(HaveOccurred())
msgi, err := pubsub.ReceiveTimeout(time.Second)
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
pong := msgi.(*redis.Pong)
Expect(pong.Payload).To(Equal(""))
})
It("should ping/pong with payload", func() {
pubsub := client.Subscribe("mychannel")
pubsub := client.Subscribe(ctx, "mychannel")
defer pubsub.Close()
_, err := pubsub.ReceiveTimeout(time.Second)
_, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
err = pubsub.Ping("hello")
err = pubsub.Ping(ctx, "hello")
Expect(err).NotTo(HaveOccurred())
msgi, err := pubsub.ReceiveTimeout(time.Second)
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
pong := msgi.(*redis.Pong)
Expect(pong.Payload).To(Equal("hello"))
})
It("should multi-ReceiveMessage", func() {
pubsub := client.Subscribe("mychannel")
pubsub := client.Subscribe(ctx, "mychannel")
defer pubsub.Close()
subscr, err := pubsub.ReceiveTimeout(time.Second)
subscr, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
Expect(subscr).To(Equal(&redis.Subscription{
Kind: "subscribe",
@ -248,25 +248,25 @@ var _ = Describe("PubSub", func() {
Count: 1,
}))
err = client.Publish("mychannel", "hello").Err()
err = client.Publish(ctx, "mychannel", "hello").Err()
Expect(err).NotTo(HaveOccurred())
err = client.Publish("mychannel", "world").Err()
err = client.Publish(ctx, "mychannel", "world").Err()
Expect(err).NotTo(HaveOccurred())
msg, err := pubsub.ReceiveMessage()
msg, err := pubsub.ReceiveMessage(ctx)
Expect(err).NotTo(HaveOccurred())
Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal("hello"))
msg, err = pubsub.ReceiveMessage()
msg, err = pubsub.ReceiveMessage(ctx)
Expect(err).NotTo(HaveOccurred())
Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal("world"))
})
It("returns an error when subscribe fails", func() {
pubsub := client.Subscribe()
pubsub := client.Subscribe(ctx)
defer pubsub.Close()
pubsub.SetNetConn(&badConn{
@ -274,10 +274,10 @@ var _ = Describe("PubSub", func() {
writeErr: io.EOF,
})
err := pubsub.Subscribe("mychannel")
err := pubsub.Subscribe(ctx, "mychannel")
Expect(err).To(MatchError("EOF"))
err = pubsub.Subscribe("mychannel")
err = pubsub.Subscribe(ctx, "mychannel")
Expect(err).NotTo(HaveOccurred())
})
@ -293,16 +293,16 @@ var _ = Describe("PubSub", func() {
defer GinkgoRecover()
Eventually(step).Should(Receive())
err := client.Publish("mychannel", "hello").Err()
err := client.Publish(ctx, "mychannel", "hello").Err()
Expect(err).NotTo(HaveOccurred())
step <- struct{}{}
}()
_, err := pubsub.ReceiveMessage()
_, err := pubsub.ReceiveMessage(ctx)
Expect(err).To(Equal(io.EOF))
step <- struct{}{}
msg, err := pubsub.ReceiveMessage()
msg, err := pubsub.ReceiveMessage(ctx)
Expect(err).NotTo(HaveOccurred())
Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal("hello"))
@ -311,10 +311,10 @@ var _ = Describe("PubSub", func() {
}
It("Subscribe should reconnect on ReceiveMessage error", func() {
pubsub := client.Subscribe("mychannel")
pubsub := client.Subscribe(ctx, "mychannel")
defer pubsub.Close()
subscr, err := pubsub.ReceiveTimeout(time.Second)
subscr, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
Expect(subscr).To(Equal(&redis.Subscription{
Kind: "subscribe",
@ -326,10 +326,10 @@ var _ = Describe("PubSub", func() {
})
It("PSubscribe should reconnect on ReceiveMessage error", func() {
pubsub := client.PSubscribe("mychannel")
pubsub := client.PSubscribe(ctx, "mychannel")
defer pubsub.Close()
subscr, err := pubsub.ReceiveTimeout(time.Second)
subscr, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
Expect(subscr).To(Equal(&redis.Subscription{
Kind: "psubscribe",
@ -341,7 +341,7 @@ var _ = Describe("PubSub", func() {
})
It("should return on Close", func() {
pubsub := client.Subscribe("mychannel")
pubsub := client.Subscribe(ctx, "mychannel")
defer pubsub.Close()
var wg sync.WaitGroup
@ -352,7 +352,7 @@ var _ = Describe("PubSub", func() {
wg.Done()
defer wg.Done()
_, err := pubsub.ReceiveMessage()
_, err := pubsub.ReceiveMessage(ctx)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(SatisfyAny(
Equal("redis: client is closed"),
@ -371,7 +371,7 @@ var _ = Describe("PubSub", func() {
It("should ReceiveMessage without a subscription", func() {
timeout := 100 * time.Millisecond
pubsub := client.Subscribe()
pubsub := client.Subscribe(ctx)
defer pubsub.Close()
var wg sync.WaitGroup
@ -382,16 +382,16 @@ var _ = Describe("PubSub", func() {
time.Sleep(timeout)
err := pubsub.Subscribe("mychannel")
err := pubsub.Subscribe(ctx, "mychannel")
Expect(err).NotTo(HaveOccurred())
time.Sleep(timeout)
err = client.Publish("mychannel", "hello").Err()
err = client.Publish(ctx, "mychannel", "hello").Err()
Expect(err).NotTo(HaveOccurred())
}()
msg, err := pubsub.ReceiveMessage()
msg, err := pubsub.ReceiveMessage(ctx)
Expect(err).NotTo(HaveOccurred())
Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal("hello"))
@ -400,13 +400,13 @@ var _ = Describe("PubSub", func() {
})
It("handles big message payload", func() {
pubsub := client.Subscribe("mychannel")
pubsub := client.Subscribe(ctx, "mychannel")
defer pubsub.Close()
ch := pubsub.Channel()
bigVal := bigVal()
err := client.Publish("mychannel", bigVal).Err()
err := client.Publish(ctx, "mychannel", bigVal).Err()
Expect(err).NotTo(HaveOccurred())
var msg *redis.Message
@ -418,7 +418,7 @@ var _ = Describe("PubSub", func() {
It("supports concurrent Ping and Receive", func() {
const N = 100
pubsub := client.Subscribe("mychannel")
pubsub := client.Subscribe(ctx, "mychannel")
defer pubsub.Close()
done := make(chan struct{})
@ -426,14 +426,14 @@ var _ = Describe("PubSub", func() {
defer GinkgoRecover()
for i := 0; i < N; i++ {
_, err := pubsub.ReceiveTimeout(5 * time.Second)
_, err := pubsub.ReceiveTimeout(ctx, 5*time.Second)
Expect(err).NotTo(HaveOccurred())
}
close(done)
}()
for i := 0; i < N; i++ {
err := pubsub.Ping()
err := pubsub.Ping(ctx)
Expect(err).NotTo(HaveOccurred())
}

View File

@ -2,7 +2,6 @@ package redis_test
import (
"bytes"
"context"
"fmt"
"net"
"strconv"
@ -22,7 +21,7 @@ var _ = Describe("races", func() {
BeforeEach(func() {
client = redis.NewClient(redisOptions())
Expect(client.FlushDB().Err()).To(BeNil())
Expect(client.FlushDB(ctx).Err()).To(BeNil())
C, N = 10, 1000
if testing.Short() {
@ -40,7 +39,7 @@ var _ = Describe("races", func() {
perform(C, func(id int) {
for i := 0; i < N; i++ {
msg := fmt.Sprintf("echo %d %d", id, i)
echo, err := client.Echo(msg).Result()
echo, err := client.Echo(ctx, msg).Result()
Expect(err).NotTo(HaveOccurred())
Expect(echo).To(Equal(msg))
}
@ -52,12 +51,12 @@ var _ = Describe("races", func() {
perform(C, func(id int) {
for i := 0; i < N; i++ {
err := client.Incr(key).Err()
err := client.Incr(ctx, key).Err()
Expect(err).NotTo(HaveOccurred())
}
})
val, err := client.Get(key).Int64()
val, err := client.Get(ctx, key).Int64()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal(int64(C * N)))
})
@ -66,6 +65,7 @@ var _ = Describe("races", func() {
perform(C, func(id int) {
for i := 0; i < N; i++ {
err := client.Set(
ctx,
fmt.Sprintf("keys.key-%d-%d", id, i),
fmt.Sprintf("hello-%d-%d", id, i),
0,
@ -74,7 +74,7 @@ var _ = Describe("races", func() {
}
})
keys := client.Keys("keys.*")
keys := client.Keys(ctx, "keys.*")
Expect(keys.Err()).NotTo(HaveOccurred())
Expect(len(keys.Val())).To(Equal(C * N))
})
@ -86,12 +86,12 @@ var _ = Describe("races", func() {
key := fmt.Sprintf("keys.key-%d", i)
keys = append(keys, key)
err := client.Set(key, fmt.Sprintf("hello-%d", i), 0).Err()
err := client.Set(ctx, key, fmt.Sprintf("hello-%d", i), 0).Err()
Expect(err).NotTo(HaveOccurred())
}
keys = append(keys, "non-existent-key")
vals, err := client.MGet(keys...).Result()
vals, err := client.MGet(ctx, keys...).Result()
Expect(err).NotTo(HaveOccurred())
Expect(len(vals)).To(Equal(N + 2))
@ -109,7 +109,7 @@ var _ = Describe("races", func() {
bigVal := bigVal()
err := client.Set("key", bigVal, 0).Err()
err := client.Set(ctx, "key", bigVal, 0).Err()
Expect(err).NotTo(HaveOccurred())
// Reconnect to get new connection.
@ -118,7 +118,7 @@ var _ = Describe("races", func() {
perform(C, func(id int) {
for i := 0; i < N; i++ {
got, err := client.Get("key").Bytes()
got, err := client.Get(ctx, "key").Bytes()
Expect(err).NotTo(HaveOccurred())
Expect(got).To(Equal(bigVal))
}
@ -131,14 +131,14 @@ var _ = Describe("races", func() {
bigVal := bigVal()
perform(C, func(id int) {
for i := 0; i < N; i++ {
err := client.Set("key", bigVal, 0).Err()
err := client.Set(ctx, "key", bigVal, 0).Err()
Expect(err).NotTo(HaveOccurred())
}
})
})
It("should select db", func() {
err := client.Set("db", 1, 0).Err()
err := client.Set(ctx, "db", 1, 0).Err()
Expect(err).NotTo(HaveOccurred())
perform(C, func(id int) {
@ -146,10 +146,10 @@ var _ = Describe("races", func() {
opt.DB = id
client := redis.NewClient(opt)
for i := 0; i < N; i++ {
err := client.Set("db", id, 0).Err()
err := client.Set(ctx, "db", id, 0).Err()
Expect(err).NotTo(HaveOccurred())
n, err := client.Get("db").Int64()
n, err := client.Get(ctx, "db").Int64()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(id)))
}
@ -157,7 +157,7 @@ var _ = Describe("races", func() {
Expect(err).NotTo(HaveOccurred())
})
n, err := client.Get("db").Int64()
n, err := client.Get(ctx, "db").Int64()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1)))
})
@ -170,7 +170,7 @@ var _ = Describe("races", func() {
client := redis.NewClient(opt)
perform(C, func(id int) {
err := client.Ping().Err()
err := client.Ping(ctx).Err()
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
@ -181,21 +181,21 @@ var _ = Describe("races", func() {
})
It("should Watch/Unwatch", func() {
err := client.Set("key", "0", 0).Err()
err := client.Set(ctx, "key", "0", 0).Err()
Expect(err).NotTo(HaveOccurred())
perform(C, func(id int) {
for i := 0; i < N; i++ {
err := client.Watch(func(tx *redis.Tx) error {
val, err := tx.Get("key").Result()
err := client.Watch(ctx, func(tx *redis.Tx) error {
val, err := tx.Get(ctx, "key").Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).NotTo(Equal(redis.Nil))
num, err := strconv.ParseInt(val, 10, 64)
Expect(err).NotTo(HaveOccurred())
cmds, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
pipe.Set("key", strconv.FormatInt(num+1, 10), 0)
cmds, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set(ctx, "key", strconv.FormatInt(num+1, 10), 0)
return nil
})
Expect(cmds).To(HaveLen(1))
@ -209,7 +209,7 @@ var _ = Describe("races", func() {
}
})
val, err := client.Get("key").Int64()
val, err := client.Get(ctx, "key").Int64()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal(int64(C * N)))
})
@ -218,10 +218,10 @@ var _ = Describe("races", func() {
perform(C, func(id int) {
pipe := client.Pipeline()
for i := 0; i < N; i++ {
pipe.Echo(fmt.Sprint(i))
pipe.Echo(ctx, fmt.Sprint(i))
}
cmds, err := pipe.Exec()
cmds, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(N))
@ -234,14 +234,14 @@ var _ = Describe("races", func() {
It("should Pipeline", func() {
pipe := client.Pipeline()
perform(N, func(id int) {
pipe.Incr("key")
pipe.Incr(ctx, "key")
})
cmds, err := pipe.Exec()
cmds, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(N))
n, err := client.Get("key").Int64()
n, err := client.Get(ctx, "key").Int64()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(N)))
})
@ -249,14 +249,14 @@ var _ = Describe("races", func() {
It("should TxPipeline", func() {
pipe := client.TxPipeline()
perform(N, func(id int) {
pipe.Incr("key")
pipe.Incr(ctx, "key")
})
cmds, err := pipe.Exec()
cmds, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(N))
n, err := client.Get("key").Int64()
n, err := client.Get(ctx, "key").Int64()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(N)))
})
@ -265,7 +265,7 @@ var _ = Describe("races", func() {
var received uint32
wg := performAsync(C, func(id int) {
for {
v, err := client.BLPop(3*time.Second, "list").Result()
v, err := client.BLPop(ctx, 3*time.Second, "list").Result()
if err != nil {
break
}
@ -276,7 +276,7 @@ var _ = Describe("races", func() {
perform(C, func(id int) {
for i := 0; i < N; i++ {
err := client.LPush("list", "hello").Err()
err := client.LPush(ctx, "list", "hello").Err()
Expect(err).NotTo(HaveOccurred())
}
})
@ -287,7 +287,7 @@ var _ = Describe("races", func() {
It("should WithContext", func() {
perform(C, func(_ int) {
err := client.WithContext(context.Background()).Ping().Err()
err := client.WithContext(ctx).Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred())
})
})
@ -299,7 +299,7 @@ var _ = Describe("cluster races", func() {
BeforeEach(func() {
opt := redisClusterOptions()
client = cluster.newClusterClient(opt)
client = cluster.newClusterClient(ctx, opt)
C, N = 10, 1000
if testing.Short() {
@ -317,7 +317,7 @@ var _ = Describe("cluster races", func() {
perform(C, func(id int) {
for i := 0; i < N; i++ {
msg := fmt.Sprintf("echo %d %d", id, i)
echo, err := client.Echo(msg).Result()
echo, err := client.Echo(ctx, msg).Result()
Expect(err).NotTo(HaveOccurred())
Expect(echo).To(Equal(msg))
}
@ -328,7 +328,7 @@ var _ = Describe("cluster races", func() {
perform(C, func(id int) {
for i := 0; i < N; i++ {
key := fmt.Sprintf("key_%d_%d", id, i)
_, err := client.Get(key).Result()
_, err := client.Get(ctx, key).Result()
Expect(err).To(Equal(redis.Nil))
}
})
@ -339,12 +339,12 @@ var _ = Describe("cluster races", func() {
perform(C, func(id int) {
for i := 0; i < N; i++ {
err := client.Incr(key).Err()
err := client.Incr(ctx, key).Err()
Expect(err).NotTo(HaveOccurred())
}
})
val, err := client.Get(key).Int64()
val, err := client.Get(ctx, key).Int64()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal(int64(C * N)))
})

152
redis.go
View File

@ -131,7 +131,7 @@ func (hs hooks) afterProcessPipeline(ctx context.Context, cmds []Cmder) error {
func (hs hooks) processTxPipeline(
ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
) error {
cmds = wrapMultiExec(cmds)
cmds = wrapMultiExec(ctx, cmds)
return hs.processPipeline(ctx, cmds, fn)
}
@ -201,6 +201,7 @@ func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
}
return nil, err
}
return cn, nil
}
@ -210,7 +211,13 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
return nil, err
}
err = c.initConn(ctx, cn)
if cn.Inited {
return cn, nil
}
err = internal.WithSpan(ctx, "init_conn", func(ctx context.Context) error {
return c.initConn(ctx, cn)
})
if err != nil {
c.connPool.Remove(cn, err)
if err := internal.Unwrap(err); err != nil {
@ -239,17 +246,17 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
connPool.SetConn(cn)
conn := newConn(ctx, c.opt, connPool)
_, err := conn.Pipelined(func(pipe Pipeliner) error {
_, err := conn.Pipelined(ctx, func(pipe Pipeliner) error {
if c.opt.Password != "" {
pipe.Auth(c.opt.Password)
pipe.Auth(ctx, c.opt.Password)
}
if c.opt.DB > 0 {
pipe.Select(c.opt.DB)
pipe.Select(ctx, c.opt.DB)
}
if c.opt.readOnly {
pipe.ReadOnly()
pipe.ReadOnly(ctx)
}
return nil
@ -279,16 +286,18 @@ func (c *baseClient) releaseConn(cn *pool.Conn, err error) {
func (c *baseClient) withConn(
ctx context.Context, fn func(context.Context, *pool.Conn) error,
) error {
cn, err := c.getConn(ctx)
if err != nil {
return err
}
defer func() {
c.releaseConn(cn, err)
}()
return internal.WithSpan(ctx, "with_conn", func(ctx context.Context) error {
cn, err := c.getConn(ctx)
if err != nil {
return err
}
defer func() {
c.releaseConn(cn, err)
}()
err = fn(ctx, cn)
return err
err = fn(ctx, cn)
return err
})
}
func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
@ -303,32 +312,43 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
func (c *baseClient) _process(ctx context.Context, cmd Cmder) error {
var lastErr error
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
if attempt > 0 {
if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
return err
}
}
attempt := attempt
retryTimeout := true
lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmd)
var retry bool
err := internal.WithSpan(ctx, "process", func(ctx context.Context) error {
if attempt > 0 {
if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
return err
}
}
retryTimeout := true
err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmd)
})
if err != nil {
return err
}
err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply)
if err != nil {
retryTimeout = cmd.readTimeout() == nil
return err
}
return nil
})
if err != nil {
return err
if err == nil {
return nil
}
err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply)
if err != nil {
retryTimeout = cmd.readTimeout() == nil
return err
}
return nil
retry = isRetryableError(err, retryTimeout)
return err
})
if lastErr == nil || !isRetryableError(lastErr, retryTimeout) {
return lastErr
if err == nil || !retry {
return err
}
lastErr = err
}
return lastErr
}
@ -465,14 +485,14 @@ func (c *baseClient) txPipelineProcessCmds(
return false, err
}
func wrapMultiExec(cmds []Cmder) []Cmder {
func wrapMultiExec(ctx context.Context, cmds []Cmder) []Cmder {
if len(cmds) == 0 {
panic("not reached")
}
cmds = append(cmds, make([]Cmder, 2)...)
copy(cmds[1:], cmds[:len(cmds)-2])
cmds[0] = NewStatusCmd("multi")
cmds[len(cmds)-1] = NewSliceCmd("exec")
cmds[0] = NewStatusCmd(ctx, "multi")
cmds[len(cmds)-1] = NewSliceCmd(ctx, "exec")
return cmds
}
@ -561,26 +581,18 @@ func (c *Client) WithContext(ctx context.Context) *Client {
return clone
}
func (c *Client) Conn() *Conn {
return newConn(c.ctx, c.opt, pool.NewSingleConnPool(c.connPool))
func (c *Client) Conn(ctx context.Context) *Conn {
return newConn(ctx, c.opt, pool.NewSingleConnPool(c.connPool))
}
// Do creates a Cmd from the args and processes the cmd.
func (c *Client) Do(args ...interface{}) *Cmd {
return c.DoContext(c.ctx, args...)
}
func (c *Client) DoContext(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(args...)
_ = c.ProcessContext(ctx, cmd)
func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(ctx, args...)
_ = c.Process(ctx, cmd)
return cmd
}
func (c *Client) Process(cmd Cmder) error {
return c.ProcessContext(c.ctx, cmd)
}
func (c *Client) ProcessContext(ctx context.Context, cmd Cmder) error {
func (c *Client) Process(ctx context.Context, cmd Cmder) error {
return c.hooks.process(ctx, cmd, c.baseClient.process)
}
@ -605,8 +617,8 @@ func (c *Client) PoolStats() *PoolStats {
return (*PoolStats)(stats)
}
func (c *Client) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(fn)
func (c *Client) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(ctx, fn)
}
func (c *Client) Pipeline() Pipeliner {
@ -618,8 +630,8 @@ func (c *Client) Pipeline() Pipeliner {
return &pipe
}
func (c *Client) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(fn)
func (c *Client) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(ctx, fn)
}
// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
@ -636,8 +648,8 @@ func (c *Client) pubSub() *PubSub {
pubsub := &PubSub{
opt: c.opt,
newConn: func(channels []string) (*pool.Conn, error) {
return c.newConn(context.TODO())
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
return c.newConn(ctx)
},
closeConn: c.connPool.CloseConn,
}
@ -671,20 +683,20 @@ func (c *Client) pubSub() *PubSub {
// }
//
// ch := sub.Channel()
func (c *Client) Subscribe(channels ...string) *PubSub {
func (c *Client) Subscribe(ctx context.Context, channels ...string) *PubSub {
pubsub := c.pubSub()
if len(channels) > 0 {
_ = pubsub.Subscribe(channels...)
_ = pubsub.Subscribe(ctx, channels...)
}
return pubsub
}
// PSubscribe subscribes the client to the given patterns.
// Patterns can be omitted to create empty subscription.
func (c *Client) PSubscribe(channels ...string) *PubSub {
func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub {
pubsub := c.pubSub()
if len(channels) > 0 {
_ = pubsub.PSubscribe(channels...)
_ = pubsub.PSubscribe(ctx, channels...)
}
return pubsub
}
@ -718,16 +730,12 @@ func newConn(ctx context.Context, opt *Options, connPool pool.Pooler) *Conn {
return &c
}
func (c *Conn) Process(cmd Cmder) error {
return c.ProcessContext(c.ctx, cmd)
}
func (c *Conn) ProcessContext(ctx context.Context, cmd Cmder) error {
func (c *Conn) Process(ctx context.Context, cmd Cmder) error {
return c.baseClient.process(ctx, cmd)
}
func (c *Conn) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(fn)
func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(ctx, fn)
}
func (c *Conn) Pipeline() Pipeliner {
@ -739,8 +747,8 @@ func (c *Conn) Pipeline() Pipeliner {
return &pipe
}
func (c *Conn) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(fn)
func (c *Conn) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(ctx, fn)
}
// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.

View File

@ -34,7 +34,7 @@ func TestHookError(t *testing.T) {
})
rdb.AddHook(redisHookError{})
err := rdb.Ping().Err()
err := rdb.Ping(ctx).Err()
if err == nil {
t.Fatalf("got nil, expected an error")
}
@ -52,7 +52,7 @@ var _ = Describe("Client", func() {
BeforeEach(func() {
client = redis.NewClient(redisOptions())
Expect(client.FlushDB().Err()).NotTo(HaveOccurred())
Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
})
AfterEach(func() {
@ -63,27 +63,27 @@ var _ = Describe("Client", func() {
Expect(client.String()).To(Equal("Redis<:6380 db:15>"))
})
It("supports WithContext", func() {
c, cancel := context.WithCancel(context.Background())
It("supports context", func() {
ctx, cancel := context.WithCancel(ctx)
cancel()
err := client.WithContext(c).Ping().Err()
err := client.Ping(ctx).Err()
Expect(err).To(MatchError("context canceled"))
})
It("supports WithTimeout", func() {
err := client.ClientPause(time.Second).Err()
err := client.ClientPause(ctx, time.Second).Err()
Expect(err).NotTo(HaveOccurred())
err = client.WithTimeout(10 * time.Millisecond).Ping().Err()
err = client.WithTimeout(10 * time.Millisecond).Ping(ctx).Err()
Expect(err).To(HaveOccurred())
err = client.Ping().Err()
err = client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred())
})
It("should ping", func() {
val, err := client.Ping().Result()
val, err := client.Ping(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("PONG"))
})
@ -102,7 +102,7 @@ var _ = Describe("Client", func() {
},
})
val, err := custom.Ping().Result()
val, err := custom.Ping(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("PONG"))
Expect(custom.Close()).NotTo(HaveOccurred())
@ -110,48 +110,48 @@ var _ = Describe("Client", func() {
It("should close", func() {
Expect(client.Close()).NotTo(HaveOccurred())
err := client.Ping().Err()
err := client.Ping(ctx).Err()
Expect(err).To(MatchError("redis: client is closed"))
})
It("should close pubsub without closing the client", func() {
pubsub := client.Subscribe()
pubsub := client.Subscribe(ctx)
Expect(pubsub.Close()).NotTo(HaveOccurred())
_, err := pubsub.Receive()
_, err := pubsub.Receive(ctx)
Expect(err).To(MatchError("redis: client is closed"))
Expect(client.Ping().Err()).NotTo(HaveOccurred())
Expect(client.Ping(ctx).Err()).NotTo(HaveOccurred())
})
It("should close Tx without closing the client", func() {
err := client.Watch(func(tx *redis.Tx) error {
_, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
err := client.Watch(ctx, func(tx *redis.Tx) error {
_, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping(ctx)
return nil
})
return err
})
Expect(err).NotTo(HaveOccurred())
Expect(client.Ping().Err()).NotTo(HaveOccurred())
Expect(client.Ping(ctx).Err()).NotTo(HaveOccurred())
})
It("should close pipeline without closing the client", func() {
pipeline := client.Pipeline()
Expect(pipeline.Close()).NotTo(HaveOccurred())
pipeline.Ping()
_, err := pipeline.Exec()
pipeline.Ping(ctx)
_, err := pipeline.Exec(ctx)
Expect(err).To(MatchError("redis: client is closed"))
Expect(client.Ping().Err()).NotTo(HaveOccurred())
Expect(client.Ping(ctx).Err()).NotTo(HaveOccurred())
})
It("should close pubsub when client is closed", func() {
pubsub := client.Subscribe()
pubsub := client.Subscribe(ctx)
Expect(client.Close()).NotTo(HaveOccurred())
_, err := pubsub.Receive()
_, err := pubsub.Receive(ctx)
Expect(err).To(MatchError("redis: client is closed"))
Expect(pubsub.Close()).NotTo(HaveOccurred())
@ -168,26 +168,26 @@ var _ = Describe("Client", func() {
Addr: redisAddr,
DB: 2,
})
Expect(db2.FlushDB().Err()).NotTo(HaveOccurred())
Expect(db2.Get("db").Err()).To(Equal(redis.Nil))
Expect(db2.Set("db", 2, 0).Err()).NotTo(HaveOccurred())
Expect(db2.FlushDB(ctx).Err()).NotTo(HaveOccurred())
Expect(db2.Get(ctx, "db").Err()).To(Equal(redis.Nil))
Expect(db2.Set(ctx, "db", 2, 0).Err()).NotTo(HaveOccurred())
n, err := db2.Get("db").Int64()
n, err := db2.Get(ctx, "db").Int64()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(2)))
Expect(client.Get("db").Err()).To(Equal(redis.Nil))
Expect(client.Get(ctx, "db").Err()).To(Equal(redis.Nil))
Expect(db2.FlushDB().Err()).NotTo(HaveOccurred())
Expect(db2.FlushDB(ctx).Err()).NotTo(HaveOccurred())
Expect(db2.Close()).NotTo(HaveOccurred())
})
It("processes custom commands", func() {
cmd := redis.NewCmd("PING")
_ = client.Process(cmd)
cmd := redis.NewCmd(ctx, "PING")
_ = client.Process(ctx, cmd)
// Flush buffers.
Expect(client.Echo("hello").Err()).NotTo(HaveOccurred())
Expect(client.Echo(ctx, "hello").Err()).NotTo(HaveOccurred())
Expect(cmd.Err()).NotTo(HaveOccurred())
Expect(cmd.Val()).To(Equal("PONG"))
@ -202,13 +202,13 @@ var _ = Describe("Client", func() {
})
// Put bad connection in the pool.
cn, err := client.Pool().Get(context.Background())
cn, err := client.Pool().Get(ctx)
Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{})
client.Pool().Put(cn)
err = client.Ping().Err()
err = client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred())
})
@ -227,12 +227,12 @@ var _ = Describe("Client", func() {
defer clientRetry.Close()
startNoRetry := time.Now()
err := clientNoRetry.Ping().Err()
err := clientNoRetry.Ping(ctx).Err()
Expect(err).To(HaveOccurred())
elapseNoRetry := time.Since(startNoRetry)
startRetry := time.Now()
err = clientRetry.Ping().Err()
err = clientRetry.Ping(ctx).Err()
Expect(err).To(HaveOccurred())
elapseRetry := time.Since(startRetry)
@ -250,7 +250,7 @@ var _ = Describe("Client", func() {
time.Sleep(time.Second)
err = client.Ping().Err()
err = client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred())
cn, err = client.Pool().Get(context.Background())
@ -260,11 +260,11 @@ var _ = Describe("Client", func() {
})
It("should process command with special chars", func() {
set := client.Set("key", "hello1\r\nhello2\r\n", 0)
set := client.Set(ctx, "key", "hello1\r\nhello2\r\n", 0)
Expect(set.Err()).NotTo(HaveOccurred())
Expect(set.Val()).To(Equal("OK"))
get := client.Get("key")
get := client.Get(ctx, "key")
Expect(get.Err()).NotTo(HaveOccurred())
Expect(get.Val()).To(Equal("hello1\r\nhello2\r\n"))
})
@ -272,14 +272,14 @@ var _ = Describe("Client", func() {
It("should handle big vals", func() {
bigVal := bytes.Repeat([]byte{'*'}, 2e6)
err := client.Set("key", bigVal, 0).Err()
err := client.Set(ctx, "key", bigVal, 0).Err()
Expect(err).NotTo(HaveOccurred())
// Reconnect to get new connection.
Expect(client.Close()).NotTo(HaveOccurred())
client = redis.NewClient(redisOptions())
got, err := client.Get("key").Bytes()
got, err := client.Get(ctx, "key").Bytes()
Expect(err).NotTo(HaveOccurred())
Expect(got).To(Equal(bigVal))
})
@ -295,14 +295,14 @@ var _ = Describe("Client timeout", func() {
testTimeout := func() {
It("Ping timeouts", func() {
err := client.Ping().Err()
err := client.Ping(ctx).Err()
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Pipeline timeouts", func() {
_, err := client.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
_, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping(ctx)
return nil
})
Expect(err).To(HaveOccurred())
@ -314,26 +314,26 @@ var _ = Describe("Client timeout", func() {
return
}
pubsub := client.Subscribe()
pubsub := client.Subscribe(ctx)
defer pubsub.Close()
err := pubsub.Subscribe("_")
err := pubsub.Subscribe(ctx, "_")
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Tx timeouts", func() {
err := client.Watch(func(tx *redis.Tx) error {
return tx.Ping().Err()
err := client.Watch(ctx, func(tx *redis.Tx) error {
return tx.Ping(ctx).Err()
})
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Tx Pipeline timeouts", func() {
err := client.Watch(func(tx *redis.Tx) error {
_, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
err := client.Watch(ctx, func(tx *redis.Tx) error {
_, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping(ctx)
return nil
})
return err
@ -373,7 +373,7 @@ var _ = Describe("Client OnConnect", func() {
opt := redisOptions()
opt.DB = 0
opt.OnConnect = func(cn *redis.Conn) error {
return cn.ClientSetName("on_connect").Err()
return cn.ClientSetName(ctx, "on_connect").Err()
}
client = redis.NewClient(opt)
@ -384,7 +384,7 @@ var _ = Describe("Client OnConnect", func() {
})
It("calls OnConnect", func() {
name, err := client.ClientGetName().Result()
name, err := client.ClientGetName(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(name).To(Equal("on_connect"))
})

51
ring.go
View File

@ -261,6 +261,8 @@ func (c *ringShards) Random() (*ringShard, error) {
func (c *ringShards) Heartbeat(frequency time.Duration) {
ticker := time.NewTicker(frequency)
defer ticker.Stop()
ctx := context.TODO()
for range ticker.C {
var rebalance bool
@ -275,7 +277,7 @@ func (c *ringShards) Heartbeat(frequency time.Duration) {
c.mu.RUnlock()
for _, shard := range shards {
err := shard.Client.Ping().Err()
err := shard.Client.Ping(ctx).Err()
if shard.Vote(err == nil || err == pool.ErrPoolTimeout) {
internal.Logger.Printf("ring shard state changed: %s", shard)
rebalance = true
@ -421,21 +423,13 @@ func (c *Ring) WithContext(ctx context.Context) *Ring {
}
// Do creates a Cmd from the args and processes the cmd.
func (c *Ring) Do(args ...interface{}) *Cmd {
return c.DoContext(c.ctx, args...)
}
func (c *Ring) DoContext(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(args...)
_ = c.ProcessContext(ctx, cmd)
func (c *Ring) Do(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(ctx, args...)
_ = c.Process(ctx, cmd)
return cmd
}
func (c *Ring) Process(cmd Cmder) error {
return c.ProcessContext(c.ctx, cmd)
}
func (c *Ring) ProcessContext(ctx context.Context, cmd Cmder) error {
func (c *Ring) Process(ctx context.Context, cmd Cmder) error {
return c.hooks.process(ctx, cmd, c.process)
}
@ -469,7 +463,7 @@ func (c *Ring) Len() int {
}
// Subscribe subscribes the client to the specified channels.
func (c *Ring) Subscribe(channels ...string) *PubSub {
func (c *Ring) Subscribe(ctx context.Context, channels ...string) *PubSub {
if len(channels) == 0 {
panic("at least one channel is required")
}
@ -479,11 +473,11 @@ func (c *Ring) Subscribe(channels ...string) *PubSub {
//TODO: return PubSub with sticky error
panic(err)
}
return shard.Client.Subscribe(channels...)
return shard.Client.Subscribe(ctx, channels...)
}
// PSubscribe subscribes the client to the given patterns.
func (c *Ring) PSubscribe(channels ...string) *PubSub {
func (c *Ring) PSubscribe(ctx context.Context, channels ...string) *PubSub {
if len(channels) == 0 {
panic("at least one channel is required")
}
@ -493,12 +487,15 @@ func (c *Ring) PSubscribe(channels ...string) *PubSub {
//TODO: return PubSub with sticky error
panic(err)
}
return shard.Client.PSubscribe(channels...)
return shard.Client.PSubscribe(ctx, channels...)
}
// ForEachShard concurrently calls the fn on each live shard in the ring.
// It returns the first error if any.
func (c *Ring) ForEachShard(fn func(client *Client) error) error {
func (c *Ring) ForEachShard(
ctx context.Context,
fn func(ctx context.Context, client *Client) error,
) error {
shards := c.shards.List()
var wg sync.WaitGroup
errCh := make(chan error, 1)
@ -510,7 +507,7 @@ func (c *Ring) ForEachShard(fn func(client *Client) error) error {
wg.Add(1)
go func(shard *ringShard) {
defer wg.Done()
err := fn(shard.Client)
err := fn(ctx, shard.Client)
if err != nil {
select {
case errCh <- err:
@ -533,7 +530,7 @@ func (c *Ring) cmdsInfo() (map[string]*CommandInfo, error) {
shards := c.shards.List()
firstErr := errRingShardsDown
for _, shard := range shards {
cmdsInfo, err := shard.Client.Command().Result()
cmdsInfo, err := shard.Client.Command(context.TODO()).Result()
if err == nil {
return cmdsInfo, nil
}
@ -589,7 +586,7 @@ func (c *Ring) _process(ctx context.Context, cmd Cmder) error {
return err
}
lastErr = shard.Client.ProcessContext(ctx, cmd)
lastErr = shard.Client.Process(ctx, cmd)
if lastErr == nil || !isRetryableError(lastErr, cmd.readTimeout() == nil) {
return lastErr
}
@ -597,8 +594,8 @@ func (c *Ring) _process(ctx context.Context, cmd Cmder) error {
return lastErr
}
func (c *Ring) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(fn)
func (c *Ring) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(ctx, fn)
}
func (c *Ring) Pipeline() Pipeliner {
@ -616,8 +613,8 @@ func (c *Ring) processPipeline(ctx context.Context, cmds []Cmder) error {
})
}
func (c *Ring) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(fn)
func (c *Ring) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(ctx, fn)
}
func (c *Ring) TxPipeline() Pipeliner {
@ -688,7 +685,7 @@ func (c *Ring) Close() error {
return c.shards.Close()
}
func (c *Ring) Watch(fn func(*Tx) error, keys ...string) error {
func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error {
if len(keys) == 0 {
return fmt.Errorf("redis: Watch requires at least one key")
}
@ -718,7 +715,7 @@ func (c *Ring) Watch(fn func(*Tx) error, keys ...string) error {
}
}
return shards[0].Client.Watch(fn, keys...)
return shards[0].Client.Watch(ctx, fn, keys...)
}
func newConsistentHash(opt *RingOptions) *consistenthash.Map {

View File

@ -22,7 +22,7 @@ var _ = Describe("Redis Ring", func() {
setRingKeys := func() {
for i := 0; i < 100; i++ {
err := ring.Set(fmt.Sprintf("key%d", i), "value", 0).Err()
err := ring.Set(ctx, fmt.Sprintf("key%d", i), "value", 0).Err()
Expect(err).NotTo(HaveOccurred())
}
}
@ -32,8 +32,8 @@ var _ = Describe("Redis Ring", func() {
opt.HeartbeatFrequency = heartbeat
ring = redis.NewRing(opt)
err := ring.ForEachShard(func(cl *redis.Client) error {
return cl.FlushDB().Err()
err := ring.ForEachShard(ctx, func(ctx context.Context, cl *redis.Client) error {
return cl.FlushDB(ctx).Err()
})
Expect(err).NotTo(HaveOccurred())
})
@ -42,11 +42,11 @@ var _ = Describe("Redis Ring", func() {
Expect(ring.Close()).NotTo(HaveOccurred())
})
It("supports WithContext", func() {
c, cancel := context.WithCancel(context.Background())
It("supports context", func() {
ctx, cancel := context.WithCancel(ctx)
cancel()
err := ring.WithContext(c).Ping().Err()
err := ring.Ping(ctx).Err()
Expect(err).To(MatchError("context canceled"))
})
@ -54,8 +54,8 @@ var _ = Describe("Redis Ring", func() {
setRingKeys()
// Both shards should have some keys now.
Expect(ringShard1.Info("keyspace").Val()).To(ContainSubstring("keys=57"))
Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=43"))
Expect(ringShard1.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=57"))
Expect(ringShard2.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=43"))
})
It("distributes keys when using EVAL", func() {
@ -67,12 +67,12 @@ var _ = Describe("Redis Ring", func() {
var key string
for i := 0; i < 100; i++ {
key = fmt.Sprintf("key%d", i)
err := script.Run(ring, []string{key}, "value").Err()
err := script.Run(ctx, ring, []string{key}, "value").Err()
Expect(err).NotTo(HaveOccurred())
}
Expect(ringShard1.Info("keyspace").Val()).To(ContainSubstring("keys=57"))
Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=43"))
Expect(ringShard1.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=57"))
Expect(ringShard2.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=43"))
})
It("uses single shard when one of the shards is down", func() {
@ -86,7 +86,7 @@ var _ = Describe("Redis Ring", func() {
setRingKeys()
// RingShard1 should have all keys.
Expect(ringShard1.Info("keyspace").Val()).To(ContainSubstring("keys=100"))
Expect(ringShard1.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=100"))
// Start ringShard2.
var err error
@ -100,27 +100,27 @@ var _ = Describe("Redis Ring", func() {
setRingKeys()
// RingShard2 should have its keys.
Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=43"))
Expect(ringShard2.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=43"))
})
It("supports hash tags", func() {
for i := 0; i < 100; i++ {
err := ring.Set(fmt.Sprintf("key%d{tag}", i), "value", 0).Err()
err := ring.Set(ctx, fmt.Sprintf("key%d{tag}", i), "value", 0).Err()
Expect(err).NotTo(HaveOccurred())
}
Expect(ringShard1.Info("keyspace").Val()).ToNot(ContainSubstring("keys="))
Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=100"))
Expect(ringShard1.Info(ctx, "keyspace").Val()).ToNot(ContainSubstring("keys="))
Expect(ringShard2.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=100"))
})
Describe("pipeline", func() {
It("distributes keys", func() {
pipe := ring.Pipeline()
for i := 0; i < 100; i++ {
err := pipe.Set(fmt.Sprintf("key%d", i), "value", 0).Err()
err := pipe.Set(ctx, fmt.Sprintf("key%d", i), "value", 0).Err()
Expect(err).NotTo(HaveOccurred())
}
cmds, err := pipe.Exec()
cmds, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(100))
Expect(pipe.Close()).NotTo(HaveOccurred())
@ -131,8 +131,8 @@ var _ = Describe("Redis Ring", func() {
}
// Both shards should have some keys now.
Expect(ringShard1.Info().Val()).To(ContainSubstring("keys=57"))
Expect(ringShard2.Info().Val()).To(ContainSubstring("keys=43"))
Expect(ringShard1.Info(ctx).Val()).To(ContainSubstring("keys=57"))
Expect(ringShard2.Info(ctx).Val()).To(ContainSubstring("keys=43"))
})
It("is consistent with ring", func() {
@ -144,32 +144,32 @@ var _ = Describe("Redis Ring", func() {
keys = append(keys, string(key))
}
_, err := ring.Pipelined(func(pipe redis.Pipeliner) error {
_, err := ring.Pipelined(ctx, func(pipe redis.Pipeliner) error {
for _, key := range keys {
pipe.Set(key, "value", 0).Err()
pipe.Set(ctx, key, "value", 0).Err()
}
return nil
})
Expect(err).NotTo(HaveOccurred())
for _, key := range keys {
val, err := ring.Get(key).Result()
val, err := ring.Get(ctx, key).Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("value"))
}
})
It("supports hash tags", func() {
_, err := ring.Pipelined(func(pipe redis.Pipeliner) error {
_, err := ring.Pipelined(ctx, func(pipe redis.Pipeliner) error {
for i := 0; i < 100; i++ {
pipe.Set(fmt.Sprintf("key%d{tag}", i), "value", 0).Err()
pipe.Set(ctx, fmt.Sprintf("key%d{tag}", i), "value", 0).Err()
}
return nil
})
Expect(err).NotTo(HaveOccurred())
Expect(ringShard1.Info().Val()).ToNot(ContainSubstring("keys="))
Expect(ringShard2.Info().Val()).To(ContainSubstring("keys=100"))
Expect(ringShard1.Info(ctx).Val()).ToNot(ContainSubstring("keys="))
Expect(ringShard2.Info(ctx).Val()).To(ContainSubstring("keys=100"))
})
})
@ -179,7 +179,7 @@ var _ = Describe("Redis Ring", func() {
opts.Password = "password"
ring = redis.NewRing(opts)
err := ring.Ping().Err()
err := ring.Ping(ctx).Err()
Expect(err).To(MatchError("ERR Client sent AUTH, but no password is set"))
})
@ -205,13 +205,13 @@ var _ = Describe("Redis Ring", func() {
}
ring = redis.NewRing(opts)
err := ring.Ping().Err()
err := ring.Ping(ctx).Err()
Expect(err).To(MatchError("ERR Client sent AUTH, but no password is set"))
})
})
It("supports Process hook", func() {
err := ring.Ping().Err()
err := ring.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred())
var stack []string
@ -229,7 +229,7 @@ var _ = Describe("Redis Ring", func() {
},
})
ring.ForEachShard(func(shard *redis.Client) error {
ring.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error {
shard.AddHook(&hook{
beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
Expect(cmd.String()).To(Equal("ping: "))
@ -245,7 +245,7 @@ var _ = Describe("Redis Ring", func() {
return nil
})
err = ring.Ping().Err()
err = ring.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred())
Expect(stack).To(Equal([]string{
"ring.BeforeProcess",
@ -256,7 +256,7 @@ var _ = Describe("Redis Ring", func() {
})
It("supports Pipeline hook", func() {
err := ring.Ping().Err()
err := ring.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred())
var stack []string
@ -276,7 +276,7 @@ var _ = Describe("Redis Ring", func() {
},
})
ring.ForEachShard(func(shard *redis.Client) error {
ring.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error {
shard.AddHook(&hook{
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
Expect(cmds).To(HaveLen(1))
@ -294,8 +294,8 @@ var _ = Describe("Redis Ring", func() {
return nil
})
_, err = ring.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
_, err = ring.Pipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping(ctx)
return nil
})
Expect(err).NotTo(HaveOccurred())
@ -308,7 +308,7 @@ var _ = Describe("Redis Ring", func() {
})
It("supports TxPipeline hook", func() {
err := ring.Ping().Err()
err := ring.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred())
var stack []string
@ -328,7 +328,7 @@ var _ = Describe("Redis Ring", func() {
},
})
ring.ForEachShard(func(shard *redis.Client) error {
ring.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error {
shard.AddHook(&hook{
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
Expect(cmds).To(HaveLen(3))
@ -346,8 +346,8 @@ var _ = Describe("Redis Ring", func() {
return nil
})
_, err = ring.TxPipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
_, err = ring.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping(ctx)
return nil
})
Expect(err).NotTo(HaveOccurred())
@ -372,13 +372,13 @@ var _ = Describe("empty Redis Ring", func() {
})
It("returns an error", func() {
err := ring.Ping().Err()
err := ring.Ping(ctx).Err()
Expect(err).To(MatchError("redis: all ring shards are down"))
})
It("pipeline returns an error", func() {
_, err := ring.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
_, err := ring.Pipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping(ctx)
return nil
})
Expect(err).To(MatchError("redis: all ring shards are down"))
@ -395,8 +395,8 @@ var _ = Describe("Ring watch", func() {
opt.HeartbeatFrequency = heartbeat
ring = redis.NewRing(opt)
err := ring.ForEachShard(func(cl *redis.Client) error {
return cl.FlushDB().Err()
err := ring.ForEachShard(ctx, func(ctx context.Context, cl *redis.Client) error {
return cl.FlushDB(ctx).Err()
})
Expect(err).NotTo(HaveOccurred())
})
@ -410,14 +410,14 @@ var _ = Describe("Ring watch", func() {
// Transactionally increments key using GET and SET commands.
incr = func(key string) error {
err := ring.Watch(func(tx *redis.Tx) error {
n, err := tx.Get(key).Int64()
err := ring.Watch(ctx, func(tx *redis.Tx) error {
n, err := tx.Get(ctx, key).Int64()
if err != nil && err != redis.Nil {
return err
}
_, err = tx.TxPipelined(func(pipe redis.Pipeliner) error {
pipe.Set(key, strconv.FormatInt(n+1, 10), 0)
_, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set(ctx, key, strconv.FormatInt(n+1, 10), 0)
return nil
})
return err
@ -441,17 +441,17 @@ var _ = Describe("Ring watch", func() {
}
wg.Wait()
n, err := ring.Get("key").Int64()
n, err := ring.Get(ctx, "key").Int64()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(100)))
})
It("should discard", func() {
err := ring.Watch(func(tx *redis.Tx) error {
cmds, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
pipe.Set("key1", "hello1", 0)
err := ring.Watch(ctx, func(tx *redis.Tx) error {
cmds, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set(ctx, "key1", "hello1", 0)
pipe.Discard()
pipe.Set("key2", "hello2", 0)
pipe.Set(ctx, "key2", "hello2", 0)
return nil
})
Expect(err).NotTo(HaveOccurred())
@ -460,23 +460,23 @@ var _ = Describe("Ring watch", func() {
}, "key1", "key2")
Expect(err).NotTo(HaveOccurred())
get := ring.Get("key1")
get := ring.Get(ctx, "key1")
Expect(get.Err()).To(Equal(redis.Nil))
Expect(get.Val()).To(Equal(""))
get = ring.Get("key2")
get = ring.Get(ctx, "key2")
Expect(get.Err()).NotTo(HaveOccurred())
Expect(get.Val()).To(Equal("hello2"))
})
It("returns no error when there are no commands", func() {
err := ring.Watch(func(tx *redis.Tx) error {
_, err := tx.TxPipelined(func(redis.Pipeliner) error { return nil })
err := ring.Watch(ctx, func(tx *redis.Tx) error {
_, err := tx.TxPipelined(ctx, func(redis.Pipeliner) error { return nil })
return err
}, "key")
Expect(err).NotTo(HaveOccurred())
v, err := ring.Ping().Result()
v, err := ring.Ping(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(v).To(Equal("PONG"))
})
@ -484,10 +484,10 @@ var _ = Describe("Ring watch", func() {
It("should exec bulks", func() {
const N = 20000
err := ring.Watch(func(tx *redis.Tx) error {
cmds, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
err := ring.Watch(ctx, func(tx *redis.Tx) error {
cmds, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
for i := 0; i < N; i++ {
pipe.Incr("key")
pipe.Incr(ctx, "key")
}
return nil
})
@ -500,7 +500,7 @@ var _ = Describe("Ring watch", func() {
}, "key")
Expect(err).NotTo(HaveOccurred())
num, err := ring.Get("key").Int64()
num, err := ring.Get(ctx, "key").Int64()
Expect(err).NotTo(HaveOccurred())
Expect(num).To(Equal(int64(N)))
})
@ -508,21 +508,21 @@ var _ = Describe("Ring watch", func() {
It("should Watch/Unwatch", func() {
var C, N int
err := ring.Set("key", "0", 0).Err()
err := ring.Set(ctx, "key", "0", 0).Err()
Expect(err).NotTo(HaveOccurred())
perform(C, func(id int) {
for i := 0; i < N; i++ {
err := ring.Watch(func(tx *redis.Tx) error {
val, err := tx.Get("key").Result()
err := ring.Watch(ctx, func(tx *redis.Tx) error {
val, err := tx.Get(ctx, "key").Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).NotTo(Equal(redis.Nil))
num, err := strconv.ParseInt(val, 10, 64)
Expect(err).NotTo(HaveOccurred())
cmds, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
pipe.Set("key", strconv.FormatInt(num+1, 10), 0)
cmds, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set(ctx, "key", strconv.FormatInt(num+1, 10), 0)
return nil
})
Expect(cmds).To(HaveLen(1))
@ -536,31 +536,31 @@ var _ = Describe("Ring watch", func() {
}
})
val, err := ring.Get("key").Int64()
val, err := ring.Get(ctx, "key").Int64()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal(int64(C * N)))
})
It("should close Tx without closing the client", func() {
err := ring.Watch(func(tx *redis.Tx) error {
_, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
err := ring.Watch(ctx, func(tx *redis.Tx) error {
_, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping(ctx)
return nil
})
return err
}, "key")
Expect(err).NotTo(HaveOccurred())
Expect(ring.Ping().Err()).NotTo(HaveOccurred())
Expect(ring.Ping(ctx).Err()).NotTo(HaveOccurred())
})
It("respects max size on multi", func() {
perform(1000, func(id int) {
var ping *redis.StatusCmd
err := ring.Watch(func(tx *redis.Tx) error {
cmds, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
ping = pipe.Ping()
err := ring.Watch(ctx, func(tx *redis.Tx) error {
cmds, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
ping = pipe.Ping(ctx)
return nil
})
Expect(err).NotTo(HaveOccurred())
@ -573,7 +573,7 @@ var _ = Describe("Ring watch", func() {
Expect(ping.Val()).To(Equal("PONG"))
})
ring.ForEachShard(func(cl *redis.Client) error {
ring.ForEachShard(ctx, func(ctx context.Context, cl *redis.Client) error {
defer GinkgoRecover()
pool := cl.Pool()
@ -597,17 +597,17 @@ var _ = Describe("Ring Tx timeout", func() {
testTimeout := func() {
It("Tx timeouts", func() {
err := ring.Watch(func(tx *redis.Tx) error {
return tx.Ping().Err()
err := ring.Watch(ctx, func(tx *redis.Tx) error {
return tx.Ping(ctx).Err()
}, "foo")
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
It("Tx Pipeline timeouts", func() {
err := ring.Watch(func(tx *redis.Tx) error {
_, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
err := ring.Watch(ctx, func(tx *redis.Tx) error {
_, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping(ctx)
return nil
})
return err
@ -627,17 +627,17 @@ var _ = Describe("Ring Tx timeout", func() {
opt.HeartbeatFrequency = heartbeat
ring = redis.NewRing(opt)
err := ring.ForEachShard(func(client *redis.Client) error {
return client.ClientPause(pause).Err()
err := ring.ForEachShard(ctx, func(ctx context.Context, client *redis.Client) error {
return client.ClientPause(ctx, pause).Err()
})
Expect(err).NotTo(HaveOccurred())
})
AfterEach(func() {
_ = ring.ForEachShard(func(client *redis.Client) error {
_ = ring.ForEachShard(ctx, func(ctx context.Context, client *redis.Client) error {
defer GinkgoRecover()
Eventually(func() error {
return client.Ping().Err()
return client.Ping(ctx).Err()
}, 2*pause).ShouldNot(HaveOccurred())
return nil
})

View File

@ -1,6 +1,7 @@
package redis
import (
"context"
"crypto/sha1"
"encoding/hex"
"io"
@ -8,10 +9,10 @@ import (
)
type scripter interface {
Eval(script string, keys []string, args ...interface{}) *Cmd
EvalSha(sha1 string, keys []string, args ...interface{}) *Cmd
ScriptExists(hashes ...string) *BoolSliceCmd
ScriptLoad(script string) *StringCmd
Eval(ctx context.Context, script string, keys []string, args ...interface{}) *Cmd
EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *Cmd
ScriptExists(ctx context.Context, hashes ...string) *BoolSliceCmd
ScriptLoad(ctx context.Context, script string) *StringCmd
}
var _ scripter = (*Client)(nil)
@ -35,28 +36,28 @@ func (s *Script) Hash() string {
return s.hash
}
func (s *Script) Load(c scripter) *StringCmd {
return c.ScriptLoad(s.src)
func (s *Script) Load(ctx context.Context, c scripter) *StringCmd {
return c.ScriptLoad(ctx, s.src)
}
func (s *Script) Exists(c scripter) *BoolSliceCmd {
return c.ScriptExists(s.hash)
func (s *Script) Exists(ctx context.Context, c scripter) *BoolSliceCmd {
return c.ScriptExists(ctx, s.hash)
}
func (s *Script) Eval(c scripter, keys []string, args ...interface{}) *Cmd {
return c.Eval(s.src, keys, args...)
func (s *Script) Eval(ctx context.Context, c scripter, keys []string, args ...interface{}) *Cmd {
return c.Eval(ctx, s.src, keys, args...)
}
func (s *Script) EvalSha(c scripter, keys []string, args ...interface{}) *Cmd {
return c.EvalSha(s.hash, keys, args...)
func (s *Script) EvalSha(ctx context.Context, c scripter, keys []string, args ...interface{}) *Cmd {
return c.EvalSha(ctx, s.hash, keys, args...)
}
// Run optimistically uses EVALSHA to run the script. If script does not exist
// it is retried using EVAL.
func (s *Script) Run(c scripter, keys []string, args ...interface{}) *Cmd {
r := s.EvalSha(c, keys, args...)
func (s *Script) Run(ctx context.Context, c scripter, keys []string, args ...interface{}) *Cmd {
r := s.EvalSha(ctx, c, keys, args...)
if err := r.Err(); err != nil && strings.HasPrefix(err.Error(), "NOSCRIPT ") {
return s.Eval(c, keys, args...)
return s.Eval(ctx, c, keys, args...)
}
return r
}

View File

@ -135,11 +135,7 @@ func (c *SentinelClient) WithContext(ctx context.Context) *SentinelClient {
return &clone
}
func (c *SentinelClient) Process(cmd Cmder) error {
return c.ProcessContext(c.ctx, cmd)
}
func (c *SentinelClient) ProcessContext(ctx context.Context, cmd Cmder) error {
func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error {
return c.baseClient.process(ctx, cmd)
}
@ -147,8 +143,8 @@ func (c *SentinelClient) pubSub() *PubSub {
pubsub := &PubSub{
opt: c.opt,
newConn: func(channels []string) (*pool.Conn, error) {
return c.newConn(context.TODO())
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
return c.newConn(ctx)
},
closeConn: c.connPool.CloseConn,
}
@ -158,49 +154,49 @@ func (c *SentinelClient) pubSub() *PubSub {
// Ping is used to test if a connection is still alive, or to
// measure latency.
func (c *SentinelClient) Ping() *StringCmd {
cmd := NewStringCmd("ping")
_ = c.Process(cmd)
func (c *SentinelClient) Ping(ctx context.Context) *StringCmd {
cmd := NewStringCmd(ctx, "ping")
_ = c.Process(ctx, cmd)
return cmd
}
// Subscribe subscribes the client to the specified channels.
// Channels can be omitted to create empty subscription.
func (c *SentinelClient) Subscribe(channels ...string) *PubSub {
func (c *SentinelClient) Subscribe(ctx context.Context, channels ...string) *PubSub {
pubsub := c.pubSub()
if len(channels) > 0 {
_ = pubsub.Subscribe(channels...)
_ = pubsub.Subscribe(ctx, channels...)
}
return pubsub
}
// PSubscribe subscribes the client to the given patterns.
// Patterns can be omitted to create empty subscription.
func (c *SentinelClient) PSubscribe(channels ...string) *PubSub {
func (c *SentinelClient) PSubscribe(ctx context.Context, channels ...string) *PubSub {
pubsub := c.pubSub()
if len(channels) > 0 {
_ = pubsub.PSubscribe(channels...)
_ = pubsub.PSubscribe(ctx, channels...)
}
return pubsub
}
func (c *SentinelClient) GetMasterAddrByName(name string) *StringSliceCmd {
cmd := NewStringSliceCmd("sentinel", "get-master-addr-by-name", name)
_ = c.Process(cmd)
func (c *SentinelClient) GetMasterAddrByName(ctx context.Context, name string) *StringSliceCmd {
cmd := NewStringSliceCmd(ctx, "sentinel", "get-master-addr-by-name", name)
_ = c.Process(ctx, cmd)
return cmd
}
func (c *SentinelClient) Sentinels(name string) *SliceCmd {
cmd := NewSliceCmd("sentinel", "sentinels", name)
_ = c.Process(cmd)
func (c *SentinelClient) Sentinels(ctx context.Context, name string) *SliceCmd {
cmd := NewSliceCmd(ctx, "sentinel", "sentinels", name)
_ = c.Process(ctx, cmd)
return cmd
}
// Failover forces a failover as if the master was not reachable, and without
// asking for agreement to other Sentinels.
func (c *SentinelClient) Failover(name string) *StatusCmd {
cmd := NewStatusCmd("sentinel", "failover", name)
_ = c.Process(cmd)
func (c *SentinelClient) Failover(ctx context.Context, name string) *StatusCmd {
cmd := NewStatusCmd(ctx, "sentinel", "failover", name)
_ = c.Process(ctx, cmd)
return cmd
}
@ -208,38 +204,38 @@ func (c *SentinelClient) Failover(name string) *StatusCmd {
// glob-style pattern. The reset process clears any previous state in a master
// (including a failover in progress), and removes every slave and sentinel
// already discovered and associated with the master.
func (c *SentinelClient) Reset(pattern string) *IntCmd {
cmd := NewIntCmd("sentinel", "reset", pattern)
_ = c.Process(cmd)
func (c *SentinelClient) Reset(ctx context.Context, pattern string) *IntCmd {
cmd := NewIntCmd(ctx, "sentinel", "reset", pattern)
_ = c.Process(ctx, cmd)
return cmd
}
// FlushConfig forces Sentinel to rewrite its configuration on disk, including
// the current Sentinel state.
func (c *SentinelClient) FlushConfig() *StatusCmd {
cmd := NewStatusCmd("sentinel", "flushconfig")
_ = c.Process(cmd)
func (c *SentinelClient) FlushConfig(ctx context.Context) *StatusCmd {
cmd := NewStatusCmd(ctx, "sentinel", "flushconfig")
_ = c.Process(ctx, cmd)
return cmd
}
// Master shows the state and info of the specified master.
func (c *SentinelClient) Master(name string) *StringStringMapCmd {
cmd := NewStringStringMapCmd("sentinel", "master", name)
_ = c.Process(cmd)
func (c *SentinelClient) Master(ctx context.Context, name string) *StringStringMapCmd {
cmd := NewStringStringMapCmd(ctx, "sentinel", "master", name)
_ = c.Process(ctx, cmd)
return cmd
}
// Masters shows a list of monitored masters and their state.
func (c *SentinelClient) Masters() *SliceCmd {
cmd := NewSliceCmd("sentinel", "masters")
_ = c.Process(cmd)
func (c *SentinelClient) Masters(ctx context.Context) *SliceCmd {
cmd := NewSliceCmd(ctx, "sentinel", "masters")
_ = c.Process(ctx, cmd)
return cmd
}
// Slaves shows a list of slaves for the specified master and their state.
func (c *SentinelClient) Slaves(name string) *SliceCmd {
cmd := NewSliceCmd("sentinel", "slaves", name)
_ = c.Process(cmd)
func (c *SentinelClient) Slaves(ctx context.Context, name string) *SliceCmd {
cmd := NewSliceCmd(ctx, "sentinel", "slaves", name)
_ = c.Process(ctx, cmd)
return cmd
}
@ -247,33 +243,33 @@ func (c *SentinelClient) Slaves(name string) *SliceCmd {
// quorum needed to failover a master, and the majority needed to authorize the
// failover. This command should be used in monitoring systems to check if a
// Sentinel deployment is ok.
func (c *SentinelClient) CkQuorum(name string) *StringCmd {
cmd := NewStringCmd("sentinel", "ckquorum", name)
_ = c.Process(cmd)
func (c *SentinelClient) CkQuorum(ctx context.Context, name string) *StringCmd {
cmd := NewStringCmd(ctx, "sentinel", "ckquorum", name)
_ = c.Process(ctx, cmd)
return cmd
}
// Monitor tells the Sentinel to start monitoring a new master with the specified
// name, ip, port, and quorum.
func (c *SentinelClient) Monitor(name, ip, port, quorum string) *StringCmd {
cmd := NewStringCmd("sentinel", "monitor", name, ip, port, quorum)
_ = c.Process(cmd)
func (c *SentinelClient) Monitor(ctx context.Context, name, ip, port, quorum string) *StringCmd {
cmd := NewStringCmd(ctx, "sentinel", "monitor", name, ip, port, quorum)
_ = c.Process(ctx, cmd)
return cmd
}
// Set is used in order to change configuration parameters of a specific master.
func (c *SentinelClient) Set(name, option, value string) *StringCmd {
cmd := NewStringCmd("sentinel", "set", name, option, value)
_ = c.Process(cmd)
func (c *SentinelClient) Set(ctx context.Context, name, option, value string) *StringCmd {
cmd := NewStringCmd(ctx, "sentinel", "set", name, option, value)
_ = c.Process(ctx, cmd)
return cmd
}
// Remove is used in order to remove the specified master: the master will no
// longer be monitored, and will totally be removed from the internal state of
// the Sentinel.
func (c *SentinelClient) Remove(name string) *StringCmd {
cmd := NewStringCmd("sentinel", "remove", name)
_ = c.Process(cmd)
func (c *SentinelClient) Remove(ctx context.Context, name string) *StringCmd {
cmd := NewStringCmd(ctx, "sentinel", "remove", name)
_ = c.Process(ctx, cmd)
return cmd
}
@ -325,7 +321,7 @@ func (c *sentinelFailover) Pool() *pool.ConnPool {
}
func (c *sentinelFailover) dial(ctx context.Context, network, _ string) (net.Conn, error) {
addr, err := c.MasterAddr()
addr, err := c.MasterAddr(ctx)
if err != nil {
return nil, err
}
@ -335,8 +331,8 @@ func (c *sentinelFailover) dial(ctx context.Context, network, _ string) (net.Con
return net.DialTimeout("tcp", addr, c.opt.DialTimeout)
}
func (c *sentinelFailover) MasterAddr() (string, error) {
addr, err := c.masterAddr()
func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) {
addr, err := c.masterAddr(ctx)
if err != nil {
return "", err
}
@ -344,13 +340,13 @@ func (c *sentinelFailover) MasterAddr() (string, error) {
return addr, nil
}
func (c *sentinelFailover) masterAddr() (string, error) {
func (c *sentinelFailover) masterAddr(ctx context.Context) (string, error) {
c.mu.RLock()
sentinel := c.sentinel
c.mu.RUnlock()
if sentinel != nil {
addr := c.getMasterAddr(sentinel)
addr := c.getMasterAddr(ctx, sentinel)
if addr != "" {
return addr, nil
}
@ -360,7 +356,7 @@ func (c *sentinelFailover) masterAddr() (string, error) {
defer c.mu.Unlock()
if c.sentinel != nil {
addr := c.getMasterAddr(c.sentinel)
addr := c.getMasterAddr(ctx, c.sentinel)
if addr != "" {
return addr, nil
}
@ -388,7 +384,7 @@ func (c *sentinelFailover) masterAddr() (string, error) {
TLSConfig: c.opt.TLSConfig,
})
masterAddr, err := sentinel.GetMasterAddrByName(c.masterName).Result()
masterAddr, err := sentinel.GetMasterAddrByName(ctx, c.masterName).Result()
if err != nil {
internal.Logger.Printf("sentinel: GetMasterAddrByName master=%q failed: %s",
c.masterName, err)
@ -398,7 +394,7 @@ func (c *sentinelFailover) masterAddr() (string, error) {
// Push working sentinel to the top.
c.sentinelAddrs[0], c.sentinelAddrs[i] = c.sentinelAddrs[i], c.sentinelAddrs[0]
c.setSentinel(sentinel)
c.setSentinel(ctx, sentinel)
addr := net.JoinHostPort(masterAddr[0], masterAddr[1])
return addr, nil
@ -407,8 +403,8 @@ func (c *sentinelFailover) masterAddr() (string, error) {
return "", errors.New("redis: all sentinels are unreachable")
}
func (c *sentinelFailover) getMasterAddr(sentinel *SentinelClient) string {
addr, err := sentinel.GetMasterAddrByName(c.masterName).Result()
func (c *sentinelFailover) getMasterAddr(ctx context.Context, sentinel *SentinelClient) string {
addr, err := sentinel.GetMasterAddrByName(ctx, c.masterName).Result()
if err != nil {
internal.Logger.Printf("sentinel: GetMasterAddrByName name=%q failed: %s",
c.masterName, err)
@ -440,19 +436,19 @@ func (c *sentinelFailover) switchMaster(addr string) {
c._masterAddr = addr
}
func (c *sentinelFailover) setSentinel(sentinel *SentinelClient) {
func (c *sentinelFailover) setSentinel(ctx context.Context, sentinel *SentinelClient) {
if c.sentinel != nil {
panic("not reached")
}
c.sentinel = sentinel
c.discoverSentinels()
c.discoverSentinels(ctx)
c.pubsub = sentinel.Subscribe("+switch-master")
c.pubsub = sentinel.Subscribe(ctx, "+switch-master")
go c.listen(c.pubsub)
}
func (c *sentinelFailover) discoverSentinels() {
sentinels, err := c.sentinel.Sentinels(c.masterName).Result()
func (c *sentinelFailover) discoverSentinels(ctx context.Context) {
sentinels, err := c.sentinel.Sentinels(ctx, c.masterName).Result()
if err != nil {
internal.Logger.Printf("sentinel: Sentinels master=%q failed: %s", c.masterName, err)
return

View File

@ -15,7 +15,7 @@ var _ = Describe("Sentinel", func() {
MasterName: sentinelName,
SentinelAddrs: []string{":" + sentinelPort},
})
Expect(client.FlushDB().Err()).NotTo(HaveOccurred())
Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
})
AfterEach(func() {
@ -24,48 +24,48 @@ var _ = Describe("Sentinel", func() {
It("should facilitate failover", func() {
// Set value on master.
err := client.Set("foo", "master", 0).Err()
err := client.Set(ctx, "foo", "master", 0).Err()
Expect(err).NotTo(HaveOccurred())
// Verify.
val, err := sentinelMaster.Get("foo").Result()
val, err := sentinelMaster.Get(ctx, "foo").Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("master"))
// Create subscription.
ch := client.Subscribe("foo").Channel()
ch := client.Subscribe(ctx, "foo").Channel()
// Wait until replicated.
Eventually(func() string {
return sentinelSlave1.Get("foo").Val()
return sentinelSlave1.Get(ctx, "foo").Val()
}, "1s", "100ms").Should(Equal("master"))
Eventually(func() string {
return sentinelSlave2.Get("foo").Val()
return sentinelSlave2.Get(ctx, "foo").Val()
}, "1s", "100ms").Should(Equal("master"))
// Wait until slaves are picked up by sentinel.
Eventually(func() string {
return sentinel.Info().Val()
return sentinel.Info(ctx).Val()
}, "10s", "100ms").Should(ContainSubstring("slaves=2"))
// Kill master.
sentinelMaster.Shutdown()
sentinelMaster.Shutdown(ctx)
Eventually(func() error {
return sentinelMaster.Ping().Err()
return sentinelMaster.Ping(ctx).Err()
}, "5s", "100ms").Should(HaveOccurred())
// Wait for Redis sentinel to elect new master.
Eventually(func() string {
return sentinelSlave1.Info().Val() + sentinelSlave2.Info().Val()
return sentinelSlave1.Info(ctx).Val() + sentinelSlave2.Info(ctx).Val()
}, "30s", "1s").Should(ContainSubstring("role:master"))
// Check that client picked up new master.
Eventually(func() error {
return client.Get("foo").Err()
return client.Get(ctx, "foo").Err()
}, "5s", "100ms").ShouldNot(HaveOccurred())
// Publish message to check if subscription is renewed.
err = client.Publish("foo", "hello").Err()
err = client.Publish(ctx, "foo", "hello").Err()
Expect(err).NotTo(HaveOccurred())
var msg *redis.Message
@ -82,7 +82,7 @@ var _ = Describe("Sentinel", func() {
SentinelAddrs: []string{":" + sentinelPort},
DB: 1,
})
err := client.Ping().Err()
err := client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred())
})
})

42
tx.go
View File

@ -55,11 +55,7 @@ func (c *Tx) WithContext(ctx context.Context) *Tx {
return &clone
}
func (c *Tx) Process(cmd Cmder) error {
return c.ProcessContext(c.ctx, cmd)
}
func (c *Tx) ProcessContext(ctx context.Context, cmd Cmder) error {
func (c *Tx) Process(ctx context.Context, cmd Cmder) error {
return c.hooks.process(ctx, cmd, c.baseClient.process)
}
@ -67,52 +63,48 @@ func (c *Tx) ProcessContext(ctx context.Context, cmd Cmder) error {
// for conditional execution if there are any keys.
//
// The transaction is automatically closed when fn exits.
func (c *Client) Watch(fn func(*Tx) error, keys ...string) error {
return c.WatchContext(c.ctx, fn, keys...)
}
func (c *Client) WatchContext(ctx context.Context, fn func(*Tx) error, keys ...string) error {
func (c *Client) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error {
tx := c.newTx(ctx)
if len(keys) > 0 {
if err := tx.Watch(keys...).Err(); err != nil {
_ = tx.Close()
if err := tx.Watch(ctx, keys...).Err(); err != nil {
_ = tx.Close(ctx)
return err
}
}
err := fn(tx)
_ = tx.Close()
_ = tx.Close(ctx)
return err
}
// Close closes the transaction, releasing any open resources.
func (c *Tx) Close() error {
_ = c.Unwatch().Err()
func (c *Tx) Close(ctx context.Context) error {
_ = c.Unwatch(ctx).Err()
return c.baseClient.Close()
}
// Watch marks the keys to be watched for conditional execution
// of a transaction.
func (c *Tx) Watch(keys ...string) *StatusCmd {
func (c *Tx) Watch(ctx context.Context, keys ...string) *StatusCmd {
args := make([]interface{}, 1+len(keys))
args[0] = "watch"
for i, key := range keys {
args[1+i] = key
}
cmd := NewStatusCmd(args...)
_ = c.Process(cmd)
cmd := NewStatusCmd(ctx, args...)
_ = c.Process(ctx, cmd)
return cmd
}
// Unwatch flushes all the previously watched keys for a transaction.
func (c *Tx) Unwatch(keys ...string) *StatusCmd {
func (c *Tx) Unwatch(ctx context.Context, keys ...string) *StatusCmd {
args := make([]interface{}, 1+len(keys))
args[0] = "unwatch"
for i, key := range keys {
args[1+i] = key
}
cmd := NewStatusCmd(args...)
_ = c.Process(cmd)
cmd := NewStatusCmd(ctx, args...)
_ = c.Process(ctx, cmd)
return cmd
}
@ -130,8 +122,8 @@ func (c *Tx) Pipeline() Pipeliner {
// Pipelined executes commands queued in the fn outside of the transaction.
// Use TxPipelined if you need transactional behavior.
func (c *Tx) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(fn)
func (c *Tx) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(ctx, fn)
}
// TxPipelined executes commands queued in the fn in the transaction.
@ -142,8 +134,8 @@ func (c *Tx) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
// Exec always returns list of commands. If transaction fails
// TxFailedErr is returned. Otherwise Exec returns an error of the first
// failed command or nil.
func (c *Tx) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(fn)
func (c *Tx) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(ctx, fn)
}
// TxPipeline creates a pipeline. Usually it is more convenient to use TxPipelined.

View File

@ -16,7 +16,7 @@ var _ = Describe("Tx", func() {
BeforeEach(func() {
client = redis.NewClient(redisOptions())
Expect(client.FlushDB().Err()).NotTo(HaveOccurred())
Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
})
AfterEach(func() {
@ -28,14 +28,14 @@ var _ = Describe("Tx", func() {
// Transactionally increments key using GET and SET commands.
incr = func(key string) error {
err := client.Watch(func(tx *redis.Tx) error {
n, err := tx.Get(key).Int64()
err := client.Watch(ctx, func(tx *redis.Tx) error {
n, err := tx.Get(ctx, key).Int64()
if err != nil && err != redis.Nil {
return err
}
_, err = tx.TxPipelined(func(pipe redis.Pipeliner) error {
pipe.Set(key, strconv.FormatInt(n+1, 10), 0)
_, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set(ctx, key, strconv.FormatInt(n+1, 10), 0)
return nil
})
return err
@ -59,17 +59,17 @@ var _ = Describe("Tx", func() {
}
wg.Wait()
n, err := client.Get("key").Int64()
n, err := client.Get(ctx, "key").Int64()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(100)))
})
It("should discard", func() {
err := client.Watch(func(tx *redis.Tx) error {
cmds, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
pipe.Set("key1", "hello1", 0)
err := client.Watch(ctx, func(tx *redis.Tx) error {
cmds, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set(ctx, "key1", "hello1", 0)
pipe.Discard()
pipe.Set("key2", "hello2", 0)
pipe.Set(ctx, "key2", "hello2", 0)
return nil
})
Expect(err).NotTo(HaveOccurred())
@ -78,23 +78,23 @@ var _ = Describe("Tx", func() {
}, "key1", "key2")
Expect(err).NotTo(HaveOccurred())
get := client.Get("key1")
get := client.Get(ctx, "key1")
Expect(get.Err()).To(Equal(redis.Nil))
Expect(get.Val()).To(Equal(""))
get = client.Get("key2")
get = client.Get(ctx, "key2")
Expect(get.Err()).NotTo(HaveOccurred())
Expect(get.Val()).To(Equal("hello2"))
})
It("returns no error when there are no commands", func() {
err := client.Watch(func(tx *redis.Tx) error {
_, err := tx.TxPipelined(func(redis.Pipeliner) error { return nil })
err := client.Watch(ctx, func(tx *redis.Tx) error {
_, err := tx.TxPipelined(ctx, func(redis.Pipeliner) error { return nil })
return err
})
Expect(err).NotTo(HaveOccurred())
v, err := client.Ping().Result()
v, err := client.Ping(ctx).Result()
Expect(err).NotTo(HaveOccurred())
Expect(v).To(Equal("PONG"))
})
@ -102,10 +102,10 @@ var _ = Describe("Tx", func() {
It("should exec bulks", func() {
const N = 20000
err := client.Watch(func(tx *redis.Tx) error {
cmds, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
err := client.Watch(ctx, func(tx *redis.Tx) error {
cmds, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
for i := 0; i < N; i++ {
pipe.Incr("key")
pipe.Incr(ctx, "key")
}
return nil
})
@ -118,7 +118,7 @@ var _ = Describe("Tx", func() {
})
Expect(err).NotTo(HaveOccurred())
num, err := client.Get("key").Int64()
num, err := client.Get(ctx, "key").Int64()
Expect(err).NotTo(HaveOccurred())
Expect(num).To(Equal(int64(N)))
})
@ -132,9 +132,9 @@ var _ = Describe("Tx", func() {
client.Pool().Put(cn)
do := func() error {
err := client.Watch(func(tx *redis.Tx) error {
_, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
err := client.Watch(ctx, func(tx *redis.Tx) error {
_, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping(ctx)
return nil
})
return err

View File

@ -164,13 +164,11 @@ type UniversalClient interface {
Cmdable
Context() context.Context
AddHook(Hook)
Watch(fn func(*Tx) error, keys ...string) error
Do(args ...interface{}) *Cmd
DoContext(ctx context.Context, args ...interface{}) *Cmd
Process(cmd Cmder) error
ProcessContext(ctx context.Context, cmd Cmder) error
Subscribe(channels ...string) *PubSub
PSubscribe(channels ...string) *PubSub
Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error
Do(ctx context.Context, args ...interface{}) *Cmd
Process(ctx context.Context, cmd Cmder) error
Subscribe(ctx context.Context, channels ...string) *PubSub
PSubscribe(ctx context.Context, channels ...string) *PubSub
Close() error
}

View File

@ -21,21 +21,21 @@ var _ = Describe("UniversalClient", func() {
MasterName: sentinelName,
Addrs: []string{":" + sentinelPort},
})
Expect(client.Ping().Err()).NotTo(HaveOccurred())
Expect(client.Ping(ctx).Err()).NotTo(HaveOccurred())
})
It("should connect to simple servers", func() {
client = redis.NewUniversalClient(&redis.UniversalOptions{
Addrs: []string{redisAddr},
})
Expect(client.Ping().Err()).NotTo(HaveOccurred())
Expect(client.Ping(ctx).Err()).NotTo(HaveOccurred())
})
It("should connect to clusters", func() {
client = redis.NewUniversalClient(&redis.UniversalOptions{
Addrs: cluster.addrs(),
})
Expect(client.Ping().Err()).NotTo(HaveOccurred())
Expect(client.Ping(ctx).Err()).NotTo(HaveOccurred())
})
})