Merge branch 'v8'

This commit is contained in:
Vladimir Mihailenco 2020-05-21 10:16:44 +03:00
commit 4440575966
47 changed files with 3373 additions and 3122 deletions

View File

@ -8,6 +8,7 @@ go:
- 1.11.x - 1.11.x
- 1.12.x - 1.12.x
- 1.13.x - 1.13.x
- 1.14.x
- tip - tip
matrix: matrix:

View File

@ -1,48 +1,48 @@
# Redis client for Golang # Redis client for Golang
[![Build Status](https://travis-ci.org/go-redis/redis.png?branch=master)](https://travis-ci.org/go-redis/redis) [![Build Status](https://travis-ci.org/go-redis/redis.png?branch=master)](https://travis-ci.org/go-redis/redis)
[![GoDoc](https://godoc.org/github.com/go-redis/redis?status.svg)](https://godoc.org/github.com/go-redis/redis) [![GoDoc](https://godoc.org/github.com/go-redis/redis?status.svg)](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc)
[![Airbrake](https://img.shields.io/badge/kudos-airbrake.io-orange.svg)](https://airbrake.io) [![Airbrake](https://img.shields.io/badge/kudos-airbrake.io-orange.svg)](https://airbrake.io)
Supports: Supports:
- Redis 3 commands except QUIT, MONITOR, SLOWLOG and SYNC. - Redis 3 commands except QUIT, MONITOR, SLOWLOG and SYNC.
- Automatic connection pooling with [circuit breaker](https://en.wikipedia.org/wiki/Circuit_breaker_design_pattern) support. - Automatic connection pooling with [circuit breaker](https://en.wikipedia.org/wiki/Circuit_breaker_design_pattern) support.
- [Pub/Sub](https://godoc.org/github.com/go-redis/redis#PubSub). - [Pub/Sub](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#PubSub).
- [Transactions](https://godoc.org/github.com/go-redis/redis#example-Client-TxPipeline). - [Transactions](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#example-Client-TxPipeline).
- [Pipeline](https://godoc.org/github.com/go-redis/redis#example-Client-Pipeline) and [TxPipeline](https://godoc.org/github.com/go-redis/redis#example-Client-TxPipeline). - [Pipeline](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#example-Client-Pipeline) and [TxPipeline](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#example-Client-TxPipeline).
- [Scripting](https://godoc.org/github.com/go-redis/redis#Script). - [Scripting](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#Script).
- [Timeouts](https://godoc.org/github.com/go-redis/redis#Options). - [Timeouts](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#Options).
- [Redis Sentinel](https://godoc.org/github.com/go-redis/redis#NewFailoverClient). - [Redis Sentinel](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#NewFailoverClient).
- [Redis Cluster](https://godoc.org/github.com/go-redis/redis#NewClusterClient). - [Redis Cluster](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#NewClusterClient).
- [Cluster of Redis Servers](https://godoc.org/github.com/go-redis/redis#example-NewClusterClient--ManualSetup) without using cluster mode and Redis Sentinel. - [Cluster of Redis Servers](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#example-NewClusterClient--ManualSetup) without using cluster mode and Redis Sentinel.
- [Ring](https://godoc.org/github.com/go-redis/redis#NewRing). - [Ring](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#NewRing).
- [Instrumentation](https://godoc.org/github.com/go-redis/redis#ex-package--Instrumentation). - [Instrumentation](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#ex-package--Instrumentation).
- [Cache friendly](https://github.com/go-redis/cache). - [Cache friendly](https://github.com/go-redis/cache).
- [Rate limiting](https://github.com/go-redis/redis_rate). - [Rate limiting](https://github.com/go-redis/redis_rate).
- [Distributed Locks](https://github.com/bsm/redislock). - [Distributed Locks](https://github.com/bsm/redislock).
API docs: https://godoc.org/github.com/go-redis/redis. API docs: https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc.
Examples: https://godoc.org/github.com/go-redis/redis#pkg-examples. Examples: https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#pkg-examples.
## Installation ## Installation
go-redis requires a Go version with [Modules](https://github.com/golang/go/wiki/Modules) support and uses import versioning. So please make sure to initialize a Go module before installing go-redis: go-redis requires a Go version with [Modules](https://github.com/golang/go/wiki/Modules) support and uses import versioning. So please make sure to initialize a Go module before installing go-redis:
``` shell ```shell
go mod init github.com/my/repo go mod init github.com/my/repo
go get github.com/go-redis/redis/v7 go get github.com/go-redis/redis/v8
``` ```
Import: Import:
``` go ```go
import "github.com/go-redis/redis/v7" import "github.com/go-redis/redis/v8"
``` ```
## Quickstart ## Quickstart
``` go ```go
func ExampleNewClient() { func ExampleNewClient() {
client := redis.NewClient(&redis.Options{ client := redis.NewClient(&redis.Options{
Addr: "localhost:6379", Addr: "localhost:6379",
@ -87,13 +87,13 @@ func ExampleClient() {
## Howto ## Howto
Please go through [examples](https://godoc.org/github.com/go-redis/redis#pkg-examples) to get an idea how to use this package. Please go through [examples](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#pkg-examples) to get an idea how to use this package.
## Look and feel ## Look and feel
Some corner cases: Some corner cases:
``` go ```go
// SET key value EX 10 NX // SET key value EX 10 NX
set, err := client.SetNX("key", "value", 10*time.Second).Result() set, err := client.SetNX("key", "value", 10*time.Second).Result()

View File

@ -8,10 +8,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/go-redis/redis/v7" "github.com/go-redis/redis/v8"
) )
func benchmarkRedisClient(poolSize int) *redis.Client { func benchmarkRedisClient(ctx context.Context, poolSize int) *redis.Client {
client := redis.NewClient(&redis.Options{ client := redis.NewClient(&redis.Options{
Addr: ":6379", Addr: ":6379",
DialTimeout: time.Second, DialTimeout: time.Second,
@ -19,21 +19,22 @@ func benchmarkRedisClient(poolSize int) *redis.Client {
WriteTimeout: time.Second, WriteTimeout: time.Second,
PoolSize: poolSize, PoolSize: poolSize,
}) })
if err := client.FlushDB().Err(); err != nil { if err := client.FlushDB(ctx).Err(); err != nil {
panic(err) panic(err)
} }
return client return client
} }
func BenchmarkRedisPing(b *testing.B) { func BenchmarkRedisPing(b *testing.B) {
client := benchmarkRedisClient(10) ctx := context.Background()
client := benchmarkRedisClient(ctx, 10)
defer client.Close() defer client.Close()
b.ResetTimer() b.ResetTimer()
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
if err := client.Ping().Err(); err != nil { if err := client.Ping(ctx).Err(); err != nil {
b.Fatal(err) b.Fatal(err)
} }
} }
@ -41,14 +42,15 @@ func BenchmarkRedisPing(b *testing.B) {
} }
func BenchmarkRedisGetNil(b *testing.B) { func BenchmarkRedisGetNil(b *testing.B) {
client := benchmarkRedisClient(10) ctx := context.Background()
client := benchmarkRedisClient(ctx, 10)
defer client.Close() defer client.Close()
b.ResetTimer() b.ResetTimer()
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
if err := client.Get("key").Err(); err != redis.Nil { if err := client.Get(ctx, "key").Err(); err != redis.Nil {
b.Fatal(err) b.Fatal(err)
} }
} }
@ -80,7 +82,8 @@ func BenchmarkRedisSetString(b *testing.B) {
} }
for _, bm := range benchmarks { for _, bm := range benchmarks {
b.Run(bm.String(), func(b *testing.B) { b.Run(bm.String(), func(b *testing.B) {
client := benchmarkRedisClient(bm.poolSize) ctx := context.Background()
client := benchmarkRedisClient(ctx, bm.poolSize)
defer client.Close() defer client.Close()
value := strings.Repeat("1", bm.valueSize) value := strings.Repeat("1", bm.valueSize)
@ -89,7 +92,7 @@ func BenchmarkRedisSetString(b *testing.B) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
err := client.Set("key", value, 0).Err() err := client.Set(ctx, "key", value, 0).Err()
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -100,7 +103,8 @@ func BenchmarkRedisSetString(b *testing.B) {
} }
func BenchmarkRedisSetGetBytes(b *testing.B) { func BenchmarkRedisSetGetBytes(b *testing.B) {
client := benchmarkRedisClient(10) ctx := context.Background()
client := benchmarkRedisClient(ctx, 10)
defer client.Close() defer client.Close()
value := bytes.Repeat([]byte{'1'}, 10000) value := bytes.Repeat([]byte{'1'}, 10000)
@ -109,11 +113,11 @@ func BenchmarkRedisSetGetBytes(b *testing.B) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
if err := client.Set("key", value, 0).Err(); err != nil { if err := client.Set(ctx, "key", value, 0).Err(); err != nil {
b.Fatal(err) b.Fatal(err)
} }
got, err := client.Get("key").Bytes() got, err := client.Get(ctx, "key").Bytes()
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -125,10 +129,11 @@ func BenchmarkRedisSetGetBytes(b *testing.B) {
} }
func BenchmarkRedisMGet(b *testing.B) { func BenchmarkRedisMGet(b *testing.B) {
client := benchmarkRedisClient(10) ctx := context.Background()
client := benchmarkRedisClient(ctx, 10)
defer client.Close() defer client.Close()
if err := client.MSet("key1", "hello1", "key2", "hello2").Err(); err != nil { if err := client.MSet(ctx, "key1", "hello1", "key2", "hello2").Err(); err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -136,7 +141,7 @@ func BenchmarkRedisMGet(b *testing.B) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
if err := client.MGet("key1", "key2").Err(); err != nil { if err := client.MGet(ctx, "key1", "key2").Err(); err != nil {
b.Fatal(err) b.Fatal(err)
} }
} }
@ -144,17 +149,18 @@ func BenchmarkRedisMGet(b *testing.B) {
} }
func BenchmarkSetExpire(b *testing.B) { func BenchmarkSetExpire(b *testing.B) {
client := benchmarkRedisClient(10) ctx := context.Background()
client := benchmarkRedisClient(ctx, 10)
defer client.Close() defer client.Close()
b.ResetTimer() b.ResetTimer()
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
if err := client.Set("key", "hello", 0).Err(); err != nil { if err := client.Set(ctx, "key", "hello", 0).Err(); err != nil {
b.Fatal(err) b.Fatal(err)
} }
if err := client.Expire("key", time.Second).Err(); err != nil { if err := client.Expire(ctx, "key", time.Second).Err(); err != nil {
b.Fatal(err) b.Fatal(err)
} }
} }
@ -162,16 +168,17 @@ func BenchmarkSetExpire(b *testing.B) {
} }
func BenchmarkPipeline(b *testing.B) { func BenchmarkPipeline(b *testing.B) {
client := benchmarkRedisClient(10) ctx := context.Background()
client := benchmarkRedisClient(ctx, 10)
defer client.Close() defer client.Close()
b.ResetTimer() b.ResetTimer()
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
_, err := client.Pipelined(func(pipe redis.Pipeliner) error { _, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set("key", "hello", 0) pipe.Set(ctx, "key", "hello", 0)
pipe.Expire("key", time.Second) pipe.Expire(ctx, "key", time.Second)
return nil return nil
}) })
if err != nil { if err != nil {
@ -182,14 +189,15 @@ func BenchmarkPipeline(b *testing.B) {
} }
func BenchmarkZAdd(b *testing.B) { func BenchmarkZAdd(b *testing.B) {
client := benchmarkRedisClient(10) ctx := context.Background()
client := benchmarkRedisClient(ctx, 10)
defer client.Close() defer client.Close()
b.ResetTimer() b.ResetTimer()
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
err := client.ZAdd("key", &redis.Z{ err := client.ZAdd(ctx, "key", &redis.Z{
Score: float64(1), Score: float64(1),
Member: "hello", Member: "hello",
}).Err() }).Err()
@ -203,10 +211,9 @@ func BenchmarkZAdd(b *testing.B) {
var clientSink *redis.Client var clientSink *redis.Client
func BenchmarkWithContext(b *testing.B) { func BenchmarkWithContext(b *testing.B) {
rdb := benchmarkRedisClient(10)
defer rdb.Close()
ctx := context.Background() ctx := context.Background()
rdb := benchmarkRedisClient(ctx, 10)
defer rdb.Close()
b.ResetTimer() b.ResetTimer()
b.ReportAllocs() b.ReportAllocs()
@ -219,11 +226,10 @@ func BenchmarkWithContext(b *testing.B) {
var ringSink *redis.Ring var ringSink *redis.Ring
func BenchmarkRingWithContext(b *testing.B) { func BenchmarkRingWithContext(b *testing.B) {
ctx := context.Background()
rdb := redis.NewRing(&redis.RingOptions{}) rdb := redis.NewRing(&redis.RingOptions{})
defer rdb.Close() defer rdb.Close()
ctx := context.Background()
b.ResetTimer() b.ResetTimer()
b.ReportAllocs() b.ReportAllocs()
@ -248,20 +254,21 @@ func BenchmarkClusterPing(b *testing.B) {
b.Skip("skipping in short mode") b.Skip("skipping in short mode")
} }
ctx := context.Background()
cluster := newClusterScenario() cluster := newClusterScenario()
if err := startCluster(cluster); err != nil { if err := startCluster(ctx, cluster); err != nil {
b.Fatal(err) b.Fatal(err)
} }
defer stopCluster(cluster) defer stopCluster(cluster)
client := cluster.newClusterClient(redisClusterOptions()) client := cluster.newClusterClient(ctx, redisClusterOptions())
defer client.Close() defer client.Close()
b.ResetTimer() b.ResetTimer()
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
err := client.Ping().Err() err := client.Ping(ctx).Err()
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -274,13 +281,14 @@ func BenchmarkClusterSetString(b *testing.B) {
b.Skip("skipping in short mode") b.Skip("skipping in short mode")
} }
ctx := context.Background()
cluster := newClusterScenario() cluster := newClusterScenario()
if err := startCluster(cluster); err != nil { if err := startCluster(ctx, cluster); err != nil {
b.Fatal(err) b.Fatal(err)
} }
defer stopCluster(cluster) defer stopCluster(cluster)
client := cluster.newClusterClient(redisClusterOptions()) client := cluster.newClusterClient(ctx, redisClusterOptions())
defer client.Close() defer client.Close()
value := string(bytes.Repeat([]byte{'1'}, 10000)) value := string(bytes.Repeat([]byte{'1'}, 10000))
@ -289,7 +297,7 @@ func BenchmarkClusterSetString(b *testing.B) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
err := client.Set("key", value, 0).Err() err := client.Set(ctx, "key", value, 0).Err()
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -302,19 +310,20 @@ func BenchmarkClusterReloadState(b *testing.B) {
b.Skip("skipping in short mode") b.Skip("skipping in short mode")
} }
ctx := context.Background()
cluster := newClusterScenario() cluster := newClusterScenario()
if err := startCluster(cluster); err != nil { if err := startCluster(ctx, cluster); err != nil {
b.Fatal(err) b.Fatal(err)
} }
defer stopCluster(cluster) defer stopCluster(cluster)
client := cluster.newClusterClient(redisClusterOptions()) client := cluster.newClusterClient(ctx, redisClusterOptions())
defer client.Close() defer client.Close()
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
err := client.ReloadState() err := client.ReloadState(ctx)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -324,11 +333,10 @@ func BenchmarkClusterReloadState(b *testing.B) {
var clusterSink *redis.ClusterClient var clusterSink *redis.ClusterClient
func BenchmarkClusterWithContext(b *testing.B) { func BenchmarkClusterWithContext(b *testing.B) {
ctx := context.Background()
rdb := redis.NewClusterClient(&redis.ClusterOptions{}) rdb := redis.NewClusterClient(&redis.ClusterOptions{})
defer rdb.Close() defer rdb.Close()
ctx := context.Background()
b.ResetTimer() b.ResetTimer()
b.ReportAllocs() b.ReportAllocs()

View File

@ -13,10 +13,10 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/go-redis/redis/v7/internal" "github.com/go-redis/redis/v8/internal"
"github.com/go-redis/redis/v7/internal/hashtag" "github.com/go-redis/redis/v8/internal/hashtag"
"github.com/go-redis/redis/v7/internal/pool" "github.com/go-redis/redis/v8/internal/pool"
"github.com/go-redis/redis/v7/internal/proto" "github.com/go-redis/redis/v8/internal/proto"
) )
var errClusterNoNodes = fmt.Errorf("redis: cluster has no nodes") var errClusterNoNodes = fmt.Errorf("redis: cluster has no nodes")
@ -200,7 +200,7 @@ func (n *clusterNode) updateLatency() {
var latency uint32 var latency uint32
for i := 0; i < probes; i++ { for i := 0; i < probes; i++ {
start := time.Now() start := time.Now()
n.Client.Ping() n.Client.Ping(context.TODO())
probe := uint32(time.Since(start) / time.Microsecond) probe := uint32(time.Since(start) / time.Microsecond)
latency = (latency + probe) / 2 latency = (latency + probe) / 2
} }
@ -597,20 +597,20 @@ func (c *clusterState) slotNodes(slot int) []*clusterNode {
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
type clusterStateHolder struct { type clusterStateHolder struct {
load func() (*clusterState, error) load func(ctx context.Context) (*clusterState, error)
state atomic.Value state atomic.Value
reloading uint32 // atomic reloading uint32 // atomic
} }
func newClusterStateHolder(fn func() (*clusterState, error)) *clusterStateHolder { func newClusterStateHolder(fn func(ctx context.Context) (*clusterState, error)) *clusterStateHolder {
return &clusterStateHolder{ return &clusterStateHolder{
load: fn, load: fn,
} }
} }
func (c *clusterStateHolder) Reload() (*clusterState, error) { func (c *clusterStateHolder) Reload(ctx context.Context) (*clusterState, error) {
state, err := c.load() state, err := c.load(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -618,14 +618,14 @@ func (c *clusterStateHolder) Reload() (*clusterState, error) {
return state, nil return state, nil
} }
func (c *clusterStateHolder) LazyReload() { func (c *clusterStateHolder) LazyReload(ctx context.Context) {
if !atomic.CompareAndSwapUint32(&c.reloading, 0, 1) { if !atomic.CompareAndSwapUint32(&c.reloading, 0, 1) {
return return
} }
go func() { go func() {
defer atomic.StoreUint32(&c.reloading, 0) defer atomic.StoreUint32(&c.reloading, 0)
_, err := c.Reload() _, err := c.Reload(ctx)
if err != nil { if err != nil {
return return
} }
@ -633,24 +633,24 @@ func (c *clusterStateHolder) LazyReload() {
}() }()
} }
func (c *clusterStateHolder) Get() (*clusterState, error) { func (c *clusterStateHolder) Get(ctx context.Context) (*clusterState, error) {
v := c.state.Load() v := c.state.Load()
if v != nil { if v != nil {
state := v.(*clusterState) state := v.(*clusterState)
if time.Since(state.createdAt) > time.Minute { if time.Since(state.createdAt) > time.Minute {
c.LazyReload() c.LazyReload(ctx)
} }
return state, nil return state, nil
} }
return c.Reload() return c.Reload(ctx)
} }
func (c *clusterStateHolder) ReloadOrGet() (*clusterState, error) { func (c *clusterStateHolder) ReloadOrGet(ctx context.Context) (*clusterState, error) {
state, err := c.Reload() state, err := c.Reload(ctx)
if err == nil { if err == nil {
return state, nil return state, nil
} }
return c.Get() return c.Get(ctx)
} }
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
@ -717,8 +717,8 @@ func (c *ClusterClient) Options() *ClusterOptions {
// ReloadState reloads cluster state. If available it calls ClusterSlots func // ReloadState reloads cluster state. If available it calls ClusterSlots func
// to get cluster slots information. // to get cluster slots information.
func (c *ClusterClient) ReloadState() error { func (c *ClusterClient) ReloadState(ctx context.Context) error {
_, err := c.state.Reload() _, err := c.state.Reload(ctx)
return err return err
} }
@ -731,21 +731,13 @@ func (c *ClusterClient) Close() error {
} }
// Do creates a Cmd from the args and processes the cmd. // Do creates a Cmd from the args and processes the cmd.
func (c *ClusterClient) Do(args ...interface{}) *Cmd { func (c *ClusterClient) Do(ctx context.Context, args ...interface{}) *Cmd {
return c.DoContext(c.ctx, args...) cmd := NewCmd(ctx, args...)
} _ = c.Process(ctx, cmd)
func (c *ClusterClient) DoContext(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(args...)
_ = c.ProcessContext(ctx, cmd)
return cmd return cmd
} }
func (c *ClusterClient) Process(cmd Cmder) error { func (c *ClusterClient) Process(ctx context.Context, cmd Cmder) error {
return c.ProcessContext(c.ctx, cmd)
}
func (c *ClusterClient) ProcessContext(ctx context.Context, cmd Cmder) error {
return c.hooks.process(ctx, cmd, c.process) return c.hooks.process(ctx, cmd, c.process)
} }
@ -774,7 +766,7 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error {
if node == nil { if node == nil {
var err error var err error
node, err = c.cmdNode(cmdInfo, slot) node, err = c.cmdNode(ctx, cmdInfo, slot)
if err != nil { if err != nil {
return err return err
} }
@ -782,13 +774,13 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error {
if ask { if ask {
pipe := node.Client.Pipeline() pipe := node.Client.Pipeline()
_ = pipe.Process(NewCmd("asking")) _ = pipe.Process(ctx, NewCmd(ctx, "asking"))
_ = pipe.Process(cmd) _ = pipe.Process(ctx, cmd)
_, lastErr = pipe.ExecContext(ctx) _, lastErr = pipe.Exec(ctx)
_ = pipe.Close() _ = pipe.Close()
ask = false ask = false
} else { } else {
lastErr = node.Client.ProcessContext(ctx, cmd) lastErr = node.Client.Process(ctx, cmd)
} }
// If there is no error - we are done. // If there is no error - we are done.
@ -796,7 +788,7 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error {
return nil return nil
} }
if lastErr != Nil { if lastErr != Nil {
c.state.LazyReload() c.state.LazyReload(ctx)
} }
if lastErr == pool.ErrClosed || isReadOnlyError(lastErr) { if lastErr == pool.ErrClosed || isReadOnlyError(lastErr) {
node = nil node = nil
@ -841,8 +833,11 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error {
// ForEachMaster concurrently calls the fn on each master node in the cluster. // ForEachMaster concurrently calls the fn on each master node in the cluster.
// It returns the first error if any. // It returns the first error if any.
func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { func (c *ClusterClient) ForEachMaster(
state, err := c.state.ReloadOrGet() ctx context.Context,
fn func(ctx context.Context, client *Client) error,
) error {
state, err := c.state.ReloadOrGet(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -854,7 +849,7 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error {
wg.Add(1) wg.Add(1)
go func(node *clusterNode) { go func(node *clusterNode) {
defer wg.Done() defer wg.Done()
err := fn(node.Client) err := fn(ctx, node.Client)
if err != nil { if err != nil {
select { select {
case errCh <- err: case errCh <- err:
@ -876,8 +871,11 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error {
// ForEachSlave concurrently calls the fn on each slave node in the cluster. // ForEachSlave concurrently calls the fn on each slave node in the cluster.
// It returns the first error if any. // It returns the first error if any.
func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error { func (c *ClusterClient) ForEachSlave(
state, err := c.state.ReloadOrGet() ctx context.Context,
fn func(ctx context.Context, client *Client) error,
) error {
state, err := c.state.ReloadOrGet(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -889,7 +887,7 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error {
wg.Add(1) wg.Add(1)
go func(node *clusterNode) { go func(node *clusterNode) {
defer wg.Done() defer wg.Done()
err := fn(node.Client) err := fn(ctx, node.Client)
if err != nil { if err != nil {
select { select {
case errCh <- err: case errCh <- err:
@ -911,8 +909,11 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error {
// ForEachNode concurrently calls the fn on each known node in the cluster. // ForEachNode concurrently calls the fn on each known node in the cluster.
// It returns the first error if any. // It returns the first error if any.
func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error { func (c *ClusterClient) ForEachNode(
state, err := c.state.ReloadOrGet() ctx context.Context,
fn func(ctx context.Context, client *Client) error,
) error {
state, err := c.state.ReloadOrGet(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -922,7 +923,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error {
worker := func(node *clusterNode) { worker := func(node *clusterNode) {
defer wg.Done() defer wg.Done()
err := fn(node.Client) err := fn(ctx, node.Client)
if err != nil { if err != nil {
select { select {
case errCh <- err: case errCh <- err:
@ -954,7 +955,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error {
func (c *ClusterClient) PoolStats() *PoolStats { func (c *ClusterClient) PoolStats() *PoolStats {
var acc PoolStats var acc PoolStats
state, _ := c.state.Get() state, _ := c.state.Get(context.TODO())
if state == nil { if state == nil {
return &acc return &acc
} }
@ -984,7 +985,7 @@ func (c *ClusterClient) PoolStats() *PoolStats {
return &acc return &acc
} }
func (c *ClusterClient) loadState() (*clusterState, error) { func (c *ClusterClient) loadState(ctx context.Context) (*clusterState, error) {
if c.opt.ClusterSlots != nil { if c.opt.ClusterSlots != nil {
slots, err := c.opt.ClusterSlots() slots, err := c.opt.ClusterSlots()
if err != nil { if err != nil {
@ -1008,7 +1009,7 @@ func (c *ClusterClient) loadState() (*clusterState, error) {
continue continue
} }
slots, err := node.Client.ClusterSlots().Result() slots, err := node.Client.ClusterSlots(ctx).Result()
if err != nil { if err != nil {
if firstErr == nil { if firstErr == nil {
firstErr = err firstErr = err
@ -1051,8 +1052,8 @@ func (c *ClusterClient) Pipeline() Pipeliner {
return &pipe return &pipe
} }
func (c *ClusterClient) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { func (c *ClusterClient) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(fn) return c.Pipeline().Pipelined(ctx, fn)
} }
func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error { func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error {
@ -1061,7 +1062,7 @@ func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error
func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) error { func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) error {
cmdsMap := newCmdsMap() cmdsMap := newCmdsMap()
err := c.mapCmdsByNode(cmdsMap, cmds) err := c.mapCmdsByNode(ctx, cmdsMap, cmds)
if err != nil { if err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)
return err return err
@ -1088,7 +1089,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro
return return
} }
if attempt < c.opt.MaxRedirects { if attempt < c.opt.MaxRedirects {
if err := c.mapCmdsByNode(failedCmds, cmds); err != nil { if err := c.mapCmdsByNode(ctx, failedCmds, cmds); err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)
} }
} else { } else {
@ -1107,8 +1108,8 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro
return cmdsFirstErr(cmds) return cmdsFirstErr(cmds)
} }
func (c *ClusterClient) mapCmdsByNode(cmdsMap *cmdsMap, cmds []Cmder) error { func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmds []Cmder) error {
state, err := c.state.Get() state, err := c.state.Get(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -1159,21 +1160,25 @@ func (c *ClusterClient) _processPipelineNode(
} }
return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
return c.pipelineReadCmds(node, rd, cmds, failedCmds) return c.pipelineReadCmds(ctx, node, rd, cmds, failedCmds)
}) })
}) })
}) })
} }
func (c *ClusterClient) pipelineReadCmds( func (c *ClusterClient) pipelineReadCmds(
node *clusterNode, rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap, ctx context.Context,
node *clusterNode,
rd *proto.Reader,
cmds []Cmder,
failedCmds *cmdsMap,
) error { ) error {
for _, cmd := range cmds { for _, cmd := range cmds {
err := cmd.readReply(rd) err := cmd.readReply(rd)
if err == nil { if err == nil {
continue continue
} }
if c.checkMovedErr(cmd, err, failedCmds) { if c.checkMovedErr(ctx, cmd, err, failedCmds) {
continue continue
} }
@ -1190,7 +1195,7 @@ func (c *ClusterClient) pipelineReadCmds(
} }
func (c *ClusterClient) checkMovedErr( func (c *ClusterClient) checkMovedErr(
cmd Cmder, err error, failedCmds *cmdsMap, ctx context.Context, cmd Cmder, err error, failedCmds *cmdsMap,
) bool { ) bool {
moved, ask, addr := isMovedError(err) moved, ask, addr := isMovedError(err)
if !moved && !ask { if !moved && !ask {
@ -1203,13 +1208,13 @@ func (c *ClusterClient) checkMovedErr(
} }
if moved { if moved {
c.state.LazyReload() c.state.LazyReload(ctx)
failedCmds.Add(node, cmd) failedCmds.Add(node, cmd)
return true return true
} }
if ask { if ask {
failedCmds.Add(node, NewCmd("asking"), cmd) failedCmds.Add(node, NewCmd(ctx, "asking"), cmd)
return true return true
} }
@ -1226,8 +1231,8 @@ func (c *ClusterClient) TxPipeline() Pipeliner {
return &pipe return &pipe
} }
func (c *ClusterClient) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) { func (c *ClusterClient) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(fn) return c.TxPipeline().Pipelined(ctx, fn)
} }
func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) error { func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
@ -1235,7 +1240,7 @@ func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) err
} }
func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) error { func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) error {
state, err := c.state.Get() state, err := c.state.Get(ctx)
if err != nil { if err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)
return err return err
@ -1271,7 +1276,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er
return return
} }
if attempt < c.opt.MaxRedirects { if attempt < c.opt.MaxRedirects {
if err := c.mapCmdsByNode(failedCmds, cmds); err != nil { if err := c.mapCmdsByNode(ctx, failedCmds, cmds); err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)
} }
} else { } else {
@ -1317,11 +1322,11 @@ func (c *ClusterClient) _processTxPipelineNode(
// Trim multi and exec. // Trim multi and exec.
cmds = cmds[1 : len(cmds)-1] cmds = cmds[1 : len(cmds)-1]
err := c.txPipelineReadQueued(rd, statusCmd, cmds, failedCmds) err := c.txPipelineReadQueued(ctx, rd, statusCmd, cmds, failedCmds)
if err != nil { if err != nil {
moved, ask, addr := isMovedError(err) moved, ask, addr := isMovedError(err)
if moved || ask { if moved || ask {
return c.cmdsMoved(cmds, moved, ask, addr, failedCmds) return c.cmdsMoved(ctx, cmds, moved, ask, addr, failedCmds)
} }
return err return err
} }
@ -1333,7 +1338,11 @@ func (c *ClusterClient) _processTxPipelineNode(
} }
func (c *ClusterClient) txPipelineReadQueued( func (c *ClusterClient) txPipelineReadQueued(
rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder, failedCmds *cmdsMap, ctx context.Context,
rd *proto.Reader,
statusCmd *StatusCmd,
cmds []Cmder,
failedCmds *cmdsMap,
) error { ) error {
// Parse queued replies. // Parse queued replies.
if err := statusCmd.readReply(rd); err != nil { if err := statusCmd.readReply(rd); err != nil {
@ -1342,7 +1351,7 @@ func (c *ClusterClient) txPipelineReadQueued(
for _, cmd := range cmds { for _, cmd := range cmds {
err := statusCmd.readReply(rd) err := statusCmd.readReply(rd)
if err == nil || c.checkMovedErr(cmd, err, failedCmds) || isRedisError(err) { if err == nil || c.checkMovedErr(ctx, cmd, err, failedCmds) || isRedisError(err) {
continue continue
} }
return err return err
@ -1370,7 +1379,10 @@ func (c *ClusterClient) txPipelineReadQueued(
} }
func (c *ClusterClient) cmdsMoved( func (c *ClusterClient) cmdsMoved(
cmds []Cmder, moved, ask bool, addr string, failedCmds *cmdsMap, ctx context.Context, cmds []Cmder,
moved, ask bool,
addr string,
failedCmds *cmdsMap,
) error { ) error {
node, err := c.nodes.Get(addr) node, err := c.nodes.Get(addr)
if err != nil { if err != nil {
@ -1378,7 +1390,7 @@ func (c *ClusterClient) cmdsMoved(
} }
if moved { if moved {
c.state.LazyReload() c.state.LazyReload(ctx)
for _, cmd := range cmds { for _, cmd := range cmds {
failedCmds.Add(node, cmd) failedCmds.Add(node, cmd)
} }
@ -1387,7 +1399,7 @@ func (c *ClusterClient) cmdsMoved(
if ask { if ask {
for _, cmd := range cmds { for _, cmd := range cmds {
failedCmds.Add(node, NewCmd("asking"), cmd) failedCmds.Add(node, NewCmd(ctx, "asking"), cmd)
} }
return nil return nil
} }
@ -1395,11 +1407,7 @@ func (c *ClusterClient) cmdsMoved(
return nil return nil
} }
func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error { func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error {
return c.WatchContext(c.ctx, fn, keys...)
}
func (c *ClusterClient) WatchContext(ctx context.Context, fn func(*Tx) error, keys ...string) error {
if len(keys) == 0 { if len(keys) == 0 {
return fmt.Errorf("redis: Watch requires at least one key") return fmt.Errorf("redis: Watch requires at least one key")
} }
@ -1412,7 +1420,7 @@ func (c *ClusterClient) WatchContext(ctx context.Context, fn func(*Tx) error, ke
} }
} }
node, err := c.slotMasterNode(slot) node, err := c.slotMasterNode(ctx, slot)
if err != nil { if err != nil {
return err return err
} }
@ -1424,12 +1432,12 @@ func (c *ClusterClient) WatchContext(ctx context.Context, fn func(*Tx) error, ke
} }
} }
err = node.Client.WatchContext(ctx, fn, keys...) err = node.Client.Watch(ctx, fn, keys...)
if err == nil { if err == nil {
break break
} }
if err != Nil { if err != Nil {
c.state.LazyReload() c.state.LazyReload(ctx)
} }
moved, ask, addr := isMovedError(err) moved, ask, addr := isMovedError(err)
@ -1442,7 +1450,7 @@ func (c *ClusterClient) WatchContext(ctx context.Context, fn func(*Tx) error, ke
} }
if err == pool.ErrClosed || isReadOnlyError(err) { if err == pool.ErrClosed || isReadOnlyError(err) {
node, err = c.slotMasterNode(slot) node, err = c.slotMasterNode(ctx, slot)
if err != nil { if err != nil {
return err return err
} }
@ -1464,7 +1472,7 @@ func (c *ClusterClient) pubSub() *PubSub {
pubsub := &PubSub{ pubsub := &PubSub{
opt: c.opt.clientOptions(), opt: c.opt.clientOptions(),
newConn: func(channels []string) (*pool.Conn, error) { newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
if node != nil { if node != nil {
panic("node != nil") panic("node != nil")
} }
@ -1472,7 +1480,7 @@ func (c *ClusterClient) pubSub() *PubSub {
var err error var err error
if len(channels) > 0 { if len(channels) > 0 {
slot := hashtag.Slot(channels[0]) slot := hashtag.Slot(channels[0])
node, err = c.slotMasterNode(slot) node, err = c.slotMasterNode(ctx, slot)
} else { } else {
node, err = c.nodes.Random() node, err = c.nodes.Random()
} }
@ -1502,20 +1510,20 @@ func (c *ClusterClient) pubSub() *PubSub {
// Subscribe subscribes the client to the specified channels. // Subscribe subscribes the client to the specified channels.
// Channels can be omitted to create empty subscription. // Channels can be omitted to create empty subscription.
func (c *ClusterClient) Subscribe(channels ...string) *PubSub { func (c *ClusterClient) Subscribe(ctx context.Context, channels ...string) *PubSub {
pubsub := c.pubSub() pubsub := c.pubSub()
if len(channels) > 0 { if len(channels) > 0 {
_ = pubsub.Subscribe(channels...) _ = pubsub.Subscribe(ctx, channels...)
} }
return pubsub return pubsub
} }
// PSubscribe subscribes the client to the given patterns. // PSubscribe subscribes the client to the given patterns.
// Patterns can be omitted to create empty subscription. // Patterns can be omitted to create empty subscription.
func (c *ClusterClient) PSubscribe(channels ...string) *PubSub { func (c *ClusterClient) PSubscribe(ctx context.Context, channels ...string) *PubSub {
pubsub := c.pubSub() pubsub := c.pubSub()
if len(channels) > 0 { if len(channels) > 0 {
_ = pubsub.PSubscribe(channels...) _ = pubsub.PSubscribe(ctx, channels...)
} }
return pubsub return pubsub
} }
@ -1540,7 +1548,7 @@ func (c *ClusterClient) cmdsInfo() (map[string]*CommandInfo, error) {
continue continue
} }
info, err := node.Client.Command().Result() info, err := node.Client.Command(context.TODO()).Result()
if err == nil { if err == nil {
return info, nil return info, nil
} }
@ -1582,8 +1590,12 @@ func cmdSlot(cmd Cmder, pos int) int {
return hashtag.Slot(firstKey) return hashtag.Slot(firstKey)
} }
func (c *ClusterClient) cmdNode(cmdInfo *CommandInfo, slot int) (*clusterNode, error) { func (c *ClusterClient) cmdNode(
state, err := c.state.Get() ctx context.Context,
cmdInfo *CommandInfo,
slot int,
) (*clusterNode, error) {
state, err := c.state.Get(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1604,8 +1616,8 @@ func (c *clusterClient) slotReadOnlyNode(state *clusterState, slot int) (*cluste
return state.slotSlaveNode(slot) return state.slotSlaveNode(slot)
} }
func (c *ClusterClient) slotMasterNode(slot int) (*clusterNode, error) { func (c *ClusterClient) slotMasterNode(ctx context.Context, slot int) (*clusterNode, error) {
state, err := c.state.Get() state, err := c.state.Get(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,12 +1,15 @@
package redis package redis
import "sync/atomic" import (
"context"
"sync/atomic"
)
func (c *ClusterClient) DBSize() *IntCmd { func (c *ClusterClient) DBSize(ctx context.Context) *IntCmd {
cmd := NewIntCmd("dbsize") cmd := NewIntCmd(ctx, "dbsize")
var size int64 var size int64
err := c.ForEachMaster(func(master *Client) error { err := c.ForEachMaster(ctx, func(ctx context.Context, master *Client) error {
n, err := master.DBSize().Result() n, err := master.DBSize(ctx).Result()
if err != nil { if err != nil {
return err return err
} }

View File

@ -9,8 +9,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/go-redis/redis/v7" "github.com/go-redis/redis/v8"
"github.com/go-redis/redis/v7/internal/hashtag" "github.com/go-redis/redis/v8/internal/hashtag"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -53,7 +53,9 @@ func (s *clusterScenario) newClusterClientUnsafe(opt *redis.ClusterOptions) *red
} }
func (s *clusterScenario) newClusterClient(opt *redis.ClusterOptions) *redis.ClusterClient { func (s *clusterScenario) newClusterClient(
ctx context.Context, opt *redis.ClusterOptions,
) *redis.ClusterClient {
client := s.newClusterClientUnsafe(opt) client := s.newClusterClientUnsafe(opt)
err := eventually(func() error { err := eventually(func() error {
@ -61,12 +63,12 @@ func (s *clusterScenario) newClusterClient(opt *redis.ClusterOptions) *redis.Clu
return nil return nil
} }
state, err := client.LoadState() state, err := client.LoadState(ctx)
if err != nil { if err != nil {
return err return err
} }
if !state.IsConsistent() { if !state.IsConsistent(ctx) {
return fmt.Errorf("cluster state is not consistent") return fmt.Errorf("cluster state is not consistent")
} }
@ -79,7 +81,7 @@ func (s *clusterScenario) newClusterClient(opt *redis.ClusterOptions) *redis.Clu
return client return client
} }
func startCluster(scenario *clusterScenario) error { func startCluster(ctx context.Context, scenario *clusterScenario) error {
// Start processes and collect node ids // Start processes and collect node ids
for pos, port := range scenario.ports { for pos, port := range scenario.ports {
process, err := startRedis(port, "--cluster-enabled", "yes") process, err := startRedis(port, "--cluster-enabled", "yes")
@ -91,7 +93,7 @@ func startCluster(scenario *clusterScenario) error {
Addr: ":" + port, Addr: ":" + port,
}) })
info, err := client.ClusterNodes().Result() info, err := client.ClusterNodes(ctx).Result()
if err != nil { if err != nil {
return err return err
} }
@ -103,7 +105,7 @@ func startCluster(scenario *clusterScenario) error {
// Meet cluster nodes. // Meet cluster nodes.
for _, client := range scenario.clients { for _, client := range scenario.clients {
err := client.ClusterMeet("127.0.0.1", scenario.ports[0]).Err() err := client.ClusterMeet(ctx, "127.0.0.1", scenario.ports[0]).Err()
if err != nil { if err != nil {
return err return err
} }
@ -112,7 +114,7 @@ func startCluster(scenario *clusterScenario) error {
// Bootstrap masters. // Bootstrap masters.
slots := []int{0, 5000, 10000, 16384} slots := []int{0, 5000, 10000, 16384}
for pos, master := range scenario.masters() { for pos, master := range scenario.masters() {
err := master.ClusterAddSlotsRange(slots[pos], slots[pos+1]-1).Err() err := master.ClusterAddSlotsRange(ctx, slots[pos], slots[pos+1]-1).Err()
if err != nil { if err != nil {
return err return err
} }
@ -124,7 +126,7 @@ func startCluster(scenario *clusterScenario) error {
// Wait until master is available // Wait until master is available
err := eventually(func() error { err := eventually(func() error {
s := slave.ClusterNodes().Val() s := slave.ClusterNodes(ctx).Val()
wanted := masterID wanted := masterID
if !strings.Contains(s, wanted) { if !strings.Contains(s, wanted) {
return fmt.Errorf("%q does not contain %q", s, wanted) return fmt.Errorf("%q does not contain %q", s, wanted)
@ -135,7 +137,7 @@ func startCluster(scenario *clusterScenario) error {
return err return err
} }
err = slave.ClusterReplicate(masterID).Err() err = slave.ClusterReplicate(ctx, masterID).Err()
if err != nil { if err != nil {
return err return err
} }
@ -175,7 +177,7 @@ func startCluster(scenario *clusterScenario) error {
}} }}
for _, client := range scenario.clients { for _, client := range scenario.clients {
err := eventually(func() error { err := eventually(func() error {
res, err := client.ClusterSlots().Result() res, err := client.ClusterSlots(ctx).Result()
if err != nil { if err != nil {
return err return err
} }
@ -243,48 +245,48 @@ var _ = Describe("ClusterClient", func() {
assertClusterClient := func() { assertClusterClient := func() {
It("supports WithContext", func() { It("supports WithContext", func() {
c, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(ctx)
cancel() cancel()
err := client.WithContext(c).Ping().Err() err := client.Ping(ctx).Err()
Expect(err).To(MatchError("context canceled")) Expect(err).To(MatchError("context canceled"))
}) })
It("should GET/SET/DEL", func() { It("should GET/SET/DEL", func() {
err := client.Get("A").Err() err := client.Get(ctx, "A").Err()
Expect(err).To(Equal(redis.Nil)) Expect(err).To(Equal(redis.Nil))
err = client.Set("A", "VALUE", 0).Err() err = client.Set(ctx, "A", "VALUE", 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Eventually(func() string { Eventually(func() string {
return client.Get("A").Val() return client.Get(ctx, "A").Val()
}, 30*time.Second).Should(Equal("VALUE")) }, 30*time.Second).Should(Equal("VALUE"))
cnt, err := client.Del("A").Result() cnt, err := client.Del(ctx, "A").Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cnt).To(Equal(int64(1))) Expect(cnt).To(Equal(int64(1)))
}) })
It("GET follows redirects", func() { It("GET follows redirects", func() {
err := client.Set("A", "VALUE", 0).Err() err := client.Set(ctx, "A", "VALUE", 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
if !failover { if !failover {
Eventually(func() int64 { Eventually(func() int64 {
nodes, err := client.Nodes("A") nodes, err := client.Nodes(ctx, "A")
if err != nil { if err != nil {
return 0 return 0
} }
return nodes[1].Client.DBSize().Val() return nodes[1].Client.DBSize(ctx).Val()
}, 30*time.Second).Should(Equal(int64(1))) }, 30*time.Second).Should(Equal(int64(1)))
Eventually(func() error { Eventually(func() error {
return client.SwapNodes("A") return client.SwapNodes(ctx, "A")
}, 30*time.Second).ShouldNot(HaveOccurred()) }, 30*time.Second).ShouldNot(HaveOccurred())
} }
v, err := client.Get("A").Result() v, err := client.Get(ctx, "A").Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(v).To(Equal("VALUE")) Expect(v).To(Equal("VALUE"))
}) })
@ -292,28 +294,28 @@ var _ = Describe("ClusterClient", func() {
It("SET follows redirects", func() { It("SET follows redirects", func() {
if !failover { if !failover {
Eventually(func() error { Eventually(func() error {
return client.SwapNodes("A") return client.SwapNodes(ctx, "A")
}, 30*time.Second).ShouldNot(HaveOccurred()) }, 30*time.Second).ShouldNot(HaveOccurred())
} }
err := client.Set("A", "VALUE", 0).Err() err := client.Set(ctx, "A", "VALUE", 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
v, err := client.Get("A").Result() v, err := client.Get(ctx, "A").Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(v).To(Equal("VALUE")) Expect(v).To(Equal("VALUE"))
}) })
It("distributes keys", func() { It("distributes keys", func() {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
err := client.Set(fmt.Sprintf("key%d", i), "value", 0).Err() err := client.Set(ctx, fmt.Sprintf("key%d", i), "value", 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
client.ForEachMaster(func(master *redis.Client) error { client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error {
defer GinkgoRecover() defer GinkgoRecover()
Eventually(func() string { Eventually(func() string {
return master.Info("keyspace").Val() return master.Info(ctx, "keyspace").Val()
}, 30*time.Second).Should(Or( }, 30*time.Second).Should(Or(
ContainSubstring("keys=31"), ContainSubstring("keys=31"),
ContainSubstring("keys=29"), ContainSubstring("keys=29"),
@ -332,14 +334,14 @@ var _ = Describe("ClusterClient", func() {
var key string var key string
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
key = fmt.Sprintf("key%d", i) key = fmt.Sprintf("key%d", i)
err := script.Run(client, []string{key}, "value").Err() err := script.Run(ctx, client, []string{key}, "value").Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
client.ForEachMaster(func(master *redis.Client) error { client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error {
defer GinkgoRecover() defer GinkgoRecover()
Eventually(func() string { Eventually(func() string {
return master.Info("keyspace").Val() return master.Info(ctx, "keyspace").Val()
}, 30*time.Second).Should(Or( }, 30*time.Second).Should(Or(
ContainSubstring("keys=31"), ContainSubstring("keys=31"),
ContainSubstring("keys=29"), ContainSubstring("keys=29"),
@ -354,14 +356,14 @@ var _ = Describe("ClusterClient", func() {
// Transactionally increments key using GET and SET commands. // Transactionally increments key using GET and SET commands.
incr = func(key string) error { incr = func(key string) error {
err := client.Watch(func(tx *redis.Tx) error { err := client.Watch(ctx, func(tx *redis.Tx) error {
n, err := tx.Get(key).Int64() n, err := tx.Get(ctx, key).Int64()
if err != nil && err != redis.Nil { if err != nil && err != redis.Nil {
return err return err
} }
_, err = tx.TxPipelined(func(pipe redis.Pipeliner) error { _, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set(key, strconv.FormatInt(n+1, 10), 0) pipe.Set(ctx, key, strconv.FormatInt(n+1, 10), 0)
return nil return nil
}) })
return err return err
@ -386,7 +388,7 @@ var _ = Describe("ClusterClient", func() {
wg.Wait() wg.Wait()
Eventually(func() string { Eventually(func() string {
return client.Get("key").Val() return client.Get(ctx, "key").Val()
}, 30*time.Second).Should(Equal("100")) }, 30*time.Second).Should(Equal("100"))
}) })
@ -400,23 +402,23 @@ var _ = Describe("ClusterClient", func() {
if !failover { if !failover {
for _, key := range keys { for _, key := range keys {
Eventually(func() error { Eventually(func() error {
return client.SwapNodes(key) return client.SwapNodes(ctx, key)
}, 30*time.Second).ShouldNot(HaveOccurred()) }, 30*time.Second).ShouldNot(HaveOccurred())
} }
} }
for i, key := range keys { for i, key := range keys {
pipe.Set(key, key+"_value", 0) pipe.Set(ctx, key, key+"_value", 0)
pipe.Expire(key, time.Duration(i+1)*time.Hour) pipe.Expire(ctx, key, time.Duration(i+1)*time.Hour)
} }
cmds, err := pipe.Exec() cmds, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(14)) Expect(cmds).To(HaveLen(14))
_ = client.ForEachNode(func(node *redis.Client) error { _ = client.ForEachNode(ctx, func(ctx context.Context, node *redis.Client) error {
defer GinkgoRecover() defer GinkgoRecover()
Eventually(func() int64 { Eventually(func() int64 {
return node.DBSize().Val() return node.DBSize(ctx).Val()
}, 30*time.Second).ShouldNot(BeZero()) }, 30*time.Second).ShouldNot(BeZero())
return nil return nil
}) })
@ -424,16 +426,16 @@ var _ = Describe("ClusterClient", func() {
if !failover { if !failover {
for _, key := range keys { for _, key := range keys {
Eventually(func() error { Eventually(func() error {
return client.SwapNodes(key) return client.SwapNodes(ctx, key)
}, 30*time.Second).ShouldNot(HaveOccurred()) }, 30*time.Second).ShouldNot(HaveOccurred())
} }
} }
for _, key := range keys { for _, key := range keys {
pipe.Get(key) pipe.Get(ctx, key)
pipe.TTL(key) pipe.TTL(ctx, key)
} }
cmds, err = pipe.Exec() cmds, err = pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(14)) Expect(cmds).To(HaveLen(14))
@ -448,15 +450,15 @@ var _ = Describe("ClusterClient", func() {
}) })
It("works with missing keys", func() { It("works with missing keys", func() {
pipe.Set("A", "A_value", 0) pipe.Set(ctx, "A", "A_value", 0)
pipe.Set("C", "C_value", 0) pipe.Set(ctx, "C", "C_value", 0)
_, err := pipe.Exec() _, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
a := pipe.Get("A") a := pipe.Get(ctx, "A")
b := pipe.Get("B") b := pipe.Get(ctx, "B")
c := pipe.Get("C") c := pipe.Get(ctx, "C")
cmds, err := pipe.Exec() cmds, err := pipe.Exec(ctx)
Expect(err).To(Equal(redis.Nil)) Expect(err).To(Equal(redis.Nil))
Expect(cmds).To(HaveLen(3)) Expect(cmds).To(HaveLen(3))
@ -497,16 +499,16 @@ var _ = Describe("ClusterClient", func() {
}) })
It("supports PubSub", func() { It("supports PubSub", func() {
pubsub := client.Subscribe("mychannel") pubsub := client.Subscribe(ctx, "mychannel")
defer pubsub.Close() defer pubsub.Close()
Eventually(func() error { Eventually(func() error {
_, err := client.Publish("mychannel", "hello").Result() _, err := client.Publish(ctx, "mychannel", "hello").Result()
if err != nil { if err != nil {
return err return err
} }
msg, err := pubsub.ReceiveTimeout(time.Second) msg, err := pubsub.ReceiveTimeout(ctx, time.Second)
if err != nil { if err != nil {
return err return err
} }
@ -521,19 +523,19 @@ var _ = Describe("ClusterClient", func() {
}) })
It("supports PubSub.Ping without channels", func() { It("supports PubSub.Ping without channels", func() {
pubsub := client.Subscribe() pubsub := client.Subscribe(ctx)
defer pubsub.Close() defer pubsub.Close()
err := pubsub.Ping() err := pubsub.Ping(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
It("supports Process hook", func() { It("supports Process hook", func() {
err := client.Ping().Err() err := client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = client.ForEachNode(func(node *redis.Client) error { err = client.ForEachNode(ctx, func(ctx context.Context, node *redis.Client) error {
return node.Ping().Err() return node.Ping(ctx).Err()
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -566,12 +568,12 @@ var _ = Describe("ClusterClient", func() {
}, },
} }
_ = client.ForEachNode(func(node *redis.Client) error { _ = client.ForEachNode(ctx, func(ctx context.Context, node *redis.Client) error {
node.AddHook(nodeHook) node.AddHook(nodeHook)
return nil return nil
}) })
err = client.Ping().Err() err = client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(stack).To(Equal([]string{ Expect(stack).To(Equal([]string{
"cluster.BeforeProcess", "cluster.BeforeProcess",
@ -587,11 +589,11 @@ var _ = Describe("ClusterClient", func() {
}) })
It("supports Pipeline hook", func() { It("supports Pipeline hook", func() {
err := client.Ping().Err() err := client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = client.ForEachNode(func(node *redis.Client) error { err = client.ForEachNode(ctx, func(ctx context.Context, node *redis.Client) error {
return node.Ping().Err() return node.Ping(ctx).Err()
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -612,7 +614,7 @@ var _ = Describe("ClusterClient", func() {
}, },
}) })
_ = client.ForEachNode(func(node *redis.Client) error { _ = client.ForEachNode(ctx, func(ctx context.Context, node *redis.Client) error {
node.AddHook(&hook{ node.AddHook(&hook{
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
Expect(cmds).To(HaveLen(1)) Expect(cmds).To(HaveLen(1))
@ -630,8 +632,8 @@ var _ = Describe("ClusterClient", func() {
return nil return nil
}) })
_, err = client.Pipelined(func(pipe redis.Pipeliner) error { _, err = client.Pipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping() pipe.Ping(ctx)
return nil return nil
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -644,11 +646,11 @@ var _ = Describe("ClusterClient", func() {
}) })
It("supports TxPipeline hook", func() { It("supports TxPipeline hook", func() {
err := client.Ping().Err() err := client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = client.ForEachNode(func(node *redis.Client) error { err = client.ForEachNode(ctx, func(ctx context.Context, node *redis.Client) error {
return node.Ping().Err() return node.Ping(ctx).Err()
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -669,7 +671,7 @@ var _ = Describe("ClusterClient", func() {
}, },
}) })
_ = client.ForEachNode(func(node *redis.Client) error { _ = client.ForEachNode(ctx, func(ctx context.Context, node *redis.Client) error {
node.AddHook(&hook{ node.AddHook(&hook{
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
Expect(cmds).To(HaveLen(3)) Expect(cmds).To(HaveLen(3))
@ -687,8 +689,8 @@ var _ = Describe("ClusterClient", func() {
return nil return nil
}) })
_, err = client.TxPipelined(func(pipe redis.Pipeliner) error { _, err = client.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping() pipe.Ping(ctx)
return nil return nil
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -704,17 +706,17 @@ var _ = Describe("ClusterClient", func() {
Describe("ClusterClient", func() { Describe("ClusterClient", func() {
BeforeEach(func() { BeforeEach(func() {
opt = redisClusterOptions() opt = redisClusterOptions()
client = cluster.newClusterClient(opt) client = cluster.newClusterClient(ctx, opt)
err := client.ForEachMaster(func(master *redis.Client) error { err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error {
return master.FlushDB().Err() return master.FlushDB(ctx).Err()
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
AfterEach(func() { AfterEach(func() {
_ = client.ForEachMaster(func(master *redis.Client) error { _ = client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error {
return master.FlushDB().Err() return master.FlushDB(ctx).Err()
}) })
Expect(client.Close()).NotTo(HaveOccurred()) Expect(client.Close()).NotTo(HaveOccurred())
}) })
@ -727,13 +729,13 @@ var _ = Describe("ClusterClient", func() {
It("returns an error when there are no attempts left", func() { It("returns an error when there are no attempts left", func() {
opt := redisClusterOptions() opt := redisClusterOptions()
opt.MaxRedirects = -1 opt.MaxRedirects = -1
client := cluster.newClusterClient(opt) client := cluster.newClusterClient(ctx, opt)
Eventually(func() error { Eventually(func() error {
return client.SwapNodes("A") return client.SwapNodes(ctx, "A")
}, 30*time.Second).ShouldNot(HaveOccurred()) }, 30*time.Second).ShouldNot(HaveOccurred())
err := client.Get("A").Err() err := client.Get(ctx, "A").Err()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("MOVED")) Expect(err.Error()).To(ContainSubstring("MOVED"))
@ -742,21 +744,21 @@ var _ = Describe("ClusterClient", func() {
It("calls fn for every master node", func() { It("calls fn for every master node", func() {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
Expect(client.Set(strconv.Itoa(i), "", 0).Err()).NotTo(HaveOccurred()) Expect(client.Set(ctx, strconv.Itoa(i), "", 0).Err()).NotTo(HaveOccurred())
} }
err := client.ForEachMaster(func(master *redis.Client) error { err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error {
return master.FlushDB().Err() return master.FlushDB(ctx).Err()
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
size, err := client.DBSize().Result() size, err := client.DBSize(ctx).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(size).To(Equal(int64(0))) Expect(size).To(Equal(int64(0)))
}) })
It("should CLUSTER SLOTS", func() { It("should CLUSTER SLOTS", func() {
res, err := client.ClusterSlots().Result() res, err := client.ClusterSlots(ctx).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(res).To(HaveLen(3)) Expect(res).To(HaveLen(3))
@ -795,49 +797,49 @@ var _ = Describe("ClusterClient", func() {
}) })
It("should CLUSTER NODES", func() { It("should CLUSTER NODES", func() {
res, err := client.ClusterNodes().Result() res, err := client.ClusterNodes(ctx).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(len(res)).To(BeNumerically(">", 400)) Expect(len(res)).To(BeNumerically(">", 400))
}) })
It("should CLUSTER INFO", func() { It("should CLUSTER INFO", func() {
res, err := client.ClusterInfo().Result() res, err := client.ClusterInfo(ctx).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(res).To(ContainSubstring("cluster_known_nodes:6")) Expect(res).To(ContainSubstring("cluster_known_nodes:6"))
}) })
It("should CLUSTER KEYSLOT", func() { It("should CLUSTER KEYSLOT", func() {
hashSlot, err := client.ClusterKeySlot("somekey").Result() hashSlot, err := client.ClusterKeySlot(ctx, "somekey").Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(hashSlot).To(Equal(int64(hashtag.Slot("somekey")))) Expect(hashSlot).To(Equal(int64(hashtag.Slot("somekey"))))
}) })
It("should CLUSTER GETKEYSINSLOT", func() { It("should CLUSTER GETKEYSINSLOT", func() {
keys, err := client.ClusterGetKeysInSlot(hashtag.Slot("somekey"), 1).Result() keys, err := client.ClusterGetKeysInSlot(ctx, hashtag.Slot("somekey"), 1).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(len(keys)).To(Equal(0)) Expect(len(keys)).To(Equal(0))
}) })
It("should CLUSTER COUNT-FAILURE-REPORTS", func() { It("should CLUSTER COUNT-FAILURE-REPORTS", func() {
n, err := client.ClusterCountFailureReports(cluster.nodeIDs[0]).Result() n, err := client.ClusterCountFailureReports(ctx, cluster.nodeIDs[0]).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(0))) Expect(n).To(Equal(int64(0)))
}) })
It("should CLUSTER COUNTKEYSINSLOT", func() { It("should CLUSTER COUNTKEYSINSLOT", func() {
n, err := client.ClusterCountKeysInSlot(10).Result() n, err := client.ClusterCountKeysInSlot(ctx, 10).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(0))) Expect(n).To(Equal(int64(0)))
}) })
It("should CLUSTER SAVECONFIG", func() { It("should CLUSTER SAVECONFIG", func() {
res, err := client.ClusterSaveConfig().Result() res, err := client.ClusterSaveConfig(ctx).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(res).To(Equal("OK")) Expect(res).To(Equal("OK"))
}) })
It("should CLUSTER SLAVES", func() { It("should CLUSTER SLAVES", func() {
nodesList, err := client.ClusterSlaves(cluster.nodeIDs[0]).Result() nodesList, err := client.ClusterSlaves(ctx, cluster.nodeIDs[0]).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(nodesList).Should(ContainElement(ContainSubstring("slave"))) Expect(nodesList).Should(ContainElement(ContainSubstring("slave")))
Expect(nodesList).Should(HaveLen(1)) Expect(nodesList).Should(HaveLen(1))
@ -847,7 +849,7 @@ var _ = Describe("ClusterClient", func() {
const nkeys = 100 const nkeys = 100
for i := 0; i < nkeys; i++ { for i := 0; i < nkeys; i++ {
err := client.Set(fmt.Sprintf("key%d", i), "value", 0).Err() err := client.Set(ctx, fmt.Sprintf("key%d", i), "value", 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
@ -862,7 +864,7 @@ var _ = Describe("ClusterClient", func() {
} }
for i := 0; i < nkeys*10; i++ { for i := 0; i < nkeys*10; i++ {
key := client.RandomKey().Val() key := client.RandomKey(ctx).Val()
addKey(key) addKey(key)
} }
@ -879,40 +881,40 @@ var _ = Describe("ClusterClient", func() {
opt = redisClusterOptions() opt = redisClusterOptions()
opt.MinRetryBackoff = 250 * time.Millisecond opt.MinRetryBackoff = 250 * time.Millisecond
opt.MaxRetryBackoff = time.Second opt.MaxRetryBackoff = time.Second
client = cluster.newClusterClient(opt) client = cluster.newClusterClient(ctx, opt)
err := client.ForEachMaster(func(master *redis.Client) error { err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error {
return master.FlushDB().Err() return master.FlushDB(ctx).Err()
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = client.ForEachSlave(func(slave *redis.Client) error { err = client.ForEachSlave(ctx, func(ctx context.Context, slave *redis.Client) error {
defer GinkgoRecover() defer GinkgoRecover()
Eventually(func() int64 { Eventually(func() int64 {
return slave.DBSize().Val() return slave.DBSize(ctx).Val()
}, "30s").Should(Equal(int64(0))) }, "30s").Should(Equal(int64(0)))
return nil return nil
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
state, err := client.LoadState() state, err := client.LoadState(ctx)
Eventually(func() bool { Eventually(func() bool {
state, err = client.LoadState() state, err = client.LoadState(ctx)
if err != nil { if err != nil {
return false return false
} }
return state.IsConsistent() return state.IsConsistent(ctx)
}, "30s").Should(BeTrue()) }, "30s").Should(BeTrue())
for _, slave := range state.Slaves { for _, slave := range state.Slaves {
err = slave.Client.ClusterFailover().Err() err = slave.Client.ClusterFailover(ctx).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { Eventually(func() bool {
state, _ := client.LoadState() state, _ := client.LoadState(ctx)
return state.IsConsistent() return state.IsConsistent(ctx)
}, "30s").Should(BeTrue()) }, "30s").Should(BeTrue())
} }
}) })
@ -929,16 +931,16 @@ var _ = Describe("ClusterClient", func() {
BeforeEach(func() { BeforeEach(func() {
opt = redisClusterOptions() opt = redisClusterOptions()
opt.RouteByLatency = true opt.RouteByLatency = true
client = cluster.newClusterClient(opt) client = cluster.newClusterClient(ctx, opt)
err := client.ForEachMaster(func(master *redis.Client) error { err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error {
return master.FlushDB().Err() return master.FlushDB(ctx).Err()
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = client.ForEachSlave(func(slave *redis.Client) error { err = client.ForEachSlave(ctx, func(ctx context.Context, slave *redis.Client) error {
Eventually(func() int64 { Eventually(func() int64 {
return client.DBSize().Val() return client.DBSize(ctx).Val()
}, 30*time.Second).Should(Equal(int64(0))) }, 30*time.Second).Should(Equal(int64(0)))
return nil return nil
}) })
@ -946,8 +948,8 @@ var _ = Describe("ClusterClient", func() {
}) })
AfterEach(func() { AfterEach(func() {
err := client.ForEachSlave(func(slave *redis.Client) error { err := client.ForEachSlave(ctx, func(ctx context.Context, slave *redis.Client) error {
return slave.ReadWrite().Err() return slave.ReadWrite(ctx).Err()
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -985,16 +987,16 @@ var _ = Describe("ClusterClient", func() {
}} }}
return slots, nil return slots, nil
} }
client = cluster.newClusterClient(opt) client = cluster.newClusterClient(ctx, opt)
err := client.ForEachMaster(func(master *redis.Client) error { err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error {
return master.FlushDB().Err() return master.FlushDB(ctx).Err()
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = client.ForEachSlave(func(slave *redis.Client) error { err = client.ForEachSlave(ctx, func(ctx context.Context, slave *redis.Client) error {
Eventually(func() int64 { Eventually(func() int64 {
return client.DBSize().Val() return client.DBSize(ctx).Val()
}, 30*time.Second).Should(Equal(int64(0))) }, 30*time.Second).Should(Equal(int64(0)))
return nil return nil
}) })
@ -1039,16 +1041,16 @@ var _ = Describe("ClusterClient", func() {
}} }}
return slots, nil return slots, nil
} }
client = cluster.newClusterClient(opt) client = cluster.newClusterClient(ctx, opt)
err := client.ForEachMaster(func(master *redis.Client) error { err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error {
return master.FlushDB().Err() return master.FlushDB(ctx).Err()
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = client.ForEachSlave(func(slave *redis.Client) error { err = client.ForEachSlave(ctx, func(ctx context.Context, slave *redis.Client) error {
Eventually(func() int64 { Eventually(func() int64 {
return client.DBSize().Val() return client.DBSize(ctx).Val()
}, 30*time.Second).Should(Equal(int64(0))) }, 30*time.Second).Should(Equal(int64(0)))
return nil return nil
}) })
@ -1078,13 +1080,13 @@ var _ = Describe("ClusterClient without nodes", func() {
}) })
It("Ping returns an error", func() { It("Ping returns an error", func() {
err := client.Ping().Err() err := client.Ping(ctx).Err()
Expect(err).To(MatchError("redis: cluster has no nodes")) Expect(err).To(MatchError("redis: cluster has no nodes"))
}) })
It("pipeline returns an error", func() { It("pipeline returns an error", func() {
_, err := client.Pipelined(func(pipe redis.Pipeliner) error { _, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping() pipe.Ping(ctx)
return nil return nil
}) })
Expect(err).To(MatchError("redis: cluster has no nodes")) Expect(err).To(MatchError("redis: cluster has no nodes"))
@ -1105,13 +1107,13 @@ var _ = Describe("ClusterClient without valid nodes", func() {
}) })
It("returns an error", func() { It("returns an error", func() {
err := client.Ping().Err() err := client.Ping(ctx).Err()
Expect(err).To(MatchError("ERR This instance has cluster support disabled")) Expect(err).To(MatchError("ERR This instance has cluster support disabled"))
}) })
It("pipeline returns an error", func() { It("pipeline returns an error", func() {
_, err := client.Pipelined(func(pipe redis.Pipeliner) error { _, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping() pipe.Ping(ctx)
return nil return nil
}) })
Expect(err).To(MatchError("ERR This instance has cluster support disabled")) Expect(err).To(MatchError("ERR This instance has cluster support disabled"))
@ -1123,7 +1125,7 @@ var _ = Describe("ClusterClient with unavailable Cluster", func() {
BeforeEach(func() { BeforeEach(func() {
for _, node := range cluster.clients { for _, node := range cluster.clients {
err := node.ClientPause(5 * time.Second).Err() err := node.ClientPause(ctx, 5*time.Second).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
@ -1139,11 +1141,11 @@ var _ = Describe("ClusterClient with unavailable Cluster", func() {
}) })
It("recovers when Cluster recovers", func() { It("recovers when Cluster recovers", func() {
err := client.Ping().Err() err := client.Ping(ctx).Err()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Eventually(func() error { Eventually(func() error {
return client.Ping().Err() return client.Ping(ctx).Err()
}, "30s").ShouldNot(HaveOccurred()) }, "30s").ShouldNot(HaveOccurred())
}) })
}) })
@ -1157,14 +1159,14 @@ var _ = Describe("ClusterClient timeout", func() {
testTimeout := func() { testTimeout := func() {
It("Ping timeouts", func() { It("Ping timeouts", func() {
err := client.Ping().Err() err := client.Ping(ctx).Err()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue()) Expect(err.(net.Error).Timeout()).To(BeTrue())
}) })
It("Pipeline timeouts", func() { It("Pipeline timeouts", func() {
_, err := client.Pipelined(func(pipe redis.Pipeliner) error { _, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping() pipe.Ping(ctx)
return nil return nil
}) })
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
@ -1172,17 +1174,17 @@ var _ = Describe("ClusterClient timeout", func() {
}) })
It("Tx timeouts", func() { It("Tx timeouts", func() {
err := client.Watch(func(tx *redis.Tx) error { err := client.Watch(ctx, func(tx *redis.Tx) error {
return tx.Ping().Err() return tx.Ping(ctx).Err()
}, "foo") }, "foo")
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue()) Expect(err.(net.Error).Timeout()).To(BeTrue())
}) })
It("Tx Pipeline timeouts", func() { It("Tx Pipeline timeouts", func() {
err := client.Watch(func(tx *redis.Tx) error { err := client.Watch(ctx, func(tx *redis.Tx) error {
_, err := tx.TxPipelined(func(pipe redis.Pipeliner) error { _, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping() pipe.Ping(ctx)
return nil return nil
}) })
return err return err
@ -1200,19 +1202,19 @@ var _ = Describe("ClusterClient timeout", func() {
opt.ReadTimeout = 250 * time.Millisecond opt.ReadTimeout = 250 * time.Millisecond
opt.WriteTimeout = 250 * time.Millisecond opt.WriteTimeout = 250 * time.Millisecond
opt.MaxRedirects = 1 opt.MaxRedirects = 1
client = cluster.newClusterClient(opt) client = cluster.newClusterClient(ctx, opt)
err := client.ForEachNode(func(client *redis.Client) error { err := client.ForEachNode(ctx, func(ctx context.Context, client *redis.Client) error {
return client.ClientPause(pause).Err() return client.ClientPause(ctx, pause).Err()
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
AfterEach(func() { AfterEach(func() {
_ = client.ForEachNode(func(client *redis.Client) error { _ = client.ForEachNode(ctx, func(ctx context.Context, client *redis.Client) error {
defer GinkgoRecover() defer GinkgoRecover()
Eventually(func() error { Eventually(func() error {
return client.Ping().Err() return client.Ping(ctx).Err()
}, 2*pause).ShouldNot(HaveOccurred()) }, 2*pause).ShouldNot(HaveOccurred())
return nil return nil
}) })

View File

@ -1,19 +1,20 @@
package redis package redis
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/go-redis/redis/v7/internal" "github.com/go-redis/redis/v8/internal"
"github.com/go-redis/redis/v7/internal/proto" "github.com/go-redis/redis/v8/internal/proto"
"github.com/go-redis/redis/v7/internal/util" "github.com/go-redis/redis/v8/internal/util"
) )
type Cmder interface { type Cmder interface {
Name() string Name() string
FullName() string
Args() []interface{} Args() []interface{}
String() string String() string
stringArg(int) string stringArg(int) string
@ -55,26 +56,6 @@ func writeCmd(wr *proto.Writer, cmd Cmder) error {
return wr.WriteArgs(cmd.Args()) return wr.WriteArgs(cmd.Args())
} }
func cmdString(cmd Cmder, val interface{}) string {
ss := make([]string, 0, len(cmd.Args()))
for _, arg := range cmd.Args() {
ss = append(ss, fmt.Sprint(arg))
}
s := strings.Join(ss, " ")
if err := cmd.Err(); err != nil {
return s + ": " + err.Error()
}
if val != nil {
switch vv := val.(type) {
case []byte:
return s + ": " + string(vv)
default:
return s + ": " + fmt.Sprint(val)
}
}
return s
}
func cmdFirstKeyPos(cmd Cmder, info *CommandInfo) int { func cmdFirstKeyPos(cmd Cmder, info *CommandInfo) int {
switch cmd.Name() { switch cmd.Name() {
case "eval", "evalsha": case "eval", "evalsha":
@ -92,9 +73,69 @@ func cmdFirstKeyPos(cmd Cmder, info *CommandInfo) int {
return int(info.FirstKeyPos) return int(info.FirstKeyPos)
} }
func cmdString(cmd Cmder, val interface{}) string {
b := make([]byte, 0, 32)
for i, arg := range cmd.Args() {
if i > 0 {
b = append(b, ' ')
}
b = appendArg(b, arg)
}
if err := cmd.Err(); err != nil {
b = append(b, ": "...)
b = append(b, err.Error()...)
} else if val != nil {
b = append(b, ": "...)
switch val := val.(type) {
case []byte:
b = append(b, val...)
default:
b = appendArg(b, val)
}
}
return string(b)
}
func appendArg(b []byte, v interface{}) []byte {
switch v := v.(type) {
case nil:
return append(b, "<nil>"...)
case string:
return append(b, v...)
case []byte:
return append(b, v...)
case int:
return strconv.AppendInt(b, int64(v), 10)
case int32:
return strconv.AppendInt(b, int64(v), 10)
case int64:
return strconv.AppendInt(b, v, 10)
case uint:
return strconv.AppendUint(b, uint64(v), 10)
case uint32:
return strconv.AppendUint(b, uint64(v), 10)
case uint64:
return strconv.AppendUint(b, v, 10)
case bool:
if v {
return append(b, "true"...)
}
return append(b, "false"...)
case time.Time:
return v.AppendFormat(b, time.RFC3339Nano)
default:
return append(b, fmt.Sprint(v)...)
}
}
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
type baseCmd struct { type baseCmd struct {
ctx context.Context
args []interface{} args []interface{}
err error err error
@ -111,6 +152,21 @@ func (cmd *baseCmd) Name() string {
return internal.ToLower(cmd.stringArg(0)) return internal.ToLower(cmd.stringArg(0))
} }
func (cmd *baseCmd) FullName() string {
switch name := cmd.Name(); name {
case "cluster", "command":
if len(cmd.args) == 1 {
return name
}
if s2, ok := cmd.args[1].(string); ok {
return name + " " + s2
}
return name
default:
return name
}
}
func (cmd *baseCmd) Args() []interface{} { func (cmd *baseCmd) Args() []interface{} {
return cmd.args return cmd.args
} }
@ -147,9 +203,12 @@ type Cmd struct {
val interface{} val interface{}
} }
func NewCmd(args ...interface{}) *Cmd { func NewCmd(ctx context.Context, args ...interface{}) *Cmd {
return &Cmd{ return &Cmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -308,9 +367,12 @@ type SliceCmd struct {
var _ Cmder = (*SliceCmd)(nil) var _ Cmder = (*SliceCmd)(nil)
func NewSliceCmd(args ...interface{}) *SliceCmd { func NewSliceCmd(ctx context.Context, args ...interface{}) *SliceCmd {
return &SliceCmd{ return &SliceCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -346,9 +408,12 @@ type StatusCmd struct {
var _ Cmder = (*StatusCmd)(nil) var _ Cmder = (*StatusCmd)(nil)
func NewStatusCmd(args ...interface{}) *StatusCmd { func NewStatusCmd(ctx context.Context, args ...interface{}) *StatusCmd {
return &StatusCmd{ return &StatusCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -379,9 +444,12 @@ type IntCmd struct {
var _ Cmder = (*IntCmd)(nil) var _ Cmder = (*IntCmd)(nil)
func NewIntCmd(args ...interface{}) *IntCmd { func NewIntCmd(ctx context.Context, args ...interface{}) *IntCmd {
return &IntCmd{ return &IntCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -416,9 +484,12 @@ type IntSliceCmd struct {
var _ Cmder = (*IntSliceCmd)(nil) var _ Cmder = (*IntSliceCmd)(nil)
func NewIntSliceCmd(args ...interface{}) *IntSliceCmd { func NewIntSliceCmd(ctx context.Context, args ...interface{}) *IntSliceCmd {
return &IntSliceCmd{ return &IntSliceCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -460,9 +531,12 @@ type DurationCmd struct {
var _ Cmder = (*DurationCmd)(nil) var _ Cmder = (*DurationCmd)(nil)
func NewDurationCmd(precision time.Duration, args ...interface{}) *DurationCmd { func NewDurationCmd(ctx context.Context, precision time.Duration, args ...interface{}) *DurationCmd {
return &DurationCmd{ return &DurationCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
precision: precision, precision: precision,
} }
} }
@ -506,9 +580,12 @@ type TimeCmd struct {
var _ Cmder = (*TimeCmd)(nil) var _ Cmder = (*TimeCmd)(nil)
func NewTimeCmd(args ...interface{}) *TimeCmd { func NewTimeCmd(ctx context.Context, args ...interface{}) *TimeCmd {
return &TimeCmd{ return &TimeCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -556,9 +633,12 @@ type BoolCmd struct {
var _ Cmder = (*BoolCmd)(nil) var _ Cmder = (*BoolCmd)(nil)
func NewBoolCmd(args ...interface{}) *BoolCmd { func NewBoolCmd(ctx context.Context, args ...interface{}) *BoolCmd {
return &BoolCmd{ return &BoolCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -610,9 +690,12 @@ type StringCmd struct {
var _ Cmder = (*StringCmd)(nil) var _ Cmder = (*StringCmd)(nil)
func NewStringCmd(args ...interface{}) *StringCmd { func NewStringCmd(ctx context.Context, args ...interface{}) *StringCmd {
return &StringCmd{ return &StringCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -700,9 +783,12 @@ type FloatCmd struct {
var _ Cmder = (*FloatCmd)(nil) var _ Cmder = (*FloatCmd)(nil)
func NewFloatCmd(args ...interface{}) *FloatCmd { func NewFloatCmd(ctx context.Context, args ...interface{}) *FloatCmd {
return &FloatCmd{ return &FloatCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -733,9 +819,12 @@ type StringSliceCmd struct {
var _ Cmder = (*StringSliceCmd)(nil) var _ Cmder = (*StringSliceCmd)(nil)
func NewStringSliceCmd(args ...interface{}) *StringSliceCmd { func NewStringSliceCmd(ctx context.Context, args ...interface{}) *StringSliceCmd {
return &StringSliceCmd{ return &StringSliceCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -783,9 +872,12 @@ type BoolSliceCmd struct {
var _ Cmder = (*BoolSliceCmd)(nil) var _ Cmder = (*BoolSliceCmd)(nil)
func NewBoolSliceCmd(args ...interface{}) *BoolSliceCmd { func NewBoolSliceCmd(ctx context.Context, args ...interface{}) *BoolSliceCmd {
return &BoolSliceCmd{ return &BoolSliceCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -826,9 +918,12 @@ type StringStringMapCmd struct {
var _ Cmder = (*StringStringMapCmd)(nil) var _ Cmder = (*StringStringMapCmd)(nil)
func NewStringStringMapCmd(args ...interface{}) *StringStringMapCmd { func NewStringStringMapCmd(ctx context.Context, args ...interface{}) *StringStringMapCmd {
return &StringStringMapCmd{ return &StringStringMapCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -875,9 +970,12 @@ type StringIntMapCmd struct {
var _ Cmder = (*StringIntMapCmd)(nil) var _ Cmder = (*StringIntMapCmd)(nil)
func NewStringIntMapCmd(args ...interface{}) *StringIntMapCmd { func NewStringIntMapCmd(ctx context.Context, args ...interface{}) *StringIntMapCmd {
return &StringIntMapCmd{ return &StringIntMapCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -924,9 +1022,12 @@ type StringStructMapCmd struct {
var _ Cmder = (*StringStructMapCmd)(nil) var _ Cmder = (*StringStructMapCmd)(nil)
func NewStringStructMapCmd(args ...interface{}) *StringStructMapCmd { func NewStringStructMapCmd(ctx context.Context, args ...interface{}) *StringStructMapCmd {
return &StringStructMapCmd{ return &StringStructMapCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -972,9 +1073,12 @@ type XMessageSliceCmd struct {
var _ Cmder = (*XMessageSliceCmd)(nil) var _ Cmder = (*XMessageSliceCmd)(nil)
func NewXMessageSliceCmd(args ...interface{}) *XMessageSliceCmd { func NewXMessageSliceCmd(ctx context.Context, args ...interface{}) *XMessageSliceCmd {
return &XMessageSliceCmd{ return &XMessageSliceCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -1069,9 +1173,12 @@ type XStreamSliceCmd struct {
var _ Cmder = (*XStreamSliceCmd)(nil) var _ Cmder = (*XStreamSliceCmd)(nil)
func NewXStreamSliceCmd(args ...interface{}) *XStreamSliceCmd { func NewXStreamSliceCmd(ctx context.Context, args ...interface{}) *XStreamSliceCmd {
return &XStreamSliceCmd{ return &XStreamSliceCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -1138,9 +1245,12 @@ type XPendingCmd struct {
var _ Cmder = (*XPendingCmd)(nil) var _ Cmder = (*XPendingCmd)(nil)
func NewXPendingCmd(args ...interface{}) *XPendingCmd { func NewXPendingCmd(ctx context.Context, args ...interface{}) *XPendingCmd {
return &XPendingCmd{ return &XPendingCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -1237,9 +1347,12 @@ type XPendingExtCmd struct {
var _ Cmder = (*XPendingExtCmd)(nil) var _ Cmder = (*XPendingExtCmd)(nil)
func NewXPendingExtCmd(args ...interface{}) *XPendingExtCmd { func NewXPendingExtCmd(ctx context.Context, args ...interface{}) *XPendingExtCmd {
return &XPendingExtCmd{ return &XPendingExtCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -1317,9 +1430,12 @@ type XInfoGroups struct {
var _ Cmder = (*XInfoGroupsCmd)(nil) var _ Cmder = (*XInfoGroupsCmd)(nil)
func NewXInfoGroupsCmd(stream string) *XInfoGroupsCmd { func NewXInfoGroupsCmd(ctx context.Context, stream string) *XInfoGroupsCmd {
return &XInfoGroupsCmd{ return &XInfoGroupsCmd{
baseCmd: baseCmd{args: []interface{}{"xinfo", "groups", stream}}, baseCmd: baseCmd{
ctx: ctx,
args: []interface{}{"xinfo", "groups", stream},
},
} }
} }
@ -1401,9 +1517,12 @@ type ZSliceCmd struct {
var _ Cmder = (*ZSliceCmd)(nil) var _ Cmder = (*ZSliceCmd)(nil)
func NewZSliceCmd(args ...interface{}) *ZSliceCmd { func NewZSliceCmd(ctx context.Context, args ...interface{}) *ZSliceCmd {
return &ZSliceCmd{ return &ZSliceCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -1453,9 +1572,12 @@ type ZWithKeyCmd struct {
var _ Cmder = (*ZWithKeyCmd)(nil) var _ Cmder = (*ZWithKeyCmd)(nil)
func NewZWithKeyCmd(args ...interface{}) *ZWithKeyCmd { func NewZWithKeyCmd(ctx context.Context, args ...interface{}) *ZWithKeyCmd {
return &ZWithKeyCmd{ return &ZWithKeyCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -1508,14 +1630,17 @@ type ScanCmd struct {
page []string page []string
cursor uint64 cursor uint64
process func(cmd Cmder) error process cmdable
} }
var _ Cmder = (*ScanCmd)(nil) var _ Cmder = (*ScanCmd)(nil)
func NewScanCmd(process func(cmd Cmder) error, args ...interface{}) *ScanCmd { func NewScanCmd(ctx context.Context, process cmdable, args ...interface{}) *ScanCmd {
return &ScanCmd{ return &ScanCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
process: process, process: process,
} }
} }
@ -1565,9 +1690,12 @@ type ClusterSlotsCmd struct {
var _ Cmder = (*ClusterSlotsCmd)(nil) var _ Cmder = (*ClusterSlotsCmd)(nil)
func NewClusterSlotsCmd(args ...interface{}) *ClusterSlotsCmd { func NewClusterSlotsCmd(ctx context.Context, args ...interface{}) *ClusterSlotsCmd {
return &ClusterSlotsCmd{ return &ClusterSlotsCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -1682,9 +1810,12 @@ type GeoLocationCmd struct {
var _ Cmder = (*GeoLocationCmd)(nil) var _ Cmder = (*GeoLocationCmd)(nil)
func NewGeoLocationCmd(q *GeoRadiusQuery, args ...interface{}) *GeoLocationCmd { func NewGeoLocationCmd(ctx context.Context, q *GeoRadiusQuery, args ...interface{}) *GeoLocationCmd {
return &GeoLocationCmd{ return &GeoLocationCmd{
baseCmd: baseCmd{args: geoLocationArgs(q, args...)}, baseCmd: baseCmd{
ctx: ctx,
args: geoLocationArgs(q, args...),
},
q: q, q: q,
} }
} }
@ -1826,9 +1957,12 @@ type GeoPosCmd struct {
var _ Cmder = (*GeoPosCmd)(nil) var _ Cmder = (*GeoPosCmd)(nil)
func NewGeoPosCmd(args ...interface{}) *GeoPosCmd { func NewGeoPosCmd(ctx context.Context, args ...interface{}) *GeoPosCmd {
return &GeoPosCmd{ return &GeoPosCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }
@ -1899,9 +2033,12 @@ type CommandsInfoCmd struct {
var _ Cmder = (*CommandsInfoCmd)(nil) var _ Cmder = (*CommandsInfoCmd)(nil)
func NewCommandsInfoCmd(args ...interface{}) *CommandsInfoCmd { func NewCommandsInfoCmd(ctx context.Context, args ...interface{}) *CommandsInfoCmd {
return &CommandsInfoCmd{ return &CommandsInfoCmd{
baseCmd: baseCmd{args: args}, baseCmd: baseCmd{
ctx: ctx,
args: args,
},
} }
} }

View File

@ -4,7 +4,7 @@ import (
"errors" "errors"
"time" "time"
redis "github.com/go-redis/redis/v7" redis "github.com/go-redis/redis/v8"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -15,7 +15,7 @@ var _ = Describe("Cmd", func() {
BeforeEach(func() { BeforeEach(func() {
client = redis.NewClient(redisOptions()) client = redis.NewClient(redisOptions())
Expect(client.FlushDB().Err()).NotTo(HaveOccurred()) Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
}) })
AfterEach(func() { AfterEach(func() {
@ -23,19 +23,19 @@ var _ = Describe("Cmd", func() {
}) })
It("implements Stringer", func() { It("implements Stringer", func() {
set := client.Set("foo", "bar", 0) set := client.Set(ctx, "foo", "bar", 0)
Expect(set.String()).To(Equal("set foo bar: OK")) Expect(set.String()).To(Equal("set foo bar: OK"))
get := client.Get("foo") get := client.Get(ctx, "foo")
Expect(get.String()).To(Equal("get foo: bar")) Expect(get.String()).To(Equal("get foo: bar"))
}) })
It("has val/err", func() { It("has val/err", func() {
set := client.Set("key", "hello", 0) set := client.Set(ctx, "key", "hello", 0)
Expect(set.Err()).NotTo(HaveOccurred()) Expect(set.Err()).NotTo(HaveOccurred())
Expect(set.Val()).To(Equal("OK")) Expect(set.Val()).To(Equal("OK"))
get := client.Get("key") get := client.Get(ctx, "key")
Expect(get.Err()).NotTo(HaveOccurred()) Expect(get.Err()).NotTo(HaveOccurred())
Expect(get.Val()).To(Equal("hello")) Expect(get.Val()).To(Equal("hello"))
@ -44,18 +44,18 @@ var _ = Describe("Cmd", func() {
}) })
It("has helpers", func() { It("has helpers", func() {
set := client.Set("key", "10", 0) set := client.Set(ctx, "key", "10", 0)
Expect(set.Err()).NotTo(HaveOccurred()) Expect(set.Err()).NotTo(HaveOccurred())
n, err := client.Get("key").Int64() n, err := client.Get(ctx, "key").Int64()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(10))) Expect(n).To(Equal(int64(10)))
un, err := client.Get("key").Uint64() un, err := client.Get(ctx, "key").Uint64()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(un).To(Equal(uint64(10))) Expect(un).To(Equal(uint64(10)))
f, err := client.Get("key").Float64() f, err := client.Get(ctx, "key").Float64()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(f).To(Equal(float64(10))) Expect(f).To(Equal(float64(10)))
}) })
@ -63,10 +63,10 @@ var _ = Describe("Cmd", func() {
It("supports float32", func() { It("supports float32", func() {
f := float32(66.97) f := float32(66.97)
err := client.Set("float_key", f, 0).Err() err := client.Set(ctx, "float_key", f, 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
val, err := client.Get("float_key").Float32() val, err := client.Get(ctx, "float_key").Float32()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal(f)) Expect(val).To(Equal(f))
}) })
@ -74,14 +74,14 @@ var _ = Describe("Cmd", func() {
It("supports time.Time", func() { It("supports time.Time", func() {
tm := time.Date(2019, 01, 01, 9, 45, 10, 222125, time.UTC) tm := time.Date(2019, 01, 01, 9, 45, 10, 222125, time.UTC)
err := client.Set("time_key", tm, 0).Err() err := client.Set(ctx, "time_key", tm, 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
s, err := client.Get("time_key").Result() s, err := client.Get(ctx, "time_key").Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(s).To(Equal("2019-01-01T09:45:10.000222125Z")) Expect(s).To(Equal("2019-01-01T09:45:10.000222125Z"))
tm2, err := client.Get("time_key").Time() tm2, err := client.Get(ctx, "time_key").Time()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(tm2).To(BeTemporally("==", tm)) Expect(tm2).To(BeTemporally("==", tm))
}) })

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -6,8 +6,8 @@ import (
"net" "net"
"strings" "strings"
"github.com/go-redis/redis/v7/internal/pool" "github.com/go-redis/redis/v8/internal/pool"
"github.com/go-redis/redis/v7/internal/proto" "github.com/go-redis/redis/v8/internal/proto"
) )
var ErrClosed = pool.ErrClosed var ErrClosed = pool.ErrClosed

View File

@ -4,7 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/go-redis/redis/v7" "github.com/go-redis/redis/v8"
) )
type redisHook struct{} type redisHook struct{}
@ -37,7 +37,7 @@ func Example_instrumentation() {
}) })
rdb.AddHook(redisHook{}) rdb.AddHook(redisHook{})
rdb.Ping() rdb.Ping(ctx)
// Output: starting processing: <ping: > // Output: starting processing: <ping: >
// finished processing: <ping: PONG> // finished processing: <ping: PONG>
} }
@ -48,9 +48,9 @@ func ExamplePipeline_instrumentation() {
}) })
rdb.AddHook(redisHook{}) rdb.AddHook(redisHook{})
rdb.Pipelined(func(pipe redis.Pipeliner) error { rdb.Pipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping() pipe.Ping(ctx)
pipe.Ping() pipe.Ping(ctx)
return nil return nil
}) })
// Output: pipeline starting processing: [ping: ping: ] // Output: pipeline starting processing: [ping: ping: ]
@ -63,9 +63,9 @@ func ExampleWatch_instrumentation() {
}) })
rdb.AddHook(redisHook{}) rdb.AddHook(redisHook{})
rdb.Watch(func(tx *redis.Tx) error { rdb.Watch(ctx, func(tx *redis.Tx) error {
tx.Ping() tx.Ping(ctx)
tx.Ping() tx.Ping(ctx)
return nil return nil
}, "foo") }, "foo")
// Output: // Output:

View File

@ -1,14 +1,16 @@
package redis_test package redis_test
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"sync" "sync"
"time" "time"
"github.com/go-redis/redis/v7" "github.com/go-redis/redis/v8"
) )
var ctx = context.Background()
var rdb *redis.Client var rdb *redis.Client
func init() { func init() {
@ -29,7 +31,7 @@ func ExampleNewClient() {
DB: 0, // use default DB DB: 0, // use default DB
}) })
pong, err := rdb.Ping().Result() pong, err := rdb.Ping(ctx).Result()
fmt.Println(pong, err) fmt.Println(pong, err)
// Output: PONG <nil> // Output: PONG <nil>
} }
@ -58,7 +60,7 @@ func ExampleNewFailoverClient() {
MasterName: "master", MasterName: "master",
SentinelAddrs: []string{":26379"}, SentinelAddrs: []string{":26379"},
}) })
rdb.Ping() rdb.Ping(ctx)
} }
func ExampleNewClusterClient() { func ExampleNewClusterClient() {
@ -67,7 +69,7 @@ func ExampleNewClusterClient() {
rdb := redis.NewClusterClient(&redis.ClusterOptions{ rdb := redis.NewClusterClient(&redis.ClusterOptions{
Addrs: []string{":7000", ":7001", ":7002", ":7003", ":7004", ":7005"}, Addrs: []string{":7000", ":7001", ":7002", ":7003", ":7004", ":7005"},
}) })
rdb.Ping() rdb.Ping(ctx)
} }
// Following example creates a cluster from 2 master nodes and 2 slave nodes // Following example creates a cluster from 2 master nodes and 2 slave nodes
@ -106,11 +108,11 @@ func ExampleNewClusterClient_manualSetup() {
ClusterSlots: clusterSlots, ClusterSlots: clusterSlots,
RouteRandomly: true, RouteRandomly: true,
}) })
rdb.Ping() rdb.Ping(ctx)
// ReloadState reloads cluster state. It calls ClusterSlots func // ReloadState reloads cluster state. It calls ClusterSlots func
// to get cluster slots information. // to get cluster slots information.
err := rdb.ReloadState() err := rdb.ReloadState(ctx)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -124,22 +126,22 @@ func ExampleNewRing() {
"shard3": ":7002", "shard3": ":7002",
}, },
}) })
rdb.Ping() rdb.Ping(ctx)
} }
func ExampleClient() { func ExampleClient() {
err := rdb.Set("key", "value", 0).Err() err := rdb.Set(ctx, "key", "value", 0).Err()
if err != nil { if err != nil {
panic(err) panic(err)
} }
val, err := rdb.Get("key").Result() val, err := rdb.Get(ctx, "key").Result()
if err != nil { if err != nil {
panic(err) panic(err)
} }
fmt.Println("key", val) fmt.Println("key", val)
val2, err := rdb.Get("missing_key").Result() val2, err := rdb.Get(ctx, "missing_key").Result()
if err == redis.Nil { if err == redis.Nil {
fmt.Println("missing_key does not exist") fmt.Println("missing_key does not exist")
} else if err != nil { } else if err != nil {
@ -152,19 +154,19 @@ func ExampleClient() {
} }
func ExampleConn() { func ExampleConn() {
conn := rdb.Conn() conn := rdb.Conn(context.Background())
err := conn.ClientSetName("foobar").Err() err := conn.ClientSetName(ctx, "foobar").Err()
if err != nil { if err != nil {
panic(err) panic(err)
} }
// Open other connections. // Open other connections.
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
go rdb.Ping() go rdb.Ping(ctx)
} }
s, err := conn.ClientGetName().Result() s, err := conn.ClientGetName(ctx).Result()
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -175,20 +177,20 @@ func ExampleConn() {
func ExampleClient_Set() { func ExampleClient_Set() {
// Last argument is expiration. Zero means the key has no // Last argument is expiration. Zero means the key has no
// expiration time. // expiration time.
err := rdb.Set("key", "value", 0).Err() err := rdb.Set(ctx, "key", "value", 0).Err()
if err != nil { if err != nil {
panic(err) panic(err)
} }
// key2 will expire in an hour. // key2 will expire in an hour.
err = rdb.Set("key2", "value", time.Hour).Err() err = rdb.Set(ctx, "key2", "value", time.Hour).Err()
if err != nil { if err != nil {
panic(err) panic(err)
} }
} }
func ExampleClient_Incr() { func ExampleClient_Incr() {
result, err := rdb.Incr("counter").Result() result, err := rdb.Incr(ctx, "counter").Result()
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -198,12 +200,12 @@ func ExampleClient_Incr() {
} }
func ExampleClient_BLPop() { func ExampleClient_BLPop() {
if err := rdb.RPush("queue", "message").Err(); err != nil { if err := rdb.RPush(ctx, "queue", "message").Err(); err != nil {
panic(err) panic(err)
} }
// use `rdb.BLPop(0, "queue")` for infinite waiting time // use `rdb.BLPop(0, "queue")` for infinite waiting time
result, err := rdb.BLPop(1*time.Second, "queue").Result() result, err := rdb.BLPop(ctx, 1*time.Second, "queue").Result()
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -213,9 +215,9 @@ func ExampleClient_BLPop() {
} }
func ExampleClient_Scan() { func ExampleClient_Scan() {
rdb.FlushDB() rdb.FlushDB(ctx)
for i := 0; i < 33; i++ { for i := 0; i < 33; i++ {
err := rdb.Set(fmt.Sprintf("key%d", i), "value", 0).Err() err := rdb.Set(ctx, fmt.Sprintf("key%d", i), "value", 0).Err()
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -226,7 +228,7 @@ func ExampleClient_Scan() {
for { for {
var keys []string var keys []string
var err error var err error
keys, cursor, err = rdb.Scan(cursor, "key*", 10).Result() keys, cursor, err = rdb.Scan(ctx, cursor, "key*", 10).Result()
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -242,9 +244,9 @@ func ExampleClient_Scan() {
func ExampleClient_Pipelined() { func ExampleClient_Pipelined() {
var incr *redis.IntCmd var incr *redis.IntCmd
_, err := rdb.Pipelined(func(pipe redis.Pipeliner) error { _, err := rdb.Pipelined(ctx, func(pipe redis.Pipeliner) error {
incr = pipe.Incr("pipelined_counter") incr = pipe.Incr(ctx, "pipelined_counter")
pipe.Expire("pipelined_counter", time.Hour) pipe.Expire(ctx, "pipelined_counter", time.Hour)
return nil return nil
}) })
fmt.Println(incr.Val(), err) fmt.Println(incr.Val(), err)
@ -254,8 +256,8 @@ func ExampleClient_Pipelined() {
func ExampleClient_Pipeline() { func ExampleClient_Pipeline() {
pipe := rdb.Pipeline() pipe := rdb.Pipeline()
incr := pipe.Incr("pipeline_counter") incr := pipe.Incr(ctx, "pipeline_counter")
pipe.Expire("pipeline_counter", time.Hour) pipe.Expire(ctx, "pipeline_counter", time.Hour)
// Execute // Execute
// //
@ -263,16 +265,16 @@ func ExampleClient_Pipeline() {
// EXPIRE pipeline_counts 3600 // EXPIRE pipeline_counts 3600
// //
// using one rdb-server roundtrip. // using one rdb-server roundtrip.
_, err := pipe.Exec() _, err := pipe.Exec(ctx)
fmt.Println(incr.Val(), err) fmt.Println(incr.Val(), err)
// Output: 1 <nil> // Output: 1 <nil>
} }
func ExampleClient_TxPipelined() { func ExampleClient_TxPipelined() {
var incr *redis.IntCmd var incr *redis.IntCmd
_, err := rdb.TxPipelined(func(pipe redis.Pipeliner) error { _, err := rdb.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
incr = pipe.Incr("tx_pipelined_counter") incr = pipe.Incr(ctx, "tx_pipelined_counter")
pipe.Expire("tx_pipelined_counter", time.Hour) pipe.Expire(ctx, "tx_pipelined_counter", time.Hour)
return nil return nil
}) })
fmt.Println(incr.Val(), err) fmt.Println(incr.Val(), err)
@ -282,8 +284,8 @@ func ExampleClient_TxPipelined() {
func ExampleClient_TxPipeline() { func ExampleClient_TxPipeline() {
pipe := rdb.TxPipeline() pipe := rdb.TxPipeline()
incr := pipe.Incr("tx_pipeline_counter") incr := pipe.Incr(ctx, "tx_pipeline_counter")
pipe.Expire("tx_pipeline_counter", time.Hour) pipe.Expire(ctx, "tx_pipeline_counter", time.Hour)
// Execute // Execute
// //
@ -293,7 +295,7 @@ func ExampleClient_TxPipeline() {
// EXEC // EXEC
// //
// using one rdb-server roundtrip. // using one rdb-server roundtrip.
_, err := pipe.Exec() _, err := pipe.Exec(ctx)
fmt.Println(incr.Val(), err) fmt.Println(incr.Val(), err)
// Output: 1 <nil> // Output: 1 <nil>
} }
@ -305,7 +307,7 @@ func ExampleClient_Watch() {
increment := func(key string) error { increment := func(key string) error {
txf := func(tx *redis.Tx) error { txf := func(tx *redis.Tx) error {
// get current value or zero // get current value or zero
n, err := tx.Get(key).Int() n, err := tx.Get(ctx, key).Int()
if err != nil && err != redis.Nil { if err != nil && err != redis.Nil {
return err return err
} }
@ -314,16 +316,16 @@ func ExampleClient_Watch() {
n++ n++
// runs only if the watched keys remain unchanged // runs only if the watched keys remain unchanged
_, err = tx.TxPipelined(func(pipe redis.Pipeliner) error { _, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
// pipe handles the error case // pipe handles the error case
pipe.Set(key, n, 0) pipe.Set(ctx, key, n, 0)
return nil return nil
}) })
return err return err
} }
for retries := routineCount; retries > 0; retries-- { for retries := routineCount; retries > 0; retries-- {
err := rdb.Watch(txf, key) err := rdb.Watch(ctx, txf, key)
if err != redis.TxFailedErr { if err != redis.TxFailedErr {
return err return err
} }
@ -345,16 +347,16 @@ func ExampleClient_Watch() {
} }
wg.Wait() wg.Wait()
n, err := rdb.Get("counter3").Int() n, err := rdb.Get(ctx, "counter3").Int()
fmt.Println("ended with", n, err) fmt.Println("ended with", n, err)
// Output: ended with 100 <nil> // Output: ended with 100 <nil>
} }
func ExamplePubSub() { func ExamplePubSub() {
pubsub := rdb.Subscribe("mychannel1") pubsub := rdb.Subscribe(ctx, "mychannel1")
// Wait for confirmation that subscription is created before publishing anything. // Wait for confirmation that subscription is created before publishing anything.
_, err := pubsub.Receive() _, err := pubsub.Receive(ctx)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -363,7 +365,7 @@ func ExamplePubSub() {
ch := pubsub.Channel() ch := pubsub.Channel()
// Publish a message. // Publish a message.
err = rdb.Publish("mychannel1", "hello").Err() err = rdb.Publish(ctx, "mychannel1", "hello").Err()
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -382,12 +384,12 @@ func ExamplePubSub() {
} }
func ExamplePubSub_Receive() { func ExamplePubSub_Receive() {
pubsub := rdb.Subscribe("mychannel2") pubsub := rdb.Subscribe(ctx, "mychannel2")
defer pubsub.Close() defer pubsub.Close()
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
// ReceiveTimeout is a low level API. Use ReceiveMessage instead. // ReceiveTimeout is a low level API. Use ReceiveMessage instead.
msgi, err := pubsub.ReceiveTimeout(time.Second) msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
if err != nil { if err != nil {
break break
} }
@ -396,7 +398,7 @@ func ExamplePubSub_Receive() {
case *redis.Subscription: case *redis.Subscription:
fmt.Println("subscribed to", msg.Channel) fmt.Println("subscribed to", msg.Channel)
_, err := rdb.Publish("mychannel2", "hello").Result() _, err := rdb.Publish(ctx, "mychannel2", "hello").Result()
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -419,15 +421,15 @@ func ExampleScript() {
return false return false
`) `)
n, err := IncrByXX.Run(rdb, []string{"xx_counter"}, 2).Result() n, err := IncrByXX.Run(ctx, rdb, []string{"xx_counter"}, 2).Result()
fmt.Println(n, err) fmt.Println(n, err)
err = rdb.Set("xx_counter", "40", 0).Err() err = rdb.Set(ctx, "xx_counter", "40", 0).Err()
if err != nil { if err != nil {
panic(err) panic(err)
} }
n, err = IncrByXX.Run(rdb, []string{"xx_counter"}, 2).Result() n, err = IncrByXX.Run(ctx, rdb, []string{"xx_counter"}, 2).Result()
fmt.Println(n, err) fmt.Println(n, err)
// Output: <nil> redis: nil // Output: <nil> redis: nil
@ -435,26 +437,26 @@ func ExampleScript() {
} }
func Example_customCommand() { func Example_customCommand() {
Get := func(rdb *redis.Client, key string) *redis.StringCmd { Get := func(ctx context.Context, rdb *redis.Client, key string) *redis.StringCmd {
cmd := redis.NewStringCmd("get", key) cmd := redis.NewStringCmd(ctx, "get", key)
rdb.Process(cmd) rdb.Process(ctx, cmd)
return cmd return cmd
} }
v, err := Get(rdb, "key_does_not_exist").Result() v, err := Get(ctx, rdb, "key_does_not_exist").Result()
fmt.Printf("%q %s", v, err) fmt.Printf("%q %s", v, err)
// Output: "" redis: nil // Output: "" redis: nil
} }
func Example_customCommand2() { func Example_customCommand2() {
v, err := rdb.Do("get", "key_does_not_exist").Text() v, err := rdb.Do(ctx, "get", "key_does_not_exist").Text()
fmt.Printf("%q %s", v, err) fmt.Printf("%q %s", v, err)
// Output: "" redis: nil // Output: "" redis: nil
} }
func ExampleScanIterator() { func ExampleScanIterator() {
iter := rdb.Scan(0, "", 0).Iterator() iter := rdb.Scan(ctx, 0, "", 0).Iterator()
for iter.Next() { for iter.Next(ctx) {
fmt.Println(iter.Val()) fmt.Println(iter.Val())
} }
if err := iter.Err(); err != nil { if err := iter.Err(); err != nil {
@ -463,8 +465,8 @@ func ExampleScanIterator() {
} }
func ExampleScanCmd_Iterator() { func ExampleScanCmd_Iterator() {
iter := rdb.Scan(0, "", 0).Iterator() iter := rdb.Scan(ctx, 0, "", 0).Iterator()
for iter.Next() { for iter.Next(ctx) {
fmt.Println(iter.Val()) fmt.Println(iter.Val())
} }
if err := iter.Err(); err != nil { if err := iter.Err(); err != nil {
@ -478,7 +480,7 @@ func ExampleNewUniversalClient_simple() {
}) })
defer rdb.Close() defer rdb.Close()
rdb.Ping() rdb.Ping(ctx)
} }
func ExampleNewUniversalClient_failover() { func ExampleNewUniversalClient_failover() {
@ -488,7 +490,7 @@ func ExampleNewUniversalClient_failover() {
}) })
defer rdb.Close() defer rdb.Close()
rdb.Ping() rdb.Ping(ctx)
} }
func ExampleNewUniversalClient_cluster() { func ExampleNewUniversalClient_cluster() {
@ -497,5 +499,5 @@ func ExampleNewUniversalClient_cluster() {
}) })
defer rdb.Close() defer rdb.Close()
rdb.Ping() rdb.Ping(ctx)
} }

View File

@ -1,12 +1,13 @@
package redis package redis
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"strings" "strings"
"github.com/go-redis/redis/v7/internal/hashtag" "github.com/go-redis/redis/v8/internal/hashtag"
"github.com/go-redis/redis/v7/internal/pool" "github.com/go-redis/redis/v8/internal/pool"
) )
func (c *baseClient) Pool() pool.Pooler { func (c *baseClient) Pool() pool.Pooler {
@ -17,12 +18,12 @@ func (c *PubSub) SetNetConn(netConn net.Conn) {
c.cn = pool.NewConn(netConn) c.cn = pool.NewConn(netConn)
} }
func (c *ClusterClient) LoadState() (*clusterState, error) { func (c *ClusterClient) LoadState(ctx context.Context) (*clusterState, error) {
return c.loadState() return c.loadState(ctx)
} }
func (c *ClusterClient) SlotAddrs(slot int) []string { func (c *ClusterClient) SlotAddrs(ctx context.Context, slot int) []string {
state, err := c.state.Get() state, err := c.state.Get(ctx)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -34,8 +35,8 @@ func (c *ClusterClient) SlotAddrs(slot int) []string {
return addrs return addrs
} }
func (c *ClusterClient) Nodes(key string) ([]*clusterNode, error) { func (c *ClusterClient) Nodes(ctx context.Context, key string) ([]*clusterNode, error) {
state, err := c.state.Reload() state, err := c.state.Reload(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -48,8 +49,8 @@ func (c *ClusterClient) Nodes(key string) ([]*clusterNode, error) {
return nodes, nil return nodes, nil
} }
func (c *ClusterClient) SwapNodes(key string) error { func (c *ClusterClient) SwapNodes(ctx context.Context, key string) error {
nodes, err := c.Nodes(key) nodes, err := c.Nodes(ctx, key)
if err != nil { if err != nil {
return err return err
} }
@ -57,12 +58,12 @@ func (c *ClusterClient) SwapNodes(key string) error {
return nil return nil
} }
func (state *clusterState) IsConsistent() bool { func (state *clusterState) IsConsistent(ctx context.Context) bool {
if len(state.Masters) < 3 { if len(state.Masters) < 3 {
return false return false
} }
for _, master := range state.Masters { for _, master := range state.Masters {
s := master.Client.Info("replication").Val() s := master.Client.Info(ctx, "replication").Val()
if !strings.Contains(s, "role:master") { if !strings.Contains(s, "role:master") {
return false return false
} }
@ -72,7 +73,7 @@ func (state *clusterState) IsConsistent() bool {
return false return false
} }
for _, slave := range state.Slaves { for _, slave := range state.Slaves {
s := slave.Client.Info("replication").Val() s := slave.Client.Info(ctx, "replication").Val()
if !strings.Contains(s, "role:slave") { if !strings.Contains(s, "role:slave") {
return false return false
} }

8
go.mod
View File

@ -1,15 +1,13 @@
module github.com/go-redis/redis/v7 module github.com/go-redis/redis/v8
require ( require (
github.com/golang/protobuf v1.3.2 // indirect
github.com/kr/pretty v0.1.0 // indirect
github.com/onsi/ginkgo v1.10.1 github.com/onsi/ginkgo v1.10.1
github.com/onsi/gomega v1.7.0 github.com/onsi/gomega v1.7.0
go.opentelemetry.io/otel v0.5.0
golang.org/x/net v0.0.0-20190923162816-aa69164e4478 // indirect golang.org/x/net v0.0.0-20190923162816-aa69164e4478 // indirect
golang.org/x/sys v0.0.0-20191010194322-b09406accb47 // indirect golang.org/x/sys v0.0.0-20191010194322-b09406accb47 // indirect
golang.org/x/text v0.3.2 // indirect golang.org/x/text v0.3.2 // indirect
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect google.golang.org/grpc v1.29.1 // indirect
gopkg.in/yaml.v2 v2.2.4 // indirect
) )
go 1.11 go 1.11

72
go.sum
View File

@ -1,9 +1,34 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/DataDog/sketches-go v0.0.0-20190923095040-43f19ad77ff7 h1:qELHH0AWCvf98Yf+CNIJx9vOZOfHFDDzgDRYsnNk/vs=
github.com/DataDog/sketches-go v0.0.0-20190923095040-43f19ad77ff7/go.mod h1:Q5DbzQ+3AkgGwymQO7aZFNP7ns2lZKGtvRBzRXfdi60=
github.com/benbjohnson/clock v1.0.0 h1:78Jk/r6m4wCi6sndMpty7A//t4dw/RW5fV4ZgDVfX1w=
github.com/benbjohnson/clock v1.0.0/go.mod h1:bGMdMPoPVvcYyt1gHDf4J2KE153Yf9BuiUKYMaxlTDM=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I=
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0 h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
@ -16,13 +41,34 @@ github.com/onsi/ginkgo v1.10.1 h1:q/mM8GF/n0shIN8SaAZ0V+jnLPzen6WIVZdiwrRlMlo=
github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/gomega v1.7.0 h1:XPnZz8VVBHjVsy1vzJmRwIcSwiUO+JFfrv/xGiigmME= github.com/onsi/gomega v1.7.0 h1:XPnZz8VVBHjVsy1vzJmRwIcSwiUO+JFfrv/xGiigmME=
github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/opentracing/opentracing-go v1.1.1-0.20190913142402-a7454ce5950e/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
go.opentelemetry.io/otel v0.5.0 h1:tdIR1veg/z+VRJaw/6SIxz+QX3l+m+BDleYLTs+GC1g=
go.opentelemetry.io/otel v0.5.0/go.mod h1:jzBIgIzK43Iu1BpDAXwqOd6UPsSAk+ewVZ5ofSXw4Ek=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190923162816-aa69164e4478 h1:l5EDrHhldLYb3ZRHDUhXF7Om7MvYXnkV9/iQNo1lX6g= golang.org/x/net v0.0.0-20190923162816-aa69164e4478 h1:l5EDrHhldLYb3ZRHDUhXF7Om7MvYXnkV9/iQNo1lX6g=
golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e h1:o3PsSEY8E4eXWkXrIP9YJALUkVZqzHJT5DOasTyn8Vs= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e h1:o3PsSEY8E4eXWkXrIP9YJALUkVZqzHJT5DOasTyn8Vs=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -33,6 +79,25 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20191009194640-548a555dbc03 h1:4HYDjxeNXAOTv3o1N2tjo8UUSlhQgAD52FVkwxnWgM8=
google.golang.org/genproto v0.0.0-20191009194640-548a555dbc03/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
google.golang.org/grpc v1.27.1 h1:zvIju4sqAGvwKspUQOhwnpcqSbzi7/H6QomNNjTL4sk=
google.golang.org/grpc v1.27.1/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
google.golang.org/grpc v1.29.1 h1:EC2SB8S04d2r73uptxphDSUG+kTKVgjRPF+N3xpxRB4=
google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
@ -43,5 +108,8 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkep
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.4 h1:/eiJrUcujPVeJ3xlSWaiNi3uSVmDGBK1pDHUHAnao1I= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

View File

@ -6,7 +6,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/go-redis/redis/v7/internal/pool" "github.com/go-redis/redis/v8/internal/pool"
) )
type poolGetPutBenchmark struct { type poolGetPutBenchmark struct {

View File

@ -6,7 +6,8 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/go-redis/redis/v7/internal/proto" "github.com/go-redis/redis/v8/internal"
"github.com/go-redis/redis/v8/internal/proto"
) )
var noDeadline = time.Time{} var noDeadline = time.Time{}
@ -58,16 +59,19 @@ func (cn *Conn) RemoteAddr() net.Addr {
} }
func (cn *Conn) WithReader(ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error) error { func (cn *Conn) WithReader(ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error) error {
return internal.WithSpan(ctx, "with_reader", func(ctx context.Context) error {
err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)) err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout))
if err != nil { if err != nil {
return err return err
} }
return fn(cn.rd) return fn(cn.rd)
})
} }
func (cn *Conn) WithWriter( func (cn *Conn) WithWriter(
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error, ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error,
) error { ) error {
return internal.WithSpan(ctx, "with_writer", func(ctx context.Context) error {
err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)) err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout))
if err != nil { if err != nil {
return err return err
@ -83,6 +87,7 @@ func (cn *Conn) WithWriter(
} }
return cn.wr.Flush() return cn.wr.Flush()
})
} }
func (cn *Conn) Close() error { func (cn *Conn) Close() error {

View File

@ -8,7 +8,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/go-redis/redis/v7/internal" "github.com/go-redis/redis/v8/internal"
) )
var ErrClosed = errors.New("redis: client is closed") var ErrClosed = errors.New("redis: client is closed")

View File

@ -6,7 +6,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/go-redis/redis/v7/internal/pool" "github.com/go-redis/redis/v8/internal/pool"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"

View File

@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"io" "io"
"github.com/go-redis/redis/v7/internal/util" "github.com/go-redis/redis/v8/internal/util"
) )
const ( const (

View File

@ -4,7 +4,7 @@ import (
"bytes" "bytes"
"testing" "testing"
"github.com/go-redis/redis/v7/internal/proto" "github.com/go-redis/redis/v8/internal/proto"
) )
func BenchmarkReader_ParseReply_Status(b *testing.B) { func BenchmarkReader_ParseReply_Status(b *testing.B) {

View File

@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"github.com/go-redis/redis/v7/internal/util" "github.com/go-redis/redis/v8/internal/util"
) )
func Scan(b []byte, v interface{}) error { func Scan(b []byte, v interface{}) error {

View File

@ -7,7 +7,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/go-redis/redis/v7/internal/proto" "github.com/go-redis/redis/v8/internal/proto"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"

View File

@ -8,7 +8,7 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/go-redis/redis/v7/internal/util" "github.com/go-redis/redis/v8/internal/util"
) )
type Writer struct { type Writer struct {
@ -93,7 +93,8 @@ func (w *Writer) writeArg(v interface{}) error {
} }
return w.int(0) return w.int(0)
case time.Time: case time.Time:
return w.string(v.Format(time.RFC3339Nano)) w.numBuf = v.AppendFormat(w.numBuf[:0], time.RFC3339Nano)
return w.bytes(w.numBuf)
case encoding.BinaryMarshaler: case encoding.BinaryMarshaler:
b, err := v.MarshalBinary() b, err := v.MarshalBinary()
if err != nil { if err != nil {

View File

@ -4,10 +4,13 @@ import (
"context" "context"
"time" "time"
"github.com/go-redis/redis/v7/internal/util" "github.com/go-redis/redis/v8/internal/util"
"go.opentelemetry.io/otel/api/global"
"go.opentelemetry.io/otel/api/trace"
) )
func Sleep(ctx context.Context, dur time.Duration) error { func Sleep(ctx context.Context, dur time.Duration) error {
return WithSpan(ctx, "sleep", func(ctx context.Context) error {
t := time.NewTimer(dur) t := time.NewTimer(dur)
defer t.Stop() defer t.Stop()
@ -17,6 +20,7 @@ func Sleep(ctx context.Context, dur time.Duration) error {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
} }
})
} }
func ToLower(s string) string { func ToLower(s string) string {
@ -54,3 +58,14 @@ func Unwrap(err error) error {
} }
return u.Unwrap() return u.Unwrap()
} }
func WithSpan(ctx context.Context, name string, fn func(context.Context) error) error {
if !trace.SpanFromContext(ctx).IsRecording() {
return fn(ctx)
}
ctx, span := global.Tracer("go-redis").Start(ctx, name)
defer span.End()
return fn(ctx)
}

View File

@ -1,6 +1,7 @@
package redis package redis
import ( import (
"context"
"sync" "sync"
) )
@ -21,7 +22,7 @@ func (it *ScanIterator) Err() error {
} }
// Next advances the cursor and returns true if more values can be read. // Next advances the cursor and returns true if more values can be read.
func (it *ScanIterator) Next() bool { func (it *ScanIterator) Next(ctx context.Context) bool {
it.mu.Lock() it.mu.Lock()
defer it.mu.Unlock() defer it.mu.Unlock()
@ -49,7 +50,7 @@ func (it *ScanIterator) Next() bool {
it.cmd.args[2] = it.cmd.cursor it.cmd.args[2] = it.cmd.cursor
} }
err := it.cmd.process(it.cmd) err := it.cmd.process(ctx, it.cmd)
if err != nil { if err != nil {
return false return false
} }

View File

@ -3,7 +3,7 @@ package redis_test
import ( import (
"fmt" "fmt"
"github.com/go-redis/redis/v7" "github.com/go-redis/redis/v8"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -15,21 +15,21 @@ var _ = Describe("ScanIterator", func() {
var seed = func(n int) error { var seed = func(n int) error {
pipe := client.Pipeline() pipe := client.Pipeline()
for i := 1; i <= n; i++ { for i := 1; i <= n; i++ {
pipe.Set(fmt.Sprintf("K%02d", i), "x", 0).Err() pipe.Set(ctx, fmt.Sprintf("K%02d", i), "x", 0).Err()
} }
_, err := pipe.Exec() _, err := pipe.Exec(ctx)
return err return err
} }
var extraSeed = func(n int, m int) error { var extraSeed = func(n int, m int) error {
pipe := client.Pipeline() pipe := client.Pipeline()
for i := 1; i <= m; i++ { for i := 1; i <= m; i++ {
pipe.Set(fmt.Sprintf("A%02d", i), "x", 0).Err() pipe.Set(ctx, fmt.Sprintf("A%02d", i), "x", 0).Err()
} }
for i := 1; i <= n; i++ { for i := 1; i <= n; i++ {
pipe.Set(fmt.Sprintf("K%02d", i), "x", 0).Err() pipe.Set(ctx, fmt.Sprintf("K%02d", i), "x", 0).Err()
} }
_, err := pipe.Exec() _, err := pipe.Exec(ctx)
return err return err
} }
@ -37,15 +37,15 @@ var _ = Describe("ScanIterator", func() {
var hashSeed = func(n int) error { var hashSeed = func(n int) error {
pipe := client.Pipeline() pipe := client.Pipeline()
for i := 1; i <= n; i++ { for i := 1; i <= n; i++ {
pipe.HSet(hashKey, fmt.Sprintf("K%02d", i), "x").Err() pipe.HSet(ctx, hashKey, fmt.Sprintf("K%02d", i), "x").Err()
} }
_, err := pipe.Exec() _, err := pipe.Exec(ctx)
return err return err
} }
BeforeEach(func() { BeforeEach(func() {
client = redis.NewClient(redisOptions()) client = redis.NewClient(redisOptions())
Expect(client.FlushDB().Err()).NotTo(HaveOccurred()) Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
}) })
AfterEach(func() { AfterEach(func() {
@ -53,8 +53,8 @@ var _ = Describe("ScanIterator", func() {
}) })
It("should scan across empty DBs", func() { It("should scan across empty DBs", func() {
iter := client.Scan(0, "", 10).Iterator() iter := client.Scan(ctx, 0, "", 10).Iterator()
Expect(iter.Next()).To(BeFalse()) Expect(iter.Next(ctx)).To(BeFalse())
Expect(iter.Err()).NotTo(HaveOccurred()) Expect(iter.Err()).NotTo(HaveOccurred())
}) })
@ -62,8 +62,8 @@ var _ = Describe("ScanIterator", func() {
Expect(seed(7)).NotTo(HaveOccurred()) Expect(seed(7)).NotTo(HaveOccurred())
var vals []string var vals []string
iter := client.Scan(0, "", 0).Iterator() iter := client.Scan(ctx, 0, "", 0).Iterator()
for iter.Next() { for iter.Next(ctx) {
vals = append(vals, iter.Val()) vals = append(vals, iter.Val())
} }
Expect(iter.Err()).NotTo(HaveOccurred()) Expect(iter.Err()).NotTo(HaveOccurred())
@ -74,8 +74,8 @@ var _ = Describe("ScanIterator", func() {
Expect(seed(71)).NotTo(HaveOccurred()) Expect(seed(71)).NotTo(HaveOccurred())
var vals []string var vals []string
iter := client.Scan(0, "", 10).Iterator() iter := client.Scan(ctx, 0, "", 10).Iterator()
for iter.Next() { for iter.Next(ctx) {
vals = append(vals, iter.Val()) vals = append(vals, iter.Val())
} }
Expect(iter.Err()).NotTo(HaveOccurred()) Expect(iter.Err()).NotTo(HaveOccurred())
@ -88,8 +88,8 @@ var _ = Describe("ScanIterator", func() {
Expect(hashSeed(71)).NotTo(HaveOccurred()) Expect(hashSeed(71)).NotTo(HaveOccurred())
var vals []string var vals []string
iter := client.HScan(hashKey, 0, "", 10).Iterator() iter := client.HScan(ctx, hashKey, 0, "", 10).Iterator()
for iter.Next() { for iter.Next(ctx) {
vals = append(vals, iter.Val()) vals = append(vals, iter.Val())
} }
Expect(iter.Err()).NotTo(HaveOccurred()) Expect(iter.Err()).NotTo(HaveOccurred())
@ -102,8 +102,8 @@ var _ = Describe("ScanIterator", func() {
Expect(seed(20)).NotTo(HaveOccurred()) Expect(seed(20)).NotTo(HaveOccurred())
var vals []string var vals []string
iter := client.Scan(0, "", 10).Iterator() iter := client.Scan(ctx, 0, "", 10).Iterator()
for iter.Next() { for iter.Next(ctx) {
vals = append(vals, iter.Val()) vals = append(vals, iter.Val())
} }
Expect(iter.Err()).NotTo(HaveOccurred()) Expect(iter.Err()).NotTo(HaveOccurred())
@ -114,8 +114,8 @@ var _ = Describe("ScanIterator", func() {
Expect(seed(33)).NotTo(HaveOccurred()) Expect(seed(33)).NotTo(HaveOccurred())
var vals []string var vals []string
iter := client.Scan(0, "K*2*", 10).Iterator() iter := client.Scan(ctx, 0, "K*2*", 10).Iterator()
for iter.Next() { for iter.Next(ctx) {
vals = append(vals, iter.Val()) vals = append(vals, iter.Val())
} }
Expect(iter.Err()).NotTo(HaveOccurred()) Expect(iter.Err()).NotTo(HaveOccurred())
@ -126,8 +126,8 @@ var _ = Describe("ScanIterator", func() {
Expect(extraSeed(2, 10)).NotTo(HaveOccurred()) Expect(extraSeed(2, 10)).NotTo(HaveOccurred())
var vals []string var vals []string
iter := client.Scan(0, "K*", 1).Iterator() iter := client.Scan(ctx, 0, "K*", 1).Iterator()
for iter.Next() { for iter.Next(ctx) {
vals = append(vals, iter.Val()) vals = append(vals, iter.Val())
} }
Expect(iter.Err()).NotTo(HaveOccurred()) Expect(iter.Err()).NotTo(HaveOccurred())

View File

@ -12,7 +12,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/go-redis/redis/v7" "github.com/go-redis/redis/v8"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -32,10 +32,10 @@ const (
const ( const (
sentinelName = "mymaster" sentinelName = "mymaster"
sentinelMasterPort = "8123" sentinelMasterPort = "9123"
sentinelSlave1Port = "8124" sentinelSlave1Port = "9124"
sentinelSlave2Port = "8125" sentinelSlave2Port = "9125"
sentinelPort = "8126" sentinelPort = "9126"
) )
var ( var (
@ -80,7 +80,7 @@ var _ = BeforeSuite(func() {
sentinelSlave2Port, "--slaveof", "127.0.0.1", sentinelMasterPort) sentinelSlave2Port, "--slaveof", "127.0.0.1", sentinelMasterPort)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(startCluster(cluster)).NotTo(HaveOccurred()) Expect(startCluster(ctx, cluster)).NotTo(HaveOccurred())
}) })
var _ = AfterSuite(func() { var _ = AfterSuite(func() {
@ -223,7 +223,7 @@ func connectTo(port string) (*redis.Client, error) {
}) })
err := eventually(func() error { err := eventually(func() error {
return client.Ping().Err() return client.Ping(ctx).Err()
}, 30*time.Second) }, 30*time.Second)
if err != nil { if err != nil {
return nil, err return nil, err
@ -243,7 +243,7 @@ func (p *redisProcess) Close() error {
} }
err := eventually(func() error { err := eventually(func() error {
if err := p.Client.Ping().Err(); err != nil { if err := p.Client.Ping(ctx).Err(); err != nil {
return nil return nil
} }
return errors.New("client is not shutdown") return errors.New("client is not shutdown")
@ -313,12 +313,12 @@ func startSentinel(port, masterName, masterPort string) (*redisProcess, error) {
return nil, err return nil, err
} }
for _, cmd := range []*redis.StatusCmd{ for _, cmd := range []*redis.StatusCmd{
redis.NewStatusCmd("SENTINEL", "MONITOR", masterName, "127.0.0.1", masterPort, "1"), redis.NewStatusCmd(ctx, "SENTINEL", "MONITOR", masterName, "127.0.0.1", masterPort, "1"),
redis.NewStatusCmd("SENTINEL", "SET", masterName, "down-after-milliseconds", "500"), redis.NewStatusCmd(ctx, "SENTINEL", "SET", masterName, "down-after-milliseconds", "500"),
redis.NewStatusCmd("SENTINEL", "SET", masterName, "failover-timeout", "1000"), redis.NewStatusCmd(ctx, "SENTINEL", "SET", masterName, "failover-timeout", "1000"),
redis.NewStatusCmd("SENTINEL", "SET", masterName, "parallel-syncs", "1"), redis.NewStatusCmd(ctx, "SENTINEL", "SET", masterName, "parallel-syncs", "1"),
} { } {
client.Process(cmd) client.Process(ctx, cmd)
if err := cmd.Err(); err != nil { if err := cmd.Err(); err != nil {
process.Kill() process.Kill()
return nil, err return nil, err

View File

@ -12,7 +12,8 @@ import (
"strings" "strings"
"time" "time"
"github.com/go-redis/redis/v7/internal/pool" "github.com/go-redis/redis/v8/internal"
"github.com/go-redis/redis/v8/internal/pool"
) )
// Limiter is the interface of a rate limiter or a circuit breaker. // Limiter is the interface of a rate limiter or a circuit breaker.
@ -237,7 +238,13 @@ func ParseURL(redisURL string) (*Options, error) {
func newConnPool(opt *Options) *pool.ConnPool { func newConnPool(opt *Options) *pool.ConnPool {
return pool.NewConnPool(&pool.Options{ return pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) { Dialer: func(ctx context.Context) (net.Conn, error) {
return opt.Dialer(ctx, opt.Network, opt.Addr) var conn net.Conn
err := internal.WithSpan(ctx, "dialer", func(ctx context.Context) error {
var err error
conn, err = opt.Dialer(ctx, opt.Network, opt.Addr)
return err
})
return conn, err
}, },
PoolSize: opt.PoolSize, PoolSize: opt.PoolSize,
MinIdleConns: opt.MinIdleConns, MinIdleConns: opt.MinIdleConns,

View File

@ -4,7 +4,7 @@ import (
"context" "context"
"sync" "sync"
"github.com/go-redis/redis/v7/internal/pool" "github.com/go-redis/redis/v8/internal/pool"
) )
type pipelineExecer func(context.Context, []Cmder) error type pipelineExecer func(context.Context, []Cmder) error
@ -24,12 +24,11 @@ type pipelineExecer func(context.Context, []Cmder) error
// depends of your batch size and/or use TxPipeline. // depends of your batch size and/or use TxPipeline.
type Pipeliner interface { type Pipeliner interface {
StatefulCmdable StatefulCmdable
Do(args ...interface{}) *Cmd Do(ctx context.Context, args ...interface{}) *Cmd
Process(cmd Cmder) error Process(ctx context.Context, cmd Cmder) error
Close() error Close() error
Discard() error Discard() error
Exec() ([]Cmder, error) Exec(ctx context.Context) ([]Cmder, error)
ExecContext(ctx context.Context) ([]Cmder, error)
} }
var _ Pipeliner = (*Pipeline)(nil) var _ Pipeliner = (*Pipeline)(nil)
@ -54,14 +53,14 @@ func (c *Pipeline) init() {
c.statefulCmdable = c.Process c.statefulCmdable = c.Process
} }
func (c *Pipeline) Do(args ...interface{}) *Cmd { func (c *Pipeline) Do(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(args...) cmd := NewCmd(ctx, args...)
_ = c.Process(cmd) _ = c.Process(ctx, cmd)
return cmd return cmd
} }
// Process queues the cmd for later execution. // Process queues the cmd for later execution.
func (c *Pipeline) Process(cmd Cmder) error { func (c *Pipeline) Process(ctx context.Context, cmd Cmder) error {
c.mu.Lock() c.mu.Lock()
c.cmds = append(c.cmds, cmd) c.cmds = append(c.cmds, cmd)
c.mu.Unlock() c.mu.Unlock()
@ -98,11 +97,7 @@ func (c *Pipeline) discard() error {
// //
// Exec always returns list of commands and error of the first failed // Exec always returns list of commands and error of the first failed
// command if any. // command if any.
func (c *Pipeline) Exec() ([]Cmder, error) { func (c *Pipeline) Exec(ctx context.Context) ([]Cmder, error) {
return c.ExecContext(c.ctx)
}
func (c *Pipeline) ExecContext(ctx context.Context) ([]Cmder, error) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -120,11 +115,11 @@ func (c *Pipeline) ExecContext(ctx context.Context) ([]Cmder, error) {
return cmds, c.exec(ctx, cmds) return cmds, c.exec(ctx, cmds)
} }
func (c *Pipeline) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { func (c *Pipeline) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
if err := fn(c); err != nil { if err := fn(c); err != nil {
return nil, err return nil, err
} }
cmds, err := c.Exec() cmds, err := c.Exec(ctx)
_ = c.Close() _ = c.Close()
return cmds, err return cmds, err
} }
@ -133,8 +128,8 @@ func (c *Pipeline) Pipeline() Pipeliner {
return c return c
} }
func (c *Pipeline) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) { func (c *Pipeline) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipelined(fn) return c.Pipelined(ctx, fn)
} }
func (c *Pipeline) TxPipeline() Pipeliner { func (c *Pipeline) TxPipeline() Pipeliner {

View File

@ -1,7 +1,7 @@
package redis_test package redis_test
import ( import (
"github.com/go-redis/redis/v7" "github.com/go-redis/redis/v8"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -13,7 +13,7 @@ var _ = Describe("pipelining", func() {
BeforeEach(func() { BeforeEach(func() {
client = redis.NewClient(redisOptions()) client = redis.NewClient(redisOptions())
Expect(client.FlushDB().Err()).NotTo(HaveOccurred()) Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
}) })
AfterEach(func() { AfterEach(func() {
@ -22,8 +22,8 @@ var _ = Describe("pipelining", func() {
It("supports block style", func() { It("supports block style", func() {
var get *redis.StringCmd var get *redis.StringCmd
cmds, err := client.Pipelined(func(pipe redis.Pipeliner) error { cmds, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error {
get = pipe.Get("foo") get = pipe.Get(ctx, "foo")
return nil return nil
}) })
Expect(err).To(Equal(redis.Nil)) Expect(err).To(Equal(redis.Nil))
@ -35,24 +35,24 @@ var _ = Describe("pipelining", func() {
assertPipeline := func() { assertPipeline := func() {
It("returns no errors when there are no commands", func() { It("returns no errors when there are no commands", func() {
_, err := pipe.Exec() _, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
It("discards queued commands", func() { It("discards queued commands", func() {
pipe.Get("key") pipe.Get(ctx, "key")
pipe.Discard() pipe.Discard()
cmds, err := pipe.Exec() cmds, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(BeNil()) Expect(cmds).To(BeNil())
}) })
It("handles val/err", func() { It("handles val/err", func() {
err := client.Set("key", "value", 0).Err() err := client.Set(ctx, "key", "value", 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
get := pipe.Get("key") get := pipe.Get(ctx, "key")
cmds, err := pipe.Exec() cmds, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(1)) Expect(cmds).To(HaveLen(1))
@ -62,8 +62,8 @@ var _ = Describe("pipelining", func() {
}) })
It("supports custom command", func() { It("supports custom command", func() {
pipe.Do("ping") pipe.Do(ctx, "ping")
cmds, err := pipe.Exec() cmds, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(1)) Expect(cmds).To(HaveLen(1))
}) })

View File

@ -4,7 +4,7 @@ import (
"context" "context"
"time" "time"
"github.com/go-redis/redis/v7" "github.com/go-redis/redis/v8"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -27,7 +27,7 @@ var _ = Describe("pool", func() {
It("respects max size", func() { It("respects max size", func() {
perform(1000, func(id int) { perform(1000, func(id int) {
val, err := client.Ping().Result() val, err := client.Ping(ctx).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("PONG")) Expect(val).To(Equal("PONG"))
}) })
@ -42,9 +42,9 @@ var _ = Describe("pool", func() {
perform(1000, func(id int) { perform(1000, func(id int) {
var ping *redis.StatusCmd var ping *redis.StatusCmd
err := client.Watch(func(tx *redis.Tx) error { err := client.Watch(ctx, func(tx *redis.Tx) error {
cmds, err := tx.Pipelined(func(pipe redis.Pipeliner) error { cmds, err := tx.Pipelined(ctx, func(pipe redis.Pipeliner) error {
ping = pipe.Ping() ping = pipe.Ping(ctx)
return nil return nil
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -66,8 +66,8 @@ var _ = Describe("pool", func() {
It("respects max size on pipelines", func() { It("respects max size on pipelines", func() {
perform(1000, func(id int) { perform(1000, func(id int) {
pipe := client.Pipeline() pipe := client.Pipeline()
ping := pipe.Ping() ping := pipe.Ping(ctx)
cmds, err := pipe.Exec() cmds, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(1)) Expect(cmds).To(HaveLen(1))
Expect(ping.Err()).NotTo(HaveOccurred()) Expect(ping.Err()).NotTo(HaveOccurred())
@ -87,10 +87,10 @@ var _ = Describe("pool", func() {
cn.SetNetConn(&badConn{}) cn.SetNetConn(&badConn{})
client.Pool().Put(cn) client.Pool().Put(cn)
err = client.Ping().Err() err = client.Ping(ctx).Err()
Expect(err).To(MatchError("bad connection")) Expect(err).To(MatchError("bad connection"))
val, err := client.Ping().Result() val, err := client.Ping(ctx).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("PONG")) Expect(val).To(Equal("PONG"))
@ -106,7 +106,7 @@ var _ = Describe("pool", func() {
It("reuses connections", func() { It("reuses connections", func() {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
val, err := client.Ping().Result() val, err := client.Ping(ctx).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("PONG")) Expect(val).To(Equal("PONG"))
} }
@ -122,7 +122,7 @@ var _ = Describe("pool", func() {
}) })
It("removes idle connections", func() { It("removes idle connections", func() {
err := client.Ping().Err() err := client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
stats := client.PoolStats() stats := client.PoolStats()

112
pubsub.go
View File

@ -8,9 +8,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/go-redis/redis/v7/internal" "github.com/go-redis/redis/v8/internal"
"github.com/go-redis/redis/v7/internal/pool" "github.com/go-redis/redis/v8/internal/pool"
"github.com/go-redis/redis/v7/internal/proto" "github.com/go-redis/redis/v8/internal/proto"
) )
const pingTimeout = 30 * time.Second const pingTimeout = 30 * time.Second
@ -26,7 +26,7 @@ var errPingTimeout = errors.New("redis: ping timeout")
type PubSub struct { type PubSub struct {
opt *Options opt *Options
newConn func([]string) (*pool.Conn, error) newConn func(ctx context.Context, channels []string) (*pool.Conn, error)
closeConn func(*pool.Conn) error closeConn func(*pool.Conn) error
mu sync.Mutex mu sync.Mutex
@ -55,14 +55,14 @@ func (c *PubSub) init() {
c.exit = make(chan struct{}) c.exit = make(chan struct{})
} }
func (c *PubSub) connWithLock() (*pool.Conn, error) { func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) {
c.mu.Lock() c.mu.Lock()
cn, err := c.conn(nil) cn, err := c.conn(ctx, nil)
c.mu.Unlock() c.mu.Unlock()
return cn, err return cn, err
} }
func (c *PubSub) conn(newChannels []string) (*pool.Conn, error) { func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, error) {
if c.closed { if c.closed {
return nil, pool.ErrClosed return nil, pool.ErrClosed
} }
@ -73,12 +73,12 @@ func (c *PubSub) conn(newChannels []string) (*pool.Conn, error) {
channels := mapKeys(c.channels) channels := mapKeys(c.channels)
channels = append(channels, newChannels...) channels = append(channels, newChannels...)
cn, err := c.newConn(channels) cn, err := c.newConn(ctx, channels)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := c.resubscribe(cn); err != nil { if err := c.resubscribe(ctx, cn); err != nil {
_ = c.closeConn(cn) _ = c.closeConn(cn)
return nil, err return nil, err
} }
@ -93,15 +93,15 @@ func (c *PubSub) writeCmd(ctx context.Context, cn *pool.Conn, cmd Cmder) error {
}) })
} }
func (c *PubSub) resubscribe(cn *pool.Conn) error { func (c *PubSub) resubscribe(ctx context.Context, cn *pool.Conn) error {
var firstErr error var firstErr error
if len(c.channels) > 0 { if len(c.channels) > 0 {
firstErr = c._subscribe(cn, "subscribe", mapKeys(c.channels)) firstErr = c._subscribe(ctx, cn, "subscribe", mapKeys(c.channels))
} }
if len(c.patterns) > 0 { if len(c.patterns) > 0 {
err := c._subscribe(cn, "psubscribe", mapKeys(c.patterns)) err := c._subscribe(ctx, cn, "psubscribe", mapKeys(c.patterns))
if err != nil && firstErr == nil { if err != nil && firstErr == nil {
firstErr = err firstErr = err
} }
@ -121,35 +121,40 @@ func mapKeys(m map[string]struct{}) []string {
} }
func (c *PubSub) _subscribe( func (c *PubSub) _subscribe(
cn *pool.Conn, redisCmd string, channels []string, ctx context.Context, cn *pool.Conn, redisCmd string, channels []string,
) error { ) error {
args := make([]interface{}, 0, 1+len(channels)) args := make([]interface{}, 0, 1+len(channels))
args = append(args, redisCmd) args = append(args, redisCmd)
for _, channel := range channels { for _, channel := range channels {
args = append(args, channel) args = append(args, channel)
} }
cmd := NewSliceCmd(args...) cmd := NewSliceCmd(ctx, args...)
return c.writeCmd(context.TODO(), cn, cmd) return c.writeCmd(ctx, cn, cmd)
} }
func (c *PubSub) releaseConnWithLock(cn *pool.Conn, err error, allowTimeout bool) { func (c *PubSub) releaseConnWithLock(
ctx context.Context,
cn *pool.Conn,
err error,
allowTimeout bool,
) {
c.mu.Lock() c.mu.Lock()
c.releaseConn(cn, err, allowTimeout) c.releaseConn(ctx, cn, err, allowTimeout)
c.mu.Unlock() c.mu.Unlock()
} }
func (c *PubSub) releaseConn(cn *pool.Conn, err error, allowTimeout bool) { func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allowTimeout bool) {
if c.cn != cn { if c.cn != cn {
return return
} }
if isBadConn(err, allowTimeout) { if isBadConn(err, allowTimeout) {
c.reconnect(err) c.reconnect(ctx, err)
} }
} }
func (c *PubSub) reconnect(reason error) { func (c *PubSub) reconnect(ctx context.Context, reason error) {
_ = c.closeTheCn(reason) _ = c.closeTheCn(reason)
_, _ = c.conn(nil) _, _ = c.conn(ctx, nil)
} }
func (c *PubSub) closeTheCn(reason error) error { func (c *PubSub) closeTheCn(reason error) error {
@ -179,11 +184,11 @@ func (c *PubSub) Close() error {
// Subscribe the client to the specified channels. It returns // Subscribe the client to the specified channels. It returns
// empty subscription if there are no channels. // empty subscription if there are no channels.
func (c *PubSub) Subscribe(channels ...string) error { func (c *PubSub) Subscribe(ctx context.Context, channels ...string) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
err := c.subscribe("subscribe", channels...) err := c.subscribe(ctx, "subscribe", channels...)
if c.channels == nil { if c.channels == nil {
c.channels = make(map[string]struct{}) c.channels = make(map[string]struct{})
} }
@ -195,11 +200,11 @@ func (c *PubSub) Subscribe(channels ...string) error {
// PSubscribe the client to the given patterns. It returns // PSubscribe the client to the given patterns. It returns
// empty subscription if there are no patterns. // empty subscription if there are no patterns.
func (c *PubSub) PSubscribe(patterns ...string) error { func (c *PubSub) PSubscribe(ctx context.Context, patterns ...string) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
err := c.subscribe("psubscribe", patterns...) err := c.subscribe(ctx, "psubscribe", patterns...)
if c.patterns == nil { if c.patterns == nil {
c.patterns = make(map[string]struct{}) c.patterns = make(map[string]struct{})
} }
@ -211,55 +216,55 @@ func (c *PubSub) PSubscribe(patterns ...string) error {
// Unsubscribe the client from the given channels, or from all of // Unsubscribe the client from the given channels, or from all of
// them if none is given. // them if none is given.
func (c *PubSub) Unsubscribe(channels ...string) error { func (c *PubSub) Unsubscribe(ctx context.Context, channels ...string) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
for _, channel := range channels { for _, channel := range channels {
delete(c.channels, channel) delete(c.channels, channel)
} }
err := c.subscribe("unsubscribe", channels...) err := c.subscribe(ctx, "unsubscribe", channels...)
return err return err
} }
// PUnsubscribe the client from the given patterns, or from all of // PUnsubscribe the client from the given patterns, or from all of
// them if none is given. // them if none is given.
func (c *PubSub) PUnsubscribe(patterns ...string) error { func (c *PubSub) PUnsubscribe(ctx context.Context, patterns ...string) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
for _, pattern := range patterns { for _, pattern := range patterns {
delete(c.patterns, pattern) delete(c.patterns, pattern)
} }
err := c.subscribe("punsubscribe", patterns...) err := c.subscribe(ctx, "punsubscribe", patterns...)
return err return err
} }
func (c *PubSub) subscribe(redisCmd string, channels ...string) error { func (c *PubSub) subscribe(ctx context.Context, redisCmd string, channels ...string) error {
cn, err := c.conn(channels) cn, err := c.conn(ctx, channels)
if err != nil { if err != nil {
return err return err
} }
err = c._subscribe(cn, redisCmd, channels) err = c._subscribe(ctx, cn, redisCmd, channels)
c.releaseConn(cn, err, false) c.releaseConn(ctx, cn, err, false)
return err return err
} }
func (c *PubSub) Ping(payload ...string) error { func (c *PubSub) Ping(ctx context.Context, payload ...string) error {
args := []interface{}{"ping"} args := []interface{}{"ping"}
if len(payload) == 1 { if len(payload) == 1 {
args = append(args, payload[0]) args = append(args, payload[0])
} }
cmd := NewCmd(args...) cmd := NewCmd(ctx, args...)
cn, err := c.connWithLock() cn, err := c.connWithLock(ctx)
if err != nil { if err != nil {
return err return err
} }
err = c.writeCmd(context.TODO(), cn, cmd) err = c.writeCmd(ctx, cn, cmd)
c.releaseConnWithLock(cn, err, false) c.releaseConnWithLock(ctx, cn, err, false)
return err return err
} }
@ -342,21 +347,21 @@ func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
// ReceiveTimeout acts like Receive but returns an error if message // ReceiveTimeout acts like Receive but returns an error if message
// is not received in time. This is low-level API and in most cases // is not received in time. This is low-level API and in most cases
// Channel should be used instead. // Channel should be used instead.
func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (interface{}, error) {
if c.cmd == nil { if c.cmd == nil {
c.cmd = NewCmd() c.cmd = NewCmd(ctx)
} }
cn, err := c.connWithLock() cn, err := c.connWithLock(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = cn.WithReader(context.TODO(), timeout, func(rd *proto.Reader) error { err = cn.WithReader(ctx, timeout, func(rd *proto.Reader) error {
return c.cmd.readReply(rd) return c.cmd.readReply(rd)
}) })
c.releaseConnWithLock(cn, err, timeout > 0) c.releaseConnWithLock(ctx, cn, err, timeout > 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -367,16 +372,16 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
// Receive returns a message as a Subscription, Message, Pong or error. // Receive returns a message as a Subscription, Message, Pong or error.
// See PubSub example for details. This is low-level API and in most cases // See PubSub example for details. This is low-level API and in most cases
// Channel should be used instead. // Channel should be used instead.
func (c *PubSub) Receive() (interface{}, error) { func (c *PubSub) Receive(ctx context.Context) (interface{}, error) {
return c.ReceiveTimeout(0) return c.ReceiveTimeout(ctx, 0)
} }
// ReceiveMessage returns a Message or error ignoring Subscription and Pong // ReceiveMessage returns a Message or error ignoring Subscription and Pong
// messages. This is low-level API and in most cases Channel should be used // messages. This is low-level API and in most cases Channel should be used
// instead. // instead.
func (c *PubSub) ReceiveMessage() (*Message, error) { func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) {
for { for {
msg, err := c.Receive() msg, err := c.Receive(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -429,7 +434,7 @@ func (c *PubSub) ChannelSize(size int) <-chan *Message {
// reconnections. // reconnections.
// //
// ChannelWithSubscriptions can not be used together with Channel or ChannelSize. // ChannelWithSubscriptions can not be used together with Channel or ChannelSize.
func (c *PubSub) ChannelWithSubscriptions(size int) <-chan interface{} { func (c *PubSub) ChannelWithSubscriptions(ctx context.Context, size int) <-chan interface{} {
c.chOnce.Do(func() { c.chOnce.Do(func() {
c.initPing() c.initPing()
c.initAllChan(size) c.initAllChan(size)
@ -446,6 +451,7 @@ func (c *PubSub) ChannelWithSubscriptions(size int) <-chan interface{} {
} }
func (c *PubSub) initPing() { func (c *PubSub) initPing() {
ctx := context.TODO()
c.ping = make(chan struct{}, 1) c.ping = make(chan struct{}, 1)
go func() { go func() {
timer := time.NewTimer(pingTimeout) timer := time.NewTimer(pingTimeout)
@ -461,7 +467,7 @@ func (c *PubSub) initPing() {
<-timer.C <-timer.C
} }
case <-timer.C: case <-timer.C:
pingErr := c.Ping() pingErr := c.Ping(ctx)
if healthy { if healthy {
healthy = false healthy = false
} else { } else {
@ -469,7 +475,7 @@ func (c *PubSub) initPing() {
pingErr = errPingTimeout pingErr = errPingTimeout
} }
c.mu.Lock() c.mu.Lock()
c.reconnect(pingErr) c.reconnect(ctx, pingErr)
healthy = true healthy = true
c.mu.Unlock() c.mu.Unlock()
} }
@ -482,6 +488,7 @@ func (c *PubSub) initPing() {
// initMsgChan must be in sync with initAllChan. // initMsgChan must be in sync with initAllChan.
func (c *PubSub) initMsgChan(size int) { func (c *PubSub) initMsgChan(size int) {
ctx := context.TODO()
c.msgCh = make(chan *Message, size) c.msgCh = make(chan *Message, size)
go func() { go func() {
timer := time.NewTimer(pingTimeout) timer := time.NewTimer(pingTimeout)
@ -489,7 +496,7 @@ func (c *PubSub) initMsgChan(size int) {
var errCount int var errCount int
for { for {
msg, err := c.Receive() msg, err := c.Receive(ctx)
if err != nil { if err != nil {
if err == pool.ErrClosed { if err == pool.ErrClosed {
close(c.msgCh) close(c.msgCh)
@ -535,6 +542,7 @@ func (c *PubSub) initMsgChan(size int) {
// initAllChan must be in sync with initMsgChan. // initAllChan must be in sync with initMsgChan.
func (c *PubSub) initAllChan(size int) { func (c *PubSub) initAllChan(size int) {
ctx := context.TODO()
c.allCh = make(chan interface{}, size) c.allCh = make(chan interface{}, size)
go func() { go func() {
timer := time.NewTimer(pingTimeout) timer := time.NewTimer(pingTimeout)
@ -542,7 +550,7 @@ func (c *PubSub) initAllChan(size int) {
var errCount int var errCount int
for { for {
msg, err := c.Receive() msg, err := c.Receive(ctx)
if err != nil { if err != nil {
if err == pool.ErrClosed { if err == pool.ErrClosed {
close(c.allCh) close(c.allCh)

View File

@ -6,7 +6,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/go-redis/redis/v7" "github.com/go-redis/redis/v8"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -20,7 +20,7 @@ var _ = Describe("PubSub", func() {
opt.MinIdleConns = 0 opt.MinIdleConns = 0
opt.MaxConnAge = 0 opt.MaxConnAge = 0
client = redis.NewClient(opt) client = redis.NewClient(opt)
Expect(client.FlushDB().Err()).NotTo(HaveOccurred()) Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
}) })
AfterEach(func() { AfterEach(func() {
@ -28,18 +28,18 @@ var _ = Describe("PubSub", func() {
}) })
It("implements Stringer", func() { It("implements Stringer", func() {
pubsub := client.PSubscribe("mychannel*") pubsub := client.PSubscribe(ctx, "mychannel*")
defer pubsub.Close() defer pubsub.Close()
Expect(pubsub.String()).To(Equal("PubSub(mychannel*)")) Expect(pubsub.String()).To(Equal("PubSub(mychannel*)"))
}) })
It("should support pattern matching", func() { It("should support pattern matching", func() {
pubsub := client.PSubscribe("mychannel*") pubsub := client.PSubscribe(ctx, "mychannel*")
defer pubsub.Close() defer pubsub.Close()
{ {
msgi, err := pubsub.ReceiveTimeout(time.Second) msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
subscr := msgi.(*redis.Subscription) subscr := msgi.(*redis.Subscription)
Expect(subscr.Kind).To(Equal("psubscribe")) Expect(subscr.Kind).To(Equal("psubscribe"))
@ -48,19 +48,19 @@ var _ = Describe("PubSub", func() {
} }
{ {
msgi, err := pubsub.ReceiveTimeout(time.Second) msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err.(net.Error).Timeout()).To(Equal(true)) Expect(err.(net.Error).Timeout()).To(Equal(true))
Expect(msgi).To(BeNil()) Expect(msgi).To(BeNil())
} }
n, err := client.Publish("mychannel1", "hello").Result() n, err := client.Publish(ctx, "mychannel1", "hello").Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1))) Expect(n).To(Equal(int64(1)))
Expect(pubsub.PUnsubscribe("mychannel*")).NotTo(HaveOccurred()) Expect(pubsub.PUnsubscribe(ctx, "mychannel*")).NotTo(HaveOccurred())
{ {
msgi, err := pubsub.ReceiveTimeout(time.Second) msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
subscr := msgi.(*redis.Message) subscr := msgi.(*redis.Message)
Expect(subscr.Channel).To(Equal("mychannel1")) Expect(subscr.Channel).To(Equal("mychannel1"))
@ -69,7 +69,7 @@ var _ = Describe("PubSub", func() {
} }
{ {
msgi, err := pubsub.ReceiveTimeout(time.Second) msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
subscr := msgi.(*redis.Subscription) subscr := msgi.(*redis.Subscription)
Expect(subscr.Kind).To(Equal("punsubscribe")) Expect(subscr.Kind).To(Equal("punsubscribe"))
@ -82,31 +82,31 @@ var _ = Describe("PubSub", func() {
}) })
It("should pub/sub channels", func() { It("should pub/sub channels", func() {
channels, err := client.PubSubChannels("mychannel*").Result() channels, err := client.PubSubChannels(ctx, "mychannel*").Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(channels).To(BeEmpty()) Expect(channels).To(BeEmpty())
pubsub := client.Subscribe("mychannel", "mychannel2") pubsub := client.Subscribe(ctx, "mychannel", "mychannel2")
defer pubsub.Close() defer pubsub.Close()
channels, err = client.PubSubChannels("mychannel*").Result() channels, err = client.PubSubChannels(ctx, "mychannel*").Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(channels).To(ConsistOf([]string{"mychannel", "mychannel2"})) Expect(channels).To(ConsistOf([]string{"mychannel", "mychannel2"}))
channels, err = client.PubSubChannels("").Result() channels, err = client.PubSubChannels(ctx, "").Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(channels).To(BeEmpty()) Expect(channels).To(BeEmpty())
channels, err = client.PubSubChannels("*").Result() channels, err = client.PubSubChannels(ctx, "*").Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(len(channels)).To(BeNumerically(">=", 2)) Expect(len(channels)).To(BeNumerically(">=", 2))
}) })
It("should return the numbers of subscribers", func() { It("should return the numbers of subscribers", func() {
pubsub := client.Subscribe("mychannel", "mychannel2") pubsub := client.Subscribe(ctx, "mychannel", "mychannel2")
defer pubsub.Close() defer pubsub.Close()
channels, err := client.PubSubNumSub("mychannel", "mychannel2", "mychannel3").Result() channels, err := client.PubSubNumSub(ctx, "mychannel", "mychannel2", "mychannel3").Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(channels).To(Equal(map[string]int64{ Expect(channels).To(Equal(map[string]int64{
"mychannel": 1, "mychannel": 1,
@ -116,24 +116,24 @@ var _ = Describe("PubSub", func() {
}) })
It("should return the numbers of subscribers by pattern", func() { It("should return the numbers of subscribers by pattern", func() {
num, err := client.PubSubNumPat().Result() num, err := client.PubSubNumPat(ctx).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(num).To(Equal(int64(0))) Expect(num).To(Equal(int64(0)))
pubsub := client.PSubscribe("*") pubsub := client.PSubscribe(ctx, "*")
defer pubsub.Close() defer pubsub.Close()
num, err = client.PubSubNumPat().Result() num, err = client.PubSubNumPat(ctx).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(num).To(Equal(int64(1))) Expect(num).To(Equal(int64(1)))
}) })
It("should pub/sub", func() { It("should pub/sub", func() {
pubsub := client.Subscribe("mychannel", "mychannel2") pubsub := client.Subscribe(ctx, "mychannel", "mychannel2")
defer pubsub.Close() defer pubsub.Close()
{ {
msgi, err := pubsub.ReceiveTimeout(time.Second) msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
subscr := msgi.(*redis.Subscription) subscr := msgi.(*redis.Subscription)
Expect(subscr.Kind).To(Equal("subscribe")) Expect(subscr.Kind).To(Equal("subscribe"))
@ -142,7 +142,7 @@ var _ = Describe("PubSub", func() {
} }
{ {
msgi, err := pubsub.ReceiveTimeout(time.Second) msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
subscr := msgi.(*redis.Subscription) subscr := msgi.(*redis.Subscription)
Expect(subscr.Kind).To(Equal("subscribe")) Expect(subscr.Kind).To(Equal("subscribe"))
@ -151,23 +151,23 @@ var _ = Describe("PubSub", func() {
} }
{ {
msgi, err := pubsub.ReceiveTimeout(time.Second) msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err.(net.Error).Timeout()).To(Equal(true)) Expect(err.(net.Error).Timeout()).To(Equal(true))
Expect(msgi).NotTo(HaveOccurred()) Expect(msgi).NotTo(HaveOccurred())
} }
n, err := client.Publish("mychannel", "hello").Result() n, err := client.Publish(ctx, "mychannel", "hello").Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1))) Expect(n).To(Equal(int64(1)))
n, err = client.Publish("mychannel2", "hello2").Result() n, err = client.Publish(ctx, "mychannel2", "hello2").Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1))) Expect(n).To(Equal(int64(1)))
Expect(pubsub.Unsubscribe("mychannel", "mychannel2")).NotTo(HaveOccurred()) Expect(pubsub.Unsubscribe(ctx, "mychannel", "mychannel2")).NotTo(HaveOccurred())
{ {
msgi, err := pubsub.ReceiveTimeout(time.Second) msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
msg := msgi.(*redis.Message) msg := msgi.(*redis.Message)
Expect(msg.Channel).To(Equal("mychannel")) Expect(msg.Channel).To(Equal("mychannel"))
@ -175,7 +175,7 @@ var _ = Describe("PubSub", func() {
} }
{ {
msgi, err := pubsub.ReceiveTimeout(time.Second) msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
msg := msgi.(*redis.Message) msg := msgi.(*redis.Message)
Expect(msg.Channel).To(Equal("mychannel2")) Expect(msg.Channel).To(Equal("mychannel2"))
@ -183,7 +183,7 @@ var _ = Describe("PubSub", func() {
} }
{ {
msgi, err := pubsub.ReceiveTimeout(time.Second) msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
subscr := msgi.(*redis.Subscription) subscr := msgi.(*redis.Subscription)
Expect(subscr.Kind).To(Equal("unsubscribe")) Expect(subscr.Kind).To(Equal("unsubscribe"))
@ -192,7 +192,7 @@ var _ = Describe("PubSub", func() {
} }
{ {
msgi, err := pubsub.ReceiveTimeout(time.Second) msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
subscr := msgi.(*redis.Subscription) subscr := msgi.(*redis.Subscription)
Expect(subscr.Kind).To(Equal("unsubscribe")) Expect(subscr.Kind).To(Equal("unsubscribe"))
@ -205,42 +205,42 @@ var _ = Describe("PubSub", func() {
}) })
It("should ping/pong", func() { It("should ping/pong", func() {
pubsub := client.Subscribe("mychannel") pubsub := client.Subscribe(ctx, "mychannel")
defer pubsub.Close() defer pubsub.Close()
_, err := pubsub.ReceiveTimeout(time.Second) _, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = pubsub.Ping("") err = pubsub.Ping(ctx, "")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
msgi, err := pubsub.ReceiveTimeout(time.Second) msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
pong := msgi.(*redis.Pong) pong := msgi.(*redis.Pong)
Expect(pong.Payload).To(Equal("")) Expect(pong.Payload).To(Equal(""))
}) })
It("should ping/pong with payload", func() { It("should ping/pong with payload", func() {
pubsub := client.Subscribe("mychannel") pubsub := client.Subscribe(ctx, "mychannel")
defer pubsub.Close() defer pubsub.Close()
_, err := pubsub.ReceiveTimeout(time.Second) _, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = pubsub.Ping("hello") err = pubsub.Ping(ctx, "hello")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
msgi, err := pubsub.ReceiveTimeout(time.Second) msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
pong := msgi.(*redis.Pong) pong := msgi.(*redis.Pong)
Expect(pong.Payload).To(Equal("hello")) Expect(pong.Payload).To(Equal("hello"))
}) })
It("should multi-ReceiveMessage", func() { It("should multi-ReceiveMessage", func() {
pubsub := client.Subscribe("mychannel") pubsub := client.Subscribe(ctx, "mychannel")
defer pubsub.Close() defer pubsub.Close()
subscr, err := pubsub.ReceiveTimeout(time.Second) subscr, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(subscr).To(Equal(&redis.Subscription{ Expect(subscr).To(Equal(&redis.Subscription{
Kind: "subscribe", Kind: "subscribe",
@ -248,25 +248,25 @@ var _ = Describe("PubSub", func() {
Count: 1, Count: 1,
})) }))
err = client.Publish("mychannel", "hello").Err() err = client.Publish(ctx, "mychannel", "hello").Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = client.Publish("mychannel", "world").Err() err = client.Publish(ctx, "mychannel", "world").Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
msg, err := pubsub.ReceiveMessage() msg, err := pubsub.ReceiveMessage(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(msg.Channel).To(Equal("mychannel")) Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal("hello")) Expect(msg.Payload).To(Equal("hello"))
msg, err = pubsub.ReceiveMessage() msg, err = pubsub.ReceiveMessage(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(msg.Channel).To(Equal("mychannel")) Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal("world")) Expect(msg.Payload).To(Equal("world"))
}) })
It("returns an error when subscribe fails", func() { It("returns an error when subscribe fails", func() {
pubsub := client.Subscribe() pubsub := client.Subscribe(ctx)
defer pubsub.Close() defer pubsub.Close()
pubsub.SetNetConn(&badConn{ pubsub.SetNetConn(&badConn{
@ -274,10 +274,10 @@ var _ = Describe("PubSub", func() {
writeErr: io.EOF, writeErr: io.EOF,
}) })
err := pubsub.Subscribe("mychannel") err := pubsub.Subscribe(ctx, "mychannel")
Expect(err).To(MatchError("EOF")) Expect(err).To(MatchError("EOF"))
err = pubsub.Subscribe("mychannel") err = pubsub.Subscribe(ctx, "mychannel")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
@ -293,16 +293,16 @@ var _ = Describe("PubSub", func() {
defer GinkgoRecover() defer GinkgoRecover()
Eventually(step).Should(Receive()) Eventually(step).Should(Receive())
err := client.Publish("mychannel", "hello").Err() err := client.Publish(ctx, "mychannel", "hello").Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
step <- struct{}{} step <- struct{}{}
}() }()
_, err := pubsub.ReceiveMessage() _, err := pubsub.ReceiveMessage(ctx)
Expect(err).To(Equal(io.EOF)) Expect(err).To(Equal(io.EOF))
step <- struct{}{} step <- struct{}{}
msg, err := pubsub.ReceiveMessage() msg, err := pubsub.ReceiveMessage(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(msg.Channel).To(Equal("mychannel")) Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal("hello")) Expect(msg.Payload).To(Equal("hello"))
@ -311,10 +311,10 @@ var _ = Describe("PubSub", func() {
} }
It("Subscribe should reconnect on ReceiveMessage error", func() { It("Subscribe should reconnect on ReceiveMessage error", func() {
pubsub := client.Subscribe("mychannel") pubsub := client.Subscribe(ctx, "mychannel")
defer pubsub.Close() defer pubsub.Close()
subscr, err := pubsub.ReceiveTimeout(time.Second) subscr, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(subscr).To(Equal(&redis.Subscription{ Expect(subscr).To(Equal(&redis.Subscription{
Kind: "subscribe", Kind: "subscribe",
@ -326,10 +326,10 @@ var _ = Describe("PubSub", func() {
}) })
It("PSubscribe should reconnect on ReceiveMessage error", func() { It("PSubscribe should reconnect on ReceiveMessage error", func() {
pubsub := client.PSubscribe("mychannel") pubsub := client.PSubscribe(ctx, "mychannel")
defer pubsub.Close() defer pubsub.Close()
subscr, err := pubsub.ReceiveTimeout(time.Second) subscr, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(subscr).To(Equal(&redis.Subscription{ Expect(subscr).To(Equal(&redis.Subscription{
Kind: "psubscribe", Kind: "psubscribe",
@ -341,7 +341,7 @@ var _ = Describe("PubSub", func() {
}) })
It("should return on Close", func() { It("should return on Close", func() {
pubsub := client.Subscribe("mychannel") pubsub := client.Subscribe(ctx, "mychannel")
defer pubsub.Close() defer pubsub.Close()
var wg sync.WaitGroup var wg sync.WaitGroup
@ -352,7 +352,7 @@ var _ = Describe("PubSub", func() {
wg.Done() wg.Done()
defer wg.Done() defer wg.Done()
_, err := pubsub.ReceiveMessage() _, err := pubsub.ReceiveMessage(ctx)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.Error()).To(SatisfyAny( Expect(err.Error()).To(SatisfyAny(
Equal("redis: client is closed"), Equal("redis: client is closed"),
@ -371,7 +371,7 @@ var _ = Describe("PubSub", func() {
It("should ReceiveMessage without a subscription", func() { It("should ReceiveMessage without a subscription", func() {
timeout := 100 * time.Millisecond timeout := 100 * time.Millisecond
pubsub := client.Subscribe() pubsub := client.Subscribe(ctx)
defer pubsub.Close() defer pubsub.Close()
var wg sync.WaitGroup var wg sync.WaitGroup
@ -382,16 +382,16 @@ var _ = Describe("PubSub", func() {
time.Sleep(timeout) time.Sleep(timeout)
err := pubsub.Subscribe("mychannel") err := pubsub.Subscribe(ctx, "mychannel")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
time.Sleep(timeout) time.Sleep(timeout)
err = client.Publish("mychannel", "hello").Err() err = client.Publish(ctx, "mychannel", "hello").Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}() }()
msg, err := pubsub.ReceiveMessage() msg, err := pubsub.ReceiveMessage(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(msg.Channel).To(Equal("mychannel")) Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal("hello")) Expect(msg.Payload).To(Equal("hello"))
@ -400,13 +400,13 @@ var _ = Describe("PubSub", func() {
}) })
It("handles big message payload", func() { It("handles big message payload", func() {
pubsub := client.Subscribe("mychannel") pubsub := client.Subscribe(ctx, "mychannel")
defer pubsub.Close() defer pubsub.Close()
ch := pubsub.Channel() ch := pubsub.Channel()
bigVal := bigVal() bigVal := bigVal()
err := client.Publish("mychannel", bigVal).Err() err := client.Publish(ctx, "mychannel", bigVal).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
var msg *redis.Message var msg *redis.Message
@ -418,7 +418,7 @@ var _ = Describe("PubSub", func() {
It("supports concurrent Ping and Receive", func() { It("supports concurrent Ping and Receive", func() {
const N = 100 const N = 100
pubsub := client.Subscribe("mychannel") pubsub := client.Subscribe(ctx, "mychannel")
defer pubsub.Close() defer pubsub.Close()
done := make(chan struct{}) done := make(chan struct{})
@ -426,14 +426,14 @@ var _ = Describe("PubSub", func() {
defer GinkgoRecover() defer GinkgoRecover()
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
_, err := pubsub.ReceiveTimeout(5 * time.Second) _, err := pubsub.ReceiveTimeout(ctx, 5*time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
close(done) close(done)
}() }()
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
err := pubsub.Ping() err := pubsub.Ping(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }

View File

@ -2,7 +2,6 @@ package redis_test
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
@ -10,7 +9,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/go-redis/redis/v7" "github.com/go-redis/redis/v8"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -22,7 +21,7 @@ var _ = Describe("races", func() {
BeforeEach(func() { BeforeEach(func() {
client = redis.NewClient(redisOptions()) client = redis.NewClient(redisOptions())
Expect(client.FlushDB().Err()).To(BeNil()) Expect(client.FlushDB(ctx).Err()).To(BeNil())
C, N = 10, 1000 C, N = 10, 1000
if testing.Short() { if testing.Short() {
@ -40,7 +39,7 @@ var _ = Describe("races", func() {
perform(C, func(id int) { perform(C, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
msg := fmt.Sprintf("echo %d %d", id, i) msg := fmt.Sprintf("echo %d %d", id, i)
echo, err := client.Echo(msg).Result() echo, err := client.Echo(ctx, msg).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(echo).To(Equal(msg)) Expect(echo).To(Equal(msg))
} }
@ -52,12 +51,12 @@ var _ = Describe("races", func() {
perform(C, func(id int) { perform(C, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
err := client.Incr(key).Err() err := client.Incr(ctx, key).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
}) })
val, err := client.Get(key).Int64() val, err := client.Get(ctx, key).Int64()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal(int64(C * N))) Expect(val).To(Equal(int64(C * N)))
}) })
@ -66,6 +65,7 @@ var _ = Describe("races", func() {
perform(C, func(id int) { perform(C, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
err := client.Set( err := client.Set(
ctx,
fmt.Sprintf("keys.key-%d-%d", id, i), fmt.Sprintf("keys.key-%d-%d", id, i),
fmt.Sprintf("hello-%d-%d", id, i), fmt.Sprintf("hello-%d-%d", id, i),
0, 0,
@ -74,7 +74,7 @@ var _ = Describe("races", func() {
} }
}) })
keys := client.Keys("keys.*") keys := client.Keys(ctx, "keys.*")
Expect(keys.Err()).NotTo(HaveOccurred()) Expect(keys.Err()).NotTo(HaveOccurred())
Expect(len(keys.Val())).To(Equal(C * N)) Expect(len(keys.Val())).To(Equal(C * N))
}) })
@ -86,12 +86,12 @@ var _ = Describe("races", func() {
key := fmt.Sprintf("keys.key-%d", i) key := fmt.Sprintf("keys.key-%d", i)
keys = append(keys, key) keys = append(keys, key)
err := client.Set(key, fmt.Sprintf("hello-%d", i), 0).Err() err := client.Set(ctx, key, fmt.Sprintf("hello-%d", i), 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
keys = append(keys, "non-existent-key") keys = append(keys, "non-existent-key")
vals, err := client.MGet(keys...).Result() vals, err := client.MGet(ctx, keys...).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(len(vals)).To(Equal(N + 2)) Expect(len(vals)).To(Equal(N + 2))
@ -109,7 +109,7 @@ var _ = Describe("races", func() {
bigVal := bigVal() bigVal := bigVal()
err := client.Set("key", bigVal, 0).Err() err := client.Set(ctx, "key", bigVal, 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
// Reconnect to get new connection. // Reconnect to get new connection.
@ -118,7 +118,7 @@ var _ = Describe("races", func() {
perform(C, func(id int) { perform(C, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
got, err := client.Get("key").Bytes() got, err := client.Get(ctx, "key").Bytes()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(got).To(Equal(bigVal)) Expect(got).To(Equal(bigVal))
} }
@ -131,14 +131,14 @@ var _ = Describe("races", func() {
bigVal := bigVal() bigVal := bigVal()
perform(C, func(id int) { perform(C, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
err := client.Set("key", bigVal, 0).Err() err := client.Set(ctx, "key", bigVal, 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
}) })
}) })
It("should select db", func() { It("should select db", func() {
err := client.Set("db", 1, 0).Err() err := client.Set(ctx, "db", 1, 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
perform(C, func(id int) { perform(C, func(id int) {
@ -146,10 +146,10 @@ var _ = Describe("races", func() {
opt.DB = id opt.DB = id
client := redis.NewClient(opt) client := redis.NewClient(opt)
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
err := client.Set("db", id, 0).Err() err := client.Set(ctx, "db", id, 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
n, err := client.Get("db").Int64() n, err := client.Get(ctx, "db").Int64()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(id))) Expect(n).To(Equal(int64(id)))
} }
@ -157,7 +157,7 @@ var _ = Describe("races", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
n, err := client.Get("db").Int64() n, err := client.Get(ctx, "db").Int64()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1))) Expect(n).To(Equal(int64(1)))
}) })
@ -170,7 +170,7 @@ var _ = Describe("races", func() {
client := redis.NewClient(opt) client := redis.NewClient(opt)
perform(C, func(id int) { perform(C, func(id int) {
err := client.Ping().Err() err := client.Ping(ctx).Err()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue()) Expect(err.(net.Error).Timeout()).To(BeTrue())
}) })
@ -181,21 +181,21 @@ var _ = Describe("races", func() {
}) })
It("should Watch/Unwatch", func() { It("should Watch/Unwatch", func() {
err := client.Set("key", "0", 0).Err() err := client.Set(ctx, "key", "0", 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
perform(C, func(id int) { perform(C, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
err := client.Watch(func(tx *redis.Tx) error { err := client.Watch(ctx, func(tx *redis.Tx) error {
val, err := tx.Get("key").Result() val, err := tx.Get(ctx, "key").Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(val).NotTo(Equal(redis.Nil)) Expect(val).NotTo(Equal(redis.Nil))
num, err := strconv.ParseInt(val, 10, 64) num, err := strconv.ParseInt(val, 10, 64)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cmds, err := tx.TxPipelined(func(pipe redis.Pipeliner) error { cmds, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set("key", strconv.FormatInt(num+1, 10), 0) pipe.Set(ctx, "key", strconv.FormatInt(num+1, 10), 0)
return nil return nil
}) })
Expect(cmds).To(HaveLen(1)) Expect(cmds).To(HaveLen(1))
@ -209,7 +209,7 @@ var _ = Describe("races", func() {
} }
}) })
val, err := client.Get("key").Int64() val, err := client.Get(ctx, "key").Int64()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal(int64(C * N))) Expect(val).To(Equal(int64(C * N)))
}) })
@ -218,10 +218,10 @@ var _ = Describe("races", func() {
perform(C, func(id int) { perform(C, func(id int) {
pipe := client.Pipeline() pipe := client.Pipeline()
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
pipe.Echo(fmt.Sprint(i)) pipe.Echo(ctx, fmt.Sprint(i))
} }
cmds, err := pipe.Exec() cmds, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(N)) Expect(cmds).To(HaveLen(N))
@ -234,14 +234,14 @@ var _ = Describe("races", func() {
It("should Pipeline", func() { It("should Pipeline", func() {
pipe := client.Pipeline() pipe := client.Pipeline()
perform(N, func(id int) { perform(N, func(id int) {
pipe.Incr("key") pipe.Incr(ctx, "key")
}) })
cmds, err := pipe.Exec() cmds, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(N)) Expect(cmds).To(HaveLen(N))
n, err := client.Get("key").Int64() n, err := client.Get(ctx, "key").Int64()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(N))) Expect(n).To(Equal(int64(N)))
}) })
@ -249,14 +249,14 @@ var _ = Describe("races", func() {
It("should TxPipeline", func() { It("should TxPipeline", func() {
pipe := client.TxPipeline() pipe := client.TxPipeline()
perform(N, func(id int) { perform(N, func(id int) {
pipe.Incr("key") pipe.Incr(ctx, "key")
}) })
cmds, err := pipe.Exec() cmds, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(N)) Expect(cmds).To(HaveLen(N))
n, err := client.Get("key").Int64() n, err := client.Get(ctx, "key").Int64()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(N))) Expect(n).To(Equal(int64(N)))
}) })
@ -265,7 +265,7 @@ var _ = Describe("races", func() {
var received uint32 var received uint32
wg := performAsync(C, func(id int) { wg := performAsync(C, func(id int) {
for { for {
v, err := client.BLPop(3*time.Second, "list").Result() v, err := client.BLPop(ctx, 3*time.Second, "list").Result()
if err != nil { if err != nil {
break break
} }
@ -276,7 +276,7 @@ var _ = Describe("races", func() {
perform(C, func(id int) { perform(C, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
err := client.LPush("list", "hello").Err() err := client.LPush(ctx, "list", "hello").Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
}) })
@ -287,7 +287,7 @@ var _ = Describe("races", func() {
It("should WithContext", func() { It("should WithContext", func() {
perform(C, func(_ int) { perform(C, func(_ int) {
err := client.WithContext(context.Background()).Ping().Err() err := client.WithContext(ctx).Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
}) })
@ -299,7 +299,7 @@ var _ = Describe("cluster races", func() {
BeforeEach(func() { BeforeEach(func() {
opt := redisClusterOptions() opt := redisClusterOptions()
client = cluster.newClusterClient(opt) client = cluster.newClusterClient(ctx, opt)
C, N = 10, 1000 C, N = 10, 1000
if testing.Short() { if testing.Short() {
@ -317,7 +317,7 @@ var _ = Describe("cluster races", func() {
perform(C, func(id int) { perform(C, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
msg := fmt.Sprintf("echo %d %d", id, i) msg := fmt.Sprintf("echo %d %d", id, i)
echo, err := client.Echo(msg).Result() echo, err := client.Echo(ctx, msg).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(echo).To(Equal(msg)) Expect(echo).To(Equal(msg))
} }
@ -328,7 +328,7 @@ var _ = Describe("cluster races", func() {
perform(C, func(id int) { perform(C, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
key := fmt.Sprintf("key_%d_%d", id, i) key := fmt.Sprintf("key_%d_%d", id, i)
_, err := client.Get(key).Result() _, err := client.Get(ctx, key).Result()
Expect(err).To(Equal(redis.Nil)) Expect(err).To(Equal(redis.Nil))
} }
}) })
@ -339,12 +339,12 @@ var _ = Describe("cluster races", func() {
perform(C, func(id int) { perform(C, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
err := client.Incr(key).Err() err := client.Incr(ctx, key).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
}) })
val, err := client.Get(key).Int64() val, err := client.Get(ctx, key).Int64()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal(int64(C * N))) Expect(val).To(Equal(int64(C * N)))
}) })

106
redis.go
View File

@ -5,9 +5,9 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/go-redis/redis/v7/internal" "github.com/go-redis/redis/v8/internal"
"github.com/go-redis/redis/v7/internal/pool" "github.com/go-redis/redis/v8/internal/pool"
"github.com/go-redis/redis/v7/internal/proto" "github.com/go-redis/redis/v8/internal/proto"
) )
// Nil reply returned by Redis when key does not exist. // Nil reply returned by Redis when key does not exist.
@ -130,7 +130,7 @@ func (hs hooks) afterProcessPipeline(ctx context.Context, cmds []Cmder) error {
func (hs hooks) processTxPipeline( func (hs hooks) processTxPipeline(
ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
) error { ) error {
cmds = wrapMultiExec(cmds) cmds = wrapMultiExec(ctx, cmds)
return hs.processPipeline(ctx, cmds, fn) return hs.processPipeline(ctx, cmds, fn)
} }
@ -200,6 +200,7 @@ func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
} }
return nil, err return nil, err
} }
return cn, nil return cn, nil
} }
@ -209,7 +210,13 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
return nil, err return nil, err
} }
err = c.initConn(ctx, cn) if cn.Inited {
return cn, nil
}
err = internal.WithSpan(ctx, "init_conn", func(ctx context.Context) error {
return c.initConn(ctx, cn)
})
if err != nil { if err != nil {
c.connPool.Remove(cn, err) c.connPool.Remove(cn, err)
if err := internal.Unwrap(err); err != nil { if err := internal.Unwrap(err); err != nil {
@ -238,21 +245,21 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
connPool.SetConn(cn) connPool.SetConn(cn)
conn := newConn(ctx, c.opt, connPool) conn := newConn(ctx, c.opt, connPool)
_, err := conn.Pipelined(func(pipe Pipeliner) error { _, err := conn.Pipelined(ctx, func(pipe Pipeliner) error {
if c.opt.Password != "" { if c.opt.Password != "" {
if c.opt.Username != "" { if c.opt.Username != "" {
pipe.AuthACL(c.opt.Username, c.opt.Password) pipe.AuthACL(ctx, c.opt.Username, c.opt.Password)
} else { } else {
pipe.Auth(c.opt.Password) pipe.Auth(ctx, c.opt.Password)
} }
} }
if c.opt.DB > 0 { if c.opt.DB > 0 {
pipe.Select(c.opt.DB) pipe.Select(ctx, c.opt.DB)
} }
if c.opt.readOnly { if c.opt.readOnly {
pipe.ReadOnly() pipe.ReadOnly(ctx)
} }
return nil return nil
@ -282,6 +289,7 @@ func (c *baseClient) releaseConn(cn *pool.Conn, err error) {
func (c *baseClient) withConn( func (c *baseClient) withConn(
ctx context.Context, fn func(context.Context, *pool.Conn) error, ctx context.Context, fn func(context.Context, *pool.Conn) error,
) error { ) error {
return internal.WithSpan(ctx, "with_conn", func(ctx context.Context) error {
cn, err := c.getConn(ctx) cn, err := c.getConn(ctx)
if err != nil { if err != nil {
return err return err
@ -292,6 +300,7 @@ func (c *baseClient) withConn(
err = fn(ctx, cn) err = fn(ctx, cn)
return err return err
})
} }
func (c *baseClient) process(ctx context.Context, cmd Cmder) error { func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
@ -306,6 +315,10 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
func (c *baseClient) _process(ctx context.Context, cmd Cmder) error { func (c *baseClient) _process(ctx context.Context, cmd Cmder) error {
var lastErr error var lastErr error
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
attempt := attempt
var retry bool
err := internal.WithSpan(ctx, "process", func(ctx context.Context) error {
if attempt > 0 { if attempt > 0 {
if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
return err return err
@ -313,7 +326,7 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder) error {
} }
retryTimeout := true retryTimeout := true
lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmd) return writeCmd(wr, cmd)
}) })
@ -329,9 +342,16 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder) error {
return nil return nil
}) })
if lastErr == nil || !isRetryableError(lastErr, retryTimeout) { if err == nil {
return lastErr return nil
} }
retry = isRetryableError(err, retryTimeout)
return err
})
if err == nil || !retry {
return err
}
lastErr = err
} }
return lastErr return lastErr
} }
@ -468,14 +488,14 @@ func (c *baseClient) txPipelineProcessCmds(
return false, err return false, err
} }
func wrapMultiExec(cmds []Cmder) []Cmder { func wrapMultiExec(ctx context.Context, cmds []Cmder) []Cmder {
if len(cmds) == 0 { if len(cmds) == 0 {
panic("not reached") panic("not reached")
} }
cmds = append(cmds, make([]Cmder, 2)...) cmds = append(cmds, make([]Cmder, 2)...)
copy(cmds[1:], cmds[:len(cmds)-2]) copy(cmds[1:], cmds[:len(cmds)-2])
cmds[0] = NewStatusCmd("multi") cmds[0] = NewStatusCmd(ctx, "multi")
cmds[len(cmds)-1] = NewSliceCmd("exec") cmds[len(cmds)-1] = NewSliceCmd(ctx, "exec")
return cmds return cmds
} }
@ -564,26 +584,18 @@ func (c *Client) WithContext(ctx context.Context) *Client {
return clone return clone
} }
func (c *Client) Conn() *Conn { func (c *Client) Conn(ctx context.Context) *Conn {
return newConn(c.ctx, c.opt, pool.NewSingleConnPool(c.connPool)) return newConn(ctx, c.opt, pool.NewSingleConnPool(c.connPool))
} }
// Do creates a Cmd from the args and processes the cmd. // Do creates a Cmd from the args and processes the cmd.
func (c *Client) Do(args ...interface{}) *Cmd { func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd {
return c.DoContext(c.ctx, args...) cmd := NewCmd(ctx, args...)
} _ = c.Process(ctx, cmd)
func (c *Client) DoContext(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(args...)
_ = c.ProcessContext(ctx, cmd)
return cmd return cmd
} }
func (c *Client) Process(cmd Cmder) error { func (c *Client) Process(ctx context.Context, cmd Cmder) error {
return c.ProcessContext(c.ctx, cmd)
}
func (c *Client) ProcessContext(ctx context.Context, cmd Cmder) error {
return c.hooks.process(ctx, cmd, c.baseClient.process) return c.hooks.process(ctx, cmd, c.baseClient.process)
} }
@ -608,8 +620,8 @@ func (c *Client) PoolStats() *PoolStats {
return (*PoolStats)(stats) return (*PoolStats)(stats)
} }
func (c *Client) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { func (c *Client) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(fn) return c.Pipeline().Pipelined(ctx, fn)
} }
func (c *Client) Pipeline() Pipeliner { func (c *Client) Pipeline() Pipeliner {
@ -621,8 +633,8 @@ func (c *Client) Pipeline() Pipeliner {
return &pipe return &pipe
} }
func (c *Client) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) { func (c *Client) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(fn) return c.TxPipeline().Pipelined(ctx, fn)
} }
// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. // TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
@ -639,8 +651,8 @@ func (c *Client) pubSub() *PubSub {
pubsub := &PubSub{ pubsub := &PubSub{
opt: c.opt, opt: c.opt,
newConn: func(channels []string) (*pool.Conn, error) { newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
return c.newConn(context.TODO()) return c.newConn(ctx)
}, },
closeConn: c.connPool.CloseConn, closeConn: c.connPool.CloseConn,
} }
@ -674,20 +686,20 @@ func (c *Client) pubSub() *PubSub {
// } // }
// //
// ch := sub.Channel() // ch := sub.Channel()
func (c *Client) Subscribe(channels ...string) *PubSub { func (c *Client) Subscribe(ctx context.Context, channels ...string) *PubSub {
pubsub := c.pubSub() pubsub := c.pubSub()
if len(channels) > 0 { if len(channels) > 0 {
_ = pubsub.Subscribe(channels...) _ = pubsub.Subscribe(ctx, channels...)
} }
return pubsub return pubsub
} }
// PSubscribe subscribes the client to the given patterns. // PSubscribe subscribes the client to the given patterns.
// Patterns can be omitted to create empty subscription. // Patterns can be omitted to create empty subscription.
func (c *Client) PSubscribe(channels ...string) *PubSub { func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub {
pubsub := c.pubSub() pubsub := c.pubSub()
if len(channels) > 0 { if len(channels) > 0 {
_ = pubsub.PSubscribe(channels...) _ = pubsub.PSubscribe(ctx, channels...)
} }
return pubsub return pubsub
} }
@ -721,16 +733,12 @@ func newConn(ctx context.Context, opt *Options, connPool pool.Pooler) *Conn {
return &c return &c
} }
func (c *Conn) Process(cmd Cmder) error { func (c *Conn) Process(ctx context.Context, cmd Cmder) error {
return c.ProcessContext(c.ctx, cmd)
}
func (c *Conn) ProcessContext(ctx context.Context, cmd Cmder) error {
return c.baseClient.process(ctx, cmd) return c.baseClient.process(ctx, cmd)
} }
func (c *Conn) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(fn) return c.Pipeline().Pipelined(ctx, fn)
} }
func (c *Conn) Pipeline() Pipeliner { func (c *Conn) Pipeline() Pipeliner {
@ -742,8 +750,8 @@ func (c *Conn) Pipeline() Pipeliner {
return &pipe return &pipe
} }
func (c *Conn) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) { func (c *Conn) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(fn) return c.TxPipeline().Pipelined(ctx, fn)
} }
// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. // TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.

View File

@ -8,7 +8,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/go-redis/redis/v7" "github.com/go-redis/redis/v8"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -34,7 +34,7 @@ func TestHookError(t *testing.T) {
}) })
rdb.AddHook(redisHookError{}) rdb.AddHook(redisHookError{})
err := rdb.Ping().Err() err := rdb.Ping(ctx).Err()
if err == nil { if err == nil {
t.Fatalf("got nil, expected an error") t.Fatalf("got nil, expected an error")
} }
@ -52,7 +52,7 @@ var _ = Describe("Client", func() {
BeforeEach(func() { BeforeEach(func() {
client = redis.NewClient(redisOptions()) client = redis.NewClient(redisOptions())
Expect(client.FlushDB().Err()).NotTo(HaveOccurred()) Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
}) })
AfterEach(func() { AfterEach(func() {
@ -63,27 +63,27 @@ var _ = Describe("Client", func() {
Expect(client.String()).To(Equal("Redis<:6380 db:15>")) Expect(client.String()).To(Equal("Redis<:6380 db:15>"))
}) })
It("supports WithContext", func() { It("supports context", func() {
c, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(ctx)
cancel() cancel()
err := client.WithContext(c).Ping().Err() err := client.Ping(ctx).Err()
Expect(err).To(MatchError("context canceled")) Expect(err).To(MatchError("context canceled"))
}) })
It("supports WithTimeout", func() { It("supports WithTimeout", func() {
err := client.ClientPause(time.Second).Err() err := client.ClientPause(ctx, time.Second).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = client.WithTimeout(10 * time.Millisecond).Ping().Err() err = client.WithTimeout(10 * time.Millisecond).Ping(ctx).Err()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
err = client.Ping().Err() err = client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
It("should ping", func() { It("should ping", func() {
val, err := client.Ping().Result() val, err := client.Ping(ctx).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("PONG")) Expect(val).To(Equal("PONG"))
}) })
@ -102,7 +102,7 @@ var _ = Describe("Client", func() {
}, },
}) })
val, err := custom.Ping().Result() val, err := custom.Ping(ctx).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("PONG")) Expect(val).To(Equal("PONG"))
Expect(custom.Close()).NotTo(HaveOccurred()) Expect(custom.Close()).NotTo(HaveOccurred())
@ -110,48 +110,48 @@ var _ = Describe("Client", func() {
It("should close", func() { It("should close", func() {
Expect(client.Close()).NotTo(HaveOccurred()) Expect(client.Close()).NotTo(HaveOccurred())
err := client.Ping().Err() err := client.Ping(ctx).Err()
Expect(err).To(MatchError("redis: client is closed")) Expect(err).To(MatchError("redis: client is closed"))
}) })
It("should close pubsub without closing the client", func() { It("should close pubsub without closing the client", func() {
pubsub := client.Subscribe() pubsub := client.Subscribe(ctx)
Expect(pubsub.Close()).NotTo(HaveOccurred()) Expect(pubsub.Close()).NotTo(HaveOccurred())
_, err := pubsub.Receive() _, err := pubsub.Receive(ctx)
Expect(err).To(MatchError("redis: client is closed")) Expect(err).To(MatchError("redis: client is closed"))
Expect(client.Ping().Err()).NotTo(HaveOccurred()) Expect(client.Ping(ctx).Err()).NotTo(HaveOccurred())
}) })
It("should close Tx without closing the client", func() { It("should close Tx without closing the client", func() {
err := client.Watch(func(tx *redis.Tx) error { err := client.Watch(ctx, func(tx *redis.Tx) error {
_, err := tx.TxPipelined(func(pipe redis.Pipeliner) error { _, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping() pipe.Ping(ctx)
return nil return nil
}) })
return err return err
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(client.Ping().Err()).NotTo(HaveOccurred()) Expect(client.Ping(ctx).Err()).NotTo(HaveOccurred())
}) })
It("should close pipeline without closing the client", func() { It("should close pipeline without closing the client", func() {
pipeline := client.Pipeline() pipeline := client.Pipeline()
Expect(pipeline.Close()).NotTo(HaveOccurred()) Expect(pipeline.Close()).NotTo(HaveOccurred())
pipeline.Ping() pipeline.Ping(ctx)
_, err := pipeline.Exec() _, err := pipeline.Exec(ctx)
Expect(err).To(MatchError("redis: client is closed")) Expect(err).To(MatchError("redis: client is closed"))
Expect(client.Ping().Err()).NotTo(HaveOccurred()) Expect(client.Ping(ctx).Err()).NotTo(HaveOccurred())
}) })
It("should close pubsub when client is closed", func() { It("should close pubsub when client is closed", func() {
pubsub := client.Subscribe() pubsub := client.Subscribe(ctx)
Expect(client.Close()).NotTo(HaveOccurred()) Expect(client.Close()).NotTo(HaveOccurred())
_, err := pubsub.Receive() _, err := pubsub.Receive(ctx)
Expect(err).To(MatchError("redis: client is closed")) Expect(err).To(MatchError("redis: client is closed"))
Expect(pubsub.Close()).NotTo(HaveOccurred()) Expect(pubsub.Close()).NotTo(HaveOccurred())
@ -168,26 +168,26 @@ var _ = Describe("Client", func() {
Addr: redisAddr, Addr: redisAddr,
DB: 2, DB: 2,
}) })
Expect(db2.FlushDB().Err()).NotTo(HaveOccurred()) Expect(db2.FlushDB(ctx).Err()).NotTo(HaveOccurred())
Expect(db2.Get("db").Err()).To(Equal(redis.Nil)) Expect(db2.Get(ctx, "db").Err()).To(Equal(redis.Nil))
Expect(db2.Set("db", 2, 0).Err()).NotTo(HaveOccurred()) Expect(db2.Set(ctx, "db", 2, 0).Err()).NotTo(HaveOccurred())
n, err := db2.Get("db").Int64() n, err := db2.Get(ctx, "db").Int64()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(2))) Expect(n).To(Equal(int64(2)))
Expect(client.Get("db").Err()).To(Equal(redis.Nil)) Expect(client.Get(ctx, "db").Err()).To(Equal(redis.Nil))
Expect(db2.FlushDB().Err()).NotTo(HaveOccurred()) Expect(db2.FlushDB(ctx).Err()).NotTo(HaveOccurred())
Expect(db2.Close()).NotTo(HaveOccurred()) Expect(db2.Close()).NotTo(HaveOccurred())
}) })
It("processes custom commands", func() { It("processes custom commands", func() {
cmd := redis.NewCmd("PING") cmd := redis.NewCmd(ctx, "PING")
_ = client.Process(cmd) _ = client.Process(ctx, cmd)
// Flush buffers. // Flush buffers.
Expect(client.Echo("hello").Err()).NotTo(HaveOccurred()) Expect(client.Echo(ctx, "hello").Err()).NotTo(HaveOccurred())
Expect(cmd.Err()).NotTo(HaveOccurred()) Expect(cmd.Err()).NotTo(HaveOccurred())
Expect(cmd.Val()).To(Equal("PONG")) Expect(cmd.Val()).To(Equal("PONG"))
@ -202,13 +202,13 @@ var _ = Describe("Client", func() {
}) })
// Put bad connection in the pool. // Put bad connection in the pool.
cn, err := client.Pool().Get(context.Background()) cn, err := client.Pool().Get(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{}) cn.SetNetConn(&badConn{})
client.Pool().Put(cn) client.Pool().Put(cn)
err = client.Ping().Err() err = client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
@ -227,12 +227,12 @@ var _ = Describe("Client", func() {
defer clientRetry.Close() defer clientRetry.Close()
startNoRetry := time.Now() startNoRetry := time.Now()
err := clientNoRetry.Ping().Err() err := clientNoRetry.Ping(ctx).Err()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
elapseNoRetry := time.Since(startNoRetry) elapseNoRetry := time.Since(startNoRetry)
startRetry := time.Now() startRetry := time.Now()
err = clientRetry.Ping().Err() err = clientRetry.Ping(ctx).Err()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
elapseRetry := time.Since(startRetry) elapseRetry := time.Since(startRetry)
@ -250,7 +250,7 @@ var _ = Describe("Client", func() {
time.Sleep(time.Second) time.Sleep(time.Second)
err = client.Ping().Err() err = client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn, err = client.Pool().Get(context.Background()) cn, err = client.Pool().Get(context.Background())
@ -260,11 +260,11 @@ var _ = Describe("Client", func() {
}) })
It("should process command with special chars", func() { It("should process command with special chars", func() {
set := client.Set("key", "hello1\r\nhello2\r\n", 0) set := client.Set(ctx, "key", "hello1\r\nhello2\r\n", 0)
Expect(set.Err()).NotTo(HaveOccurred()) Expect(set.Err()).NotTo(HaveOccurred())
Expect(set.Val()).To(Equal("OK")) Expect(set.Val()).To(Equal("OK"))
get := client.Get("key") get := client.Get(ctx, "key")
Expect(get.Err()).NotTo(HaveOccurred()) Expect(get.Err()).NotTo(HaveOccurred())
Expect(get.Val()).To(Equal("hello1\r\nhello2\r\n")) Expect(get.Val()).To(Equal("hello1\r\nhello2\r\n"))
}) })
@ -272,14 +272,14 @@ var _ = Describe("Client", func() {
It("should handle big vals", func() { It("should handle big vals", func() {
bigVal := bytes.Repeat([]byte{'*'}, 2e6) bigVal := bytes.Repeat([]byte{'*'}, 2e6)
err := client.Set("key", bigVal, 0).Err() err := client.Set(ctx, "key", bigVal, 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
// Reconnect to get new connection. // Reconnect to get new connection.
Expect(client.Close()).NotTo(HaveOccurred()) Expect(client.Close()).NotTo(HaveOccurred())
client = redis.NewClient(redisOptions()) client = redis.NewClient(redisOptions())
got, err := client.Get("key").Bytes() got, err := client.Get(ctx, "key").Bytes()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(got).To(Equal(bigVal)) Expect(got).To(Equal(bigVal))
}) })
@ -295,14 +295,14 @@ var _ = Describe("Client timeout", func() {
testTimeout := func() { testTimeout := func() {
It("Ping timeouts", func() { It("Ping timeouts", func() {
err := client.Ping().Err() err := client.Ping(ctx).Err()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue()) Expect(err.(net.Error).Timeout()).To(BeTrue())
}) })
It("Pipeline timeouts", func() { It("Pipeline timeouts", func() {
_, err := client.Pipelined(func(pipe redis.Pipeliner) error { _, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping() pipe.Ping(ctx)
return nil return nil
}) })
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
@ -314,26 +314,26 @@ var _ = Describe("Client timeout", func() {
return return
} }
pubsub := client.Subscribe() pubsub := client.Subscribe(ctx)
defer pubsub.Close() defer pubsub.Close()
err := pubsub.Subscribe("_") err := pubsub.Subscribe(ctx, "_")
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue()) Expect(err.(net.Error).Timeout()).To(BeTrue())
}) })
It("Tx timeouts", func() { It("Tx timeouts", func() {
err := client.Watch(func(tx *redis.Tx) error { err := client.Watch(ctx, func(tx *redis.Tx) error {
return tx.Ping().Err() return tx.Ping(ctx).Err()
}) })
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue()) Expect(err.(net.Error).Timeout()).To(BeTrue())
}) })
It("Tx Pipeline timeouts", func() { It("Tx Pipeline timeouts", func() {
err := client.Watch(func(tx *redis.Tx) error { err := client.Watch(ctx, func(tx *redis.Tx) error {
_, err := tx.TxPipelined(func(pipe redis.Pipeliner) error { _, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping() pipe.Ping(ctx)
return nil return nil
}) })
return err return err
@ -373,7 +373,7 @@ var _ = Describe("Client OnConnect", func() {
opt := redisOptions() opt := redisOptions()
opt.DB = 0 opt.DB = 0
opt.OnConnect = func(cn *redis.Conn) error { opt.OnConnect = func(cn *redis.Conn) error {
return cn.ClientSetName("on_connect").Err() return cn.ClientSetName(ctx, "on_connect").Err()
} }
client = redis.NewClient(opt) client = redis.NewClient(opt)
@ -384,7 +384,7 @@ var _ = Describe("Client OnConnect", func() {
}) })
It("calls OnConnect", func() { It("calls OnConnect", func() {
name, err := client.ClientGetName().Result() name, err := client.ClientGetName(ctx).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(name).To(Equal("on_connect")) Expect(name).To(Equal("on_connect"))
}) })

94
ring.go
View File

@ -10,10 +10,10 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/go-redis/redis/v7/internal" "github.com/go-redis/redis/v8/internal"
"github.com/go-redis/redis/v7/internal/consistenthash" "github.com/go-redis/redis/v8/internal/consistenthash"
"github.com/go-redis/redis/v7/internal/hashtag" "github.com/go-redis/redis/v8/internal/hashtag"
"github.com/go-redis/redis/v7/internal/pool" "github.com/go-redis/redis/v8/internal/pool"
) )
// Hash is type of hash function used in consistent hash. // Hash is type of hash function used in consistent hash.
@ -27,10 +27,6 @@ type RingOptions struct {
// Map of name => host:port addresses of ring shards. // Map of name => host:port addresses of ring shards.
Addrs map[string]string Addrs map[string]string
// Map of name => password of ring shards, to allow different shards to have
// different passwords. It will be ignored if the Password field is set.
Passwords map[string]string
// Frequency of PING commands sent to check shards availability. // Frequency of PING commands sent to check shards availability.
// Shard is considered down after 3 subsequent failed checks. // Shard is considered down after 3 subsequent failed checks.
HeartbeatFrequency time.Duration HeartbeatFrequency time.Duration
@ -59,9 +55,6 @@ type RingOptions struct {
// NewClient creates a shard client with provided name and options. // NewClient creates a shard client with provided name and options.
NewClient func(name string, opt *Options) *Client NewClient func(name string, opt *Options) *Client
// Optional hook that is called when a new shard is created.
OnNewShard func(*Client)
// Following options are copied from Options struct. // Following options are copied from Options struct.
OnConnect func(*Conn) error OnConnect func(*Conn) error
@ -108,12 +101,11 @@ func (opt *RingOptions) init() {
} }
} }
func (opt *RingOptions) clientOptions(shard string) *Options { func (opt *RingOptions) clientOptions() *Options {
return &Options{ return &Options{
OnConnect: opt.OnConnect, OnConnect: opt.OnConnect,
DB: opt.DB, DB: opt.DB,
Password: opt.getPassword(shard),
DialTimeout: opt.DialTimeout, DialTimeout: opt.DialTimeout,
ReadTimeout: opt.ReadTimeout, ReadTimeout: opt.ReadTimeout,
@ -128,13 +120,6 @@ func (opt *RingOptions) clientOptions(shard string) *Options {
} }
} }
func (opt *RingOptions) getPassword(shard string) string {
if opt.Password == "" {
return opt.Passwords[shard]
}
return opt.Password
}
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
type ringShard struct { type ringShard struct {
@ -261,6 +246,8 @@ func (c *ringShards) Random() (*ringShard, error) {
func (c *ringShards) Heartbeat(frequency time.Duration) { func (c *ringShards) Heartbeat(frequency time.Duration) {
ticker := time.NewTicker(frequency) ticker := time.NewTicker(frequency)
defer ticker.Stop() defer ticker.Stop()
ctx := context.TODO()
for range ticker.C { for range ticker.C {
var rebalance bool var rebalance bool
@ -275,7 +262,7 @@ func (c *ringShards) Heartbeat(frequency time.Duration) {
c.mu.RUnlock() c.mu.RUnlock()
for _, shard := range shards { for _, shard := range shards {
err := shard.Client.Ping().Err() err := shard.Client.Ping(ctx).Err()
if shard.Vote(err == nil || err == pool.ErrPoolTimeout) { if shard.Vote(err == nil || err == pool.ErrPoolTimeout) {
internal.Logger.Printf("ring shard state changed: %s", shard) internal.Logger.Printf("ring shard state changed: %s", shard)
rebalance = true rebalance = true
@ -391,18 +378,13 @@ func NewRing(opt *RingOptions) *Ring {
} }
func newRingShard(opt *RingOptions, name, addr string) *Client { func newRingShard(opt *RingOptions, name, addr string) *Client {
clopt := opt.clientOptions(name) clopt := opt.clientOptions()
clopt.Addr = addr clopt.Addr = addr
var shard *Client
if opt.NewClient != nil { if opt.NewClient != nil {
shard = opt.NewClient(name, clopt) return opt.NewClient(name, clopt)
} else {
shard = NewClient(clopt)
} }
if opt.OnNewShard != nil { return NewClient(clopt)
opt.OnNewShard(shard)
}
return shard
} }
func (c *Ring) Context() context.Context { func (c *Ring) Context() context.Context {
@ -421,21 +403,13 @@ func (c *Ring) WithContext(ctx context.Context) *Ring {
} }
// Do creates a Cmd from the args and processes the cmd. // Do creates a Cmd from the args and processes the cmd.
func (c *Ring) Do(args ...interface{}) *Cmd { func (c *Ring) Do(ctx context.Context, args ...interface{}) *Cmd {
return c.DoContext(c.ctx, args...) cmd := NewCmd(ctx, args...)
} _ = c.Process(ctx, cmd)
func (c *Ring) DoContext(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(args...)
_ = c.ProcessContext(ctx, cmd)
return cmd return cmd
} }
func (c *Ring) Process(cmd Cmder) error { func (c *Ring) Process(ctx context.Context, cmd Cmder) error {
return c.ProcessContext(c.ctx, cmd)
}
func (c *Ring) ProcessContext(ctx context.Context, cmd Cmder) error {
return c.hooks.process(ctx, cmd, c.process) return c.hooks.process(ctx, cmd, c.process)
} }
@ -469,7 +443,7 @@ func (c *Ring) Len() int {
} }
// Subscribe subscribes the client to the specified channels. // Subscribe subscribes the client to the specified channels.
func (c *Ring) Subscribe(channels ...string) *PubSub { func (c *Ring) Subscribe(ctx context.Context, channels ...string) *PubSub {
if len(channels) == 0 { if len(channels) == 0 {
panic("at least one channel is required") panic("at least one channel is required")
} }
@ -479,11 +453,11 @@ func (c *Ring) Subscribe(channels ...string) *PubSub {
//TODO: return PubSub with sticky error //TODO: return PubSub with sticky error
panic(err) panic(err)
} }
return shard.Client.Subscribe(channels...) return shard.Client.Subscribe(ctx, channels...)
} }
// PSubscribe subscribes the client to the given patterns. // PSubscribe subscribes the client to the given patterns.
func (c *Ring) PSubscribe(channels ...string) *PubSub { func (c *Ring) PSubscribe(ctx context.Context, channels ...string) *PubSub {
if len(channels) == 0 { if len(channels) == 0 {
panic("at least one channel is required") panic("at least one channel is required")
} }
@ -493,12 +467,15 @@ func (c *Ring) PSubscribe(channels ...string) *PubSub {
//TODO: return PubSub with sticky error //TODO: return PubSub with sticky error
panic(err) panic(err)
} }
return shard.Client.PSubscribe(channels...) return shard.Client.PSubscribe(ctx, channels...)
} }
// ForEachShard concurrently calls the fn on each live shard in the ring. // ForEachShard concurrently calls the fn on each live shard in the ring.
// It returns the first error if any. // It returns the first error if any.
func (c *Ring) ForEachShard(fn func(client *Client) error) error { func (c *Ring) ForEachShard(
ctx context.Context,
fn func(ctx context.Context, client *Client) error,
) error {
shards := c.shards.List() shards := c.shards.List()
var wg sync.WaitGroup var wg sync.WaitGroup
errCh := make(chan error, 1) errCh := make(chan error, 1)
@ -510,7 +487,7 @@ func (c *Ring) ForEachShard(fn func(client *Client) error) error {
wg.Add(1) wg.Add(1)
go func(shard *ringShard) { go func(shard *ringShard) {
defer wg.Done() defer wg.Done()
err := fn(shard.Client) err := fn(ctx, shard.Client)
if err != nil { if err != nil {
select { select {
case errCh <- err: case errCh <- err:
@ -531,9 +508,9 @@ func (c *Ring) ForEachShard(fn func(client *Client) error) error {
func (c *Ring) cmdsInfo() (map[string]*CommandInfo, error) { func (c *Ring) cmdsInfo() (map[string]*CommandInfo, error) {
shards := c.shards.List() shards := c.shards.List()
firstErr := errRingShardsDown var firstErr error
for _, shard := range shards { for _, shard := range shards {
cmdsInfo, err := shard.Client.Command().Result() cmdsInfo, err := shard.Client.Command(context.TODO()).Result()
if err == nil { if err == nil {
return cmdsInfo, nil return cmdsInfo, nil
} }
@ -541,6 +518,9 @@ func (c *Ring) cmdsInfo() (map[string]*CommandInfo, error) {
firstErr = err firstErr = err
} }
} }
if firstErr == nil {
return nil, errRingShardsDown
}
return nil, firstErr return nil, firstErr
} }
@ -589,7 +569,7 @@ func (c *Ring) _process(ctx context.Context, cmd Cmder) error {
return err return err
} }
lastErr = shard.Client.ProcessContext(ctx, cmd) lastErr = shard.Client.Process(ctx, cmd)
if lastErr == nil || !isRetryableError(lastErr, cmd.readTimeout() == nil) { if lastErr == nil || !isRetryableError(lastErr, cmd.readTimeout() == nil) {
return lastErr return lastErr
} }
@ -597,8 +577,8 @@ func (c *Ring) _process(ctx context.Context, cmd Cmder) error {
return lastErr return lastErr
} }
func (c *Ring) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { func (c *Ring) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(fn) return c.Pipeline().Pipelined(ctx, fn)
} }
func (c *Ring) Pipeline() Pipeliner { func (c *Ring) Pipeline() Pipeliner {
@ -616,8 +596,8 @@ func (c *Ring) processPipeline(ctx context.Context, cmds []Cmder) error {
}) })
} }
func (c *Ring) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) { func (c *Ring) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(fn) return c.TxPipeline().Pipelined(ctx, fn)
} }
func (c *Ring) TxPipeline() Pipeliner { func (c *Ring) TxPipeline() Pipeliner {
@ -688,7 +668,7 @@ func (c *Ring) Close() error {
return c.shards.Close() return c.shards.Close()
} }
func (c *Ring) Watch(fn func(*Tx) error, keys ...string) error { func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error {
if len(keys) == 0 { if len(keys) == 0 {
return fmt.Errorf("redis: Watch requires at least one key") return fmt.Errorf("redis: Watch requires at least one key")
} }
@ -718,7 +698,7 @@ func (c *Ring) Watch(fn func(*Tx) error, keys ...string) error {
} }
} }
return shards[0].Client.Watch(fn, keys...) return shards[0].Client.Watch(ctx, fn, keys...)
} }
func newConsistentHash(opt *RingOptions) *consistenthash.Map { func newConsistentHash(opt *RingOptions) *consistenthash.Map {

View File

@ -9,7 +9,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/go-redis/redis/v7" "github.com/go-redis/redis/v8"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -22,7 +22,7 @@ var _ = Describe("Redis Ring", func() {
setRingKeys := func() { setRingKeys := func() {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
err := ring.Set(fmt.Sprintf("key%d", i), "value", 0).Err() err := ring.Set(ctx, fmt.Sprintf("key%d", i), "value", 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
} }
@ -32,8 +32,8 @@ var _ = Describe("Redis Ring", func() {
opt.HeartbeatFrequency = heartbeat opt.HeartbeatFrequency = heartbeat
ring = redis.NewRing(opt) ring = redis.NewRing(opt)
err := ring.ForEachShard(func(cl *redis.Client) error { err := ring.ForEachShard(ctx, func(ctx context.Context, cl *redis.Client) error {
return cl.FlushDB().Err() return cl.FlushDB(ctx).Err()
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
@ -42,11 +42,11 @@ var _ = Describe("Redis Ring", func() {
Expect(ring.Close()).NotTo(HaveOccurred()) Expect(ring.Close()).NotTo(HaveOccurred())
}) })
It("supports WithContext", func() { It("supports context", func() {
c, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(ctx)
cancel() cancel()
err := ring.WithContext(c).Ping().Err() err := ring.Ping(ctx).Err()
Expect(err).To(MatchError("context canceled")) Expect(err).To(MatchError("context canceled"))
}) })
@ -54,8 +54,8 @@ var _ = Describe("Redis Ring", func() {
setRingKeys() setRingKeys()
// Both shards should have some keys now. // Both shards should have some keys now.
Expect(ringShard1.Info("keyspace").Val()).To(ContainSubstring("keys=57")) Expect(ringShard1.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=57"))
Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=43")) Expect(ringShard2.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=43"))
}) })
It("distributes keys when using EVAL", func() { It("distributes keys when using EVAL", func() {
@ -67,12 +67,12 @@ var _ = Describe("Redis Ring", func() {
var key string var key string
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
key = fmt.Sprintf("key%d", i) key = fmt.Sprintf("key%d", i)
err := script.Run(ring, []string{key}, "value").Err() err := script.Run(ctx, ring, []string{key}, "value").Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
Expect(ringShard1.Info("keyspace").Val()).To(ContainSubstring("keys=57")) Expect(ringShard1.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=57"))
Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=43")) Expect(ringShard2.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=43"))
}) })
It("uses single shard when one of the shards is down", func() { It("uses single shard when one of the shards is down", func() {
@ -86,7 +86,7 @@ var _ = Describe("Redis Ring", func() {
setRingKeys() setRingKeys()
// RingShard1 should have all keys. // RingShard1 should have all keys.
Expect(ringShard1.Info("keyspace").Val()).To(ContainSubstring("keys=100")) Expect(ringShard1.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=100"))
// Start ringShard2. // Start ringShard2.
var err error var err error
@ -100,27 +100,27 @@ var _ = Describe("Redis Ring", func() {
setRingKeys() setRingKeys()
// RingShard2 should have its keys. // RingShard2 should have its keys.
Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=43")) Expect(ringShard2.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=43"))
}) })
It("supports hash tags", func() { It("supports hash tags", func() {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
err := ring.Set(fmt.Sprintf("key%d{tag}", i), "value", 0).Err() err := ring.Set(ctx, fmt.Sprintf("key%d{tag}", i), "value", 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
Expect(ringShard1.Info("keyspace").Val()).ToNot(ContainSubstring("keys=")) Expect(ringShard1.Info(ctx, "keyspace").Val()).ToNot(ContainSubstring("keys="))
Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=100")) Expect(ringShard2.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=100"))
}) })
Describe("pipeline", func() { Describe("pipeline", func() {
It("distributes keys", func() { It("distributes keys", func() {
pipe := ring.Pipeline() pipe := ring.Pipeline()
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
err := pipe.Set(fmt.Sprintf("key%d", i), "value", 0).Err() err := pipe.Set(ctx, fmt.Sprintf("key%d", i), "value", 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
cmds, err := pipe.Exec() cmds, err := pipe.Exec(ctx)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(100)) Expect(cmds).To(HaveLen(100))
Expect(pipe.Close()).NotTo(HaveOccurred()) Expect(pipe.Close()).NotTo(HaveOccurred())
@ -131,8 +131,8 @@ var _ = Describe("Redis Ring", func() {
} }
// Both shards should have some keys now. // Both shards should have some keys now.
Expect(ringShard1.Info().Val()).To(ContainSubstring("keys=57")) Expect(ringShard1.Info(ctx).Val()).To(ContainSubstring("keys=57"))
Expect(ringShard2.Info().Val()).To(ContainSubstring("keys=43")) Expect(ringShard2.Info(ctx).Val()).To(ContainSubstring("keys=43"))
}) })
It("is consistent with ring", func() { It("is consistent with ring", func() {
@ -144,55 +144,32 @@ var _ = Describe("Redis Ring", func() {
keys = append(keys, string(key)) keys = append(keys, string(key))
} }
_, err := ring.Pipelined(func(pipe redis.Pipeliner) error { _, err := ring.Pipelined(ctx, func(pipe redis.Pipeliner) error {
for _, key := range keys { for _, key := range keys {
pipe.Set(key, "value", 0).Err() pipe.Set(ctx, key, "value", 0).Err()
} }
return nil return nil
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
for _, key := range keys { for _, key := range keys {
val, err := ring.Get(key).Result() val, err := ring.Get(ctx, key).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("value")) Expect(val).To(Equal("value"))
} }
}) })
It("supports hash tags", func() { It("supports hash tags", func() {
_, err := ring.Pipelined(func(pipe redis.Pipeliner) error { _, err := ring.Pipelined(ctx, func(pipe redis.Pipeliner) error {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
pipe.Set(fmt.Sprintf("key%d{tag}", i), "value", 0).Err() pipe.Set(ctx, fmt.Sprintf("key%d{tag}", i), "value", 0).Err()
} }
return nil return nil
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(ringShard1.Info().Val()).ToNot(ContainSubstring("keys=")) Expect(ringShard1.Info(ctx).Val()).ToNot(ContainSubstring("keys="))
Expect(ringShard2.Info().Val()).To(ContainSubstring("keys=100")) Expect(ringShard2.Info(ctx).Val()).To(ContainSubstring("keys=100"))
})
})
Describe("shard passwords", func() {
It("can be initialized with a single password, used for all shards", func() {
opts := redisRingOptions()
opts.Password = "password"
ring = redis.NewRing(opts)
err := ring.Ping().Err()
Expect(err).To(MatchError("ERR Client sent AUTH, but no password is set"))
})
It("can be initialized with a passwords map, one for each shard", func() {
opts := redisRingOptions()
opts.Passwords = map[string]string{
"ringShardOne": "password1",
"ringShardTwo": "password2",
}
ring = redis.NewRing(opts)
err := ring.Ping().Err()
Expect(err).To(MatchError("ERR Client sent AUTH, but no password is set"))
}) })
}) })
@ -205,13 +182,13 @@ var _ = Describe("Redis Ring", func() {
} }
ring = redis.NewRing(opts) ring = redis.NewRing(opts)
err := ring.Ping().Err() err := ring.Ping(ctx).Err()
Expect(err).To(MatchError("ERR Client sent AUTH, but no password is set")) Expect(err).To(MatchError("ERR Client sent AUTH, but no password is set"))
}) })
}) })
It("supports Process hook", func() { It("supports Process hook", func() {
err := ring.Ping().Err() err := ring.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
var stack []string var stack []string
@ -229,7 +206,7 @@ var _ = Describe("Redis Ring", func() {
}, },
}) })
ring.ForEachShard(func(shard *redis.Client) error { ring.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error {
shard.AddHook(&hook{ shard.AddHook(&hook{
beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
Expect(cmd.String()).To(Equal("ping: ")) Expect(cmd.String()).To(Equal("ping: "))
@ -245,7 +222,7 @@ var _ = Describe("Redis Ring", func() {
return nil return nil
}) })
err = ring.Ping().Err() err = ring.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(stack).To(Equal([]string{ Expect(stack).To(Equal([]string{
"ring.BeforeProcess", "ring.BeforeProcess",
@ -256,7 +233,7 @@ var _ = Describe("Redis Ring", func() {
}) })
It("supports Pipeline hook", func() { It("supports Pipeline hook", func() {
err := ring.Ping().Err() err := ring.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
var stack []string var stack []string
@ -276,7 +253,7 @@ var _ = Describe("Redis Ring", func() {
}, },
}) })
ring.ForEachShard(func(shard *redis.Client) error { ring.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error {
shard.AddHook(&hook{ shard.AddHook(&hook{
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
Expect(cmds).To(HaveLen(1)) Expect(cmds).To(HaveLen(1))
@ -294,8 +271,8 @@ var _ = Describe("Redis Ring", func() {
return nil return nil
}) })
_, err = ring.Pipelined(func(pipe redis.Pipeliner) error { _, err = ring.Pipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping() pipe.Ping(ctx)
return nil return nil
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -308,7 +285,7 @@ var _ = Describe("Redis Ring", func() {
}) })
It("supports TxPipeline hook", func() { It("supports TxPipeline hook", func() {
err := ring.Ping().Err() err := ring.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
var stack []string var stack []string
@ -328,7 +305,7 @@ var _ = Describe("Redis Ring", func() {
}, },
}) })
ring.ForEachShard(func(shard *redis.Client) error { ring.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error {
shard.AddHook(&hook{ shard.AddHook(&hook{
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
Expect(cmds).To(HaveLen(3)) Expect(cmds).To(HaveLen(3))
@ -346,8 +323,8 @@ var _ = Describe("Redis Ring", func() {
return nil return nil
}) })
_, err = ring.TxPipelined(func(pipe redis.Pipeliner) error { _, err = ring.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping() pipe.Ping(ctx)
return nil return nil
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -372,13 +349,13 @@ var _ = Describe("empty Redis Ring", func() {
}) })
It("returns an error", func() { It("returns an error", func() {
err := ring.Ping().Err() err := ring.Ping(ctx).Err()
Expect(err).To(MatchError("redis: all ring shards are down")) Expect(err).To(MatchError("redis: all ring shards are down"))
}) })
It("pipeline returns an error", func() { It("pipeline returns an error", func() {
_, err := ring.Pipelined(func(pipe redis.Pipeliner) error { _, err := ring.Pipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping() pipe.Ping(ctx)
return nil return nil
}) })
Expect(err).To(MatchError("redis: all ring shards are down")) Expect(err).To(MatchError("redis: all ring shards are down"))
@ -395,8 +372,8 @@ var _ = Describe("Ring watch", func() {
opt.HeartbeatFrequency = heartbeat opt.HeartbeatFrequency = heartbeat
ring = redis.NewRing(opt) ring = redis.NewRing(opt)
err := ring.ForEachShard(func(cl *redis.Client) error { err := ring.ForEachShard(ctx, func(ctx context.Context, cl *redis.Client) error {
return cl.FlushDB().Err() return cl.FlushDB(ctx).Err()
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
@ -410,14 +387,14 @@ var _ = Describe("Ring watch", func() {
// Transactionally increments key using GET and SET commands. // Transactionally increments key using GET and SET commands.
incr = func(key string) error { incr = func(key string) error {
err := ring.Watch(func(tx *redis.Tx) error { err := ring.Watch(ctx, func(tx *redis.Tx) error {
n, err := tx.Get(key).Int64() n, err := tx.Get(ctx, key).Int64()
if err != nil && err != redis.Nil { if err != nil && err != redis.Nil {
return err return err
} }
_, err = tx.TxPipelined(func(pipe redis.Pipeliner) error { _, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set(key, strconv.FormatInt(n+1, 10), 0) pipe.Set(ctx, key, strconv.FormatInt(n+1, 10), 0)
return nil return nil
}) })
return err return err
@ -441,17 +418,17 @@ var _ = Describe("Ring watch", func() {
} }
wg.Wait() wg.Wait()
n, err := ring.Get("key").Int64() n, err := ring.Get(ctx, "key").Int64()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(100))) Expect(n).To(Equal(int64(100)))
}) })
It("should discard", func() { It("should discard", func() {
err := ring.Watch(func(tx *redis.Tx) error { err := ring.Watch(ctx, func(tx *redis.Tx) error {
cmds, err := tx.TxPipelined(func(pipe redis.Pipeliner) error { cmds, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set("key1", "hello1", 0) pipe.Set(ctx, "key1", "hello1", 0)
pipe.Discard() pipe.Discard()
pipe.Set("key2", "hello2", 0) pipe.Set(ctx, "key2", "hello2", 0)
return nil return nil
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -460,23 +437,23 @@ var _ = Describe("Ring watch", func() {
}, "key1", "key2") }, "key1", "key2")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
get := ring.Get("key1") get := ring.Get(ctx, "key1")
Expect(get.Err()).To(Equal(redis.Nil)) Expect(get.Err()).To(Equal(redis.Nil))
Expect(get.Val()).To(Equal("")) Expect(get.Val()).To(Equal(""))
get = ring.Get("key2") get = ring.Get(ctx, "key2")
Expect(get.Err()).NotTo(HaveOccurred()) Expect(get.Err()).NotTo(HaveOccurred())
Expect(get.Val()).To(Equal("hello2")) Expect(get.Val()).To(Equal("hello2"))
}) })
It("returns no error when there are no commands", func() { It("returns no error when there are no commands", func() {
err := ring.Watch(func(tx *redis.Tx) error { err := ring.Watch(ctx, func(tx *redis.Tx) error {
_, err := tx.TxPipelined(func(redis.Pipeliner) error { return nil }) _, err := tx.TxPipelined(ctx, func(redis.Pipeliner) error { return nil })
return err return err
}, "key") }, "key")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
v, err := ring.Ping().Result() v, err := ring.Ping(ctx).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(v).To(Equal("PONG")) Expect(v).To(Equal("PONG"))
}) })
@ -484,10 +461,10 @@ var _ = Describe("Ring watch", func() {
It("should exec bulks", func() { It("should exec bulks", func() {
const N = 20000 const N = 20000
err := ring.Watch(func(tx *redis.Tx) error { err := ring.Watch(ctx, func(tx *redis.Tx) error {
cmds, err := tx.TxPipelined(func(pipe redis.Pipeliner) error { cmds, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
pipe.Incr("key") pipe.Incr(ctx, "key")
} }
return nil return nil
}) })
@ -500,7 +477,7 @@ var _ = Describe("Ring watch", func() {
}, "key") }, "key")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
num, err := ring.Get("key").Int64() num, err := ring.Get(ctx, "key").Int64()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(num).To(Equal(int64(N))) Expect(num).To(Equal(int64(N)))
}) })
@ -508,21 +485,21 @@ var _ = Describe("Ring watch", func() {
It("should Watch/Unwatch", func() { It("should Watch/Unwatch", func() {
var C, N int var C, N int
err := ring.Set("key", "0", 0).Err() err := ring.Set(ctx, "key", "0", 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
perform(C, func(id int) { perform(C, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
err := ring.Watch(func(tx *redis.Tx) error { err := ring.Watch(ctx, func(tx *redis.Tx) error {
val, err := tx.Get("key").Result() val, err := tx.Get(ctx, "key").Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(val).NotTo(Equal(redis.Nil)) Expect(val).NotTo(Equal(redis.Nil))
num, err := strconv.ParseInt(val, 10, 64) num, err := strconv.ParseInt(val, 10, 64)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cmds, err := tx.TxPipelined(func(pipe redis.Pipeliner) error { cmds, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set("key", strconv.FormatInt(num+1, 10), 0) pipe.Set(ctx, "key", strconv.FormatInt(num+1, 10), 0)
return nil return nil
}) })
Expect(cmds).To(HaveLen(1)) Expect(cmds).To(HaveLen(1))
@ -536,31 +513,31 @@ var _ = Describe("Ring watch", func() {
} }
}) })
val, err := ring.Get("key").Int64() val, err := ring.Get(ctx, "key").Int64()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal(int64(C * N))) Expect(val).To(Equal(int64(C * N)))
}) })
It("should close Tx without closing the client", func() { It("should close Tx without closing the client", func() {
err := ring.Watch(func(tx *redis.Tx) error { err := ring.Watch(ctx, func(tx *redis.Tx) error {
_, err := tx.TxPipelined(func(pipe redis.Pipeliner) error { _, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping() pipe.Ping(ctx)
return nil return nil
}) })
return err return err
}, "key") }, "key")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(ring.Ping().Err()).NotTo(HaveOccurred()) Expect(ring.Ping(ctx).Err()).NotTo(HaveOccurred())
}) })
It("respects max size on multi", func() { It("respects max size on multi", func() {
perform(1000, func(id int) { perform(1000, func(id int) {
var ping *redis.StatusCmd var ping *redis.StatusCmd
err := ring.Watch(func(tx *redis.Tx) error { err := ring.Watch(ctx, func(tx *redis.Tx) error {
cmds, err := tx.TxPipelined(func(pipe redis.Pipeliner) error { cmds, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
ping = pipe.Ping() ping = pipe.Ping(ctx)
return nil return nil
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -573,7 +550,7 @@ var _ = Describe("Ring watch", func() {
Expect(ping.Val()).To(Equal("PONG")) Expect(ping.Val()).To(Equal("PONG"))
}) })
ring.ForEachShard(func(cl *redis.Client) error { ring.ForEachShard(ctx, func(ctx context.Context, cl *redis.Client) error {
defer GinkgoRecover() defer GinkgoRecover()
pool := cl.Pool() pool := cl.Pool()
@ -597,17 +574,17 @@ var _ = Describe("Ring Tx timeout", func() {
testTimeout := func() { testTimeout := func() {
It("Tx timeouts", func() { It("Tx timeouts", func() {
err := ring.Watch(func(tx *redis.Tx) error { err := ring.Watch(ctx, func(tx *redis.Tx) error {
return tx.Ping().Err() return tx.Ping(ctx).Err()
}, "foo") }, "foo")
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue()) Expect(err.(net.Error).Timeout()).To(BeTrue())
}) })
It("Tx Pipeline timeouts", func() { It("Tx Pipeline timeouts", func() {
err := ring.Watch(func(tx *redis.Tx) error { err := ring.Watch(ctx, func(tx *redis.Tx) error {
_, err := tx.TxPipelined(func(pipe redis.Pipeliner) error { _, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping() pipe.Ping(ctx)
return nil return nil
}) })
return err return err
@ -627,17 +604,17 @@ var _ = Describe("Ring Tx timeout", func() {
opt.HeartbeatFrequency = heartbeat opt.HeartbeatFrequency = heartbeat
ring = redis.NewRing(opt) ring = redis.NewRing(opt)
err := ring.ForEachShard(func(client *redis.Client) error { err := ring.ForEachShard(ctx, func(ctx context.Context, client *redis.Client) error {
return client.ClientPause(pause).Err() return client.ClientPause(ctx, pause).Err()
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
AfterEach(func() { AfterEach(func() {
_ = ring.ForEachShard(func(client *redis.Client) error { _ = ring.ForEachShard(ctx, func(ctx context.Context, client *redis.Client) error {
defer GinkgoRecover() defer GinkgoRecover()
Eventually(func() error { Eventually(func() error {
return client.Ping().Err() return client.Ping(ctx).Err()
}, 2*pause).ShouldNot(HaveOccurred()) }, 2*pause).ShouldNot(HaveOccurred())
return nil return nil
}) })

View File

@ -1,6 +1,7 @@
package redis package redis
import ( import (
"context"
"crypto/sha1" "crypto/sha1"
"encoding/hex" "encoding/hex"
"io" "io"
@ -8,10 +9,10 @@ import (
) )
type scripter interface { type scripter interface {
Eval(script string, keys []string, args ...interface{}) *Cmd Eval(ctx context.Context, script string, keys []string, args ...interface{}) *Cmd
EvalSha(sha1 string, keys []string, args ...interface{}) *Cmd EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *Cmd
ScriptExists(hashes ...string) *BoolSliceCmd ScriptExists(ctx context.Context, hashes ...string) *BoolSliceCmd
ScriptLoad(script string) *StringCmd ScriptLoad(ctx context.Context, script string) *StringCmd
} }
var _ scripter = (*Client)(nil) var _ scripter = (*Client)(nil)
@ -35,28 +36,28 @@ func (s *Script) Hash() string {
return s.hash return s.hash
} }
func (s *Script) Load(c scripter) *StringCmd { func (s *Script) Load(ctx context.Context, c scripter) *StringCmd {
return c.ScriptLoad(s.src) return c.ScriptLoad(ctx, s.src)
} }
func (s *Script) Exists(c scripter) *BoolSliceCmd { func (s *Script) Exists(ctx context.Context, c scripter) *BoolSliceCmd {
return c.ScriptExists(s.hash) return c.ScriptExists(ctx, s.hash)
} }
func (s *Script) Eval(c scripter, keys []string, args ...interface{}) *Cmd { func (s *Script) Eval(ctx context.Context, c scripter, keys []string, args ...interface{}) *Cmd {
return c.Eval(s.src, keys, args...) return c.Eval(ctx, s.src, keys, args...)
} }
func (s *Script) EvalSha(c scripter, keys []string, args ...interface{}) *Cmd { func (s *Script) EvalSha(ctx context.Context, c scripter, keys []string, args ...interface{}) *Cmd {
return c.EvalSha(s.hash, keys, args...) return c.EvalSha(ctx, s.hash, keys, args...)
} }
// Run optimistically uses EVALSHA to run the script. If script does not exist // Run optimistically uses EVALSHA to run the script. If script does not exist
// it is retried using EVAL. // it is retried using EVAL.
func (s *Script) Run(c scripter, keys []string, args ...interface{}) *Cmd { func (s *Script) Run(ctx context.Context, c scripter, keys []string, args ...interface{}) *Cmd {
r := s.EvalSha(c, keys, args...) r := s.EvalSha(ctx, c, keys, args...)
if err := r.Err(); err != nil && strings.HasPrefix(err.Error(), "NOSCRIPT ") { if err := r.Err(); err != nil && strings.HasPrefix(err.Error(), "NOSCRIPT ") {
return s.Eval(c, keys, args...) return s.Eval(ctx, c, keys, args...)
} }
return r return r
} }

View File

@ -9,8 +9,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/go-redis/redis/v7/internal" "github.com/go-redis/redis/v8/internal"
"github.com/go-redis/redis/v7/internal/pool" "github.com/go-redis/redis/v8/internal/pool"
) )
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
@ -139,11 +139,7 @@ func (c *SentinelClient) WithContext(ctx context.Context) *SentinelClient {
return &clone return &clone
} }
func (c *SentinelClient) Process(cmd Cmder) error { func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error {
return c.ProcessContext(c.ctx, cmd)
}
func (c *SentinelClient) ProcessContext(ctx context.Context, cmd Cmder) error {
return c.baseClient.process(ctx, cmd) return c.baseClient.process(ctx, cmd)
} }
@ -151,8 +147,8 @@ func (c *SentinelClient) pubSub() *PubSub {
pubsub := &PubSub{ pubsub := &PubSub{
opt: c.opt, opt: c.opt,
newConn: func(channels []string) (*pool.Conn, error) { newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
return c.newConn(context.TODO()) return c.newConn(ctx)
}, },
closeConn: c.connPool.CloseConn, closeConn: c.connPool.CloseConn,
} }
@ -162,49 +158,49 @@ func (c *SentinelClient) pubSub() *PubSub {
// Ping is used to test if a connection is still alive, or to // Ping is used to test if a connection is still alive, or to
// measure latency. // measure latency.
func (c *SentinelClient) Ping() *StringCmd { func (c *SentinelClient) Ping(ctx context.Context) *StringCmd {
cmd := NewStringCmd("ping") cmd := NewStringCmd(ctx, "ping")
_ = c.Process(cmd) _ = c.Process(ctx, cmd)
return cmd return cmd
} }
// Subscribe subscribes the client to the specified channels. // Subscribe subscribes the client to the specified channels.
// Channels can be omitted to create empty subscription. // Channels can be omitted to create empty subscription.
func (c *SentinelClient) Subscribe(channels ...string) *PubSub { func (c *SentinelClient) Subscribe(ctx context.Context, channels ...string) *PubSub {
pubsub := c.pubSub() pubsub := c.pubSub()
if len(channels) > 0 { if len(channels) > 0 {
_ = pubsub.Subscribe(channels...) _ = pubsub.Subscribe(ctx, channels...)
} }
return pubsub return pubsub
} }
// PSubscribe subscribes the client to the given patterns. // PSubscribe subscribes the client to the given patterns.
// Patterns can be omitted to create empty subscription. // Patterns can be omitted to create empty subscription.
func (c *SentinelClient) PSubscribe(channels ...string) *PubSub { func (c *SentinelClient) PSubscribe(ctx context.Context, channels ...string) *PubSub {
pubsub := c.pubSub() pubsub := c.pubSub()
if len(channels) > 0 { if len(channels) > 0 {
_ = pubsub.PSubscribe(channels...) _ = pubsub.PSubscribe(ctx, channels...)
} }
return pubsub return pubsub
} }
func (c *SentinelClient) GetMasterAddrByName(name string) *StringSliceCmd { func (c *SentinelClient) GetMasterAddrByName(ctx context.Context, name string) *StringSliceCmd {
cmd := NewStringSliceCmd("sentinel", "get-master-addr-by-name", name) cmd := NewStringSliceCmd(ctx, "sentinel", "get-master-addr-by-name", name)
_ = c.Process(cmd) _ = c.Process(ctx, cmd)
return cmd return cmd
} }
func (c *SentinelClient) Sentinels(name string) *SliceCmd { func (c *SentinelClient) Sentinels(ctx context.Context, name string) *SliceCmd {
cmd := NewSliceCmd("sentinel", "sentinels", name) cmd := NewSliceCmd(ctx, "sentinel", "sentinels", name)
_ = c.Process(cmd) _ = c.Process(ctx, cmd)
return cmd return cmd
} }
// Failover forces a failover as if the master was not reachable, and without // Failover forces a failover as if the master was not reachable, and without
// asking for agreement to other Sentinels. // asking for agreement to other Sentinels.
func (c *SentinelClient) Failover(name string) *StatusCmd { func (c *SentinelClient) Failover(ctx context.Context, name string) *StatusCmd {
cmd := NewStatusCmd("sentinel", "failover", name) cmd := NewStatusCmd(ctx, "sentinel", "failover", name)
_ = c.Process(cmd) _ = c.Process(ctx, cmd)
return cmd return cmd
} }
@ -212,38 +208,38 @@ func (c *SentinelClient) Failover(name string) *StatusCmd {
// glob-style pattern. The reset process clears any previous state in a master // glob-style pattern. The reset process clears any previous state in a master
// (including a failover in progress), and removes every slave and sentinel // (including a failover in progress), and removes every slave and sentinel
// already discovered and associated with the master. // already discovered and associated with the master.
func (c *SentinelClient) Reset(pattern string) *IntCmd { func (c *SentinelClient) Reset(ctx context.Context, pattern string) *IntCmd {
cmd := NewIntCmd("sentinel", "reset", pattern) cmd := NewIntCmd(ctx, "sentinel", "reset", pattern)
_ = c.Process(cmd) _ = c.Process(ctx, cmd)
return cmd return cmd
} }
// FlushConfig forces Sentinel to rewrite its configuration on disk, including // FlushConfig forces Sentinel to rewrite its configuration on disk, including
// the current Sentinel state. // the current Sentinel state.
func (c *SentinelClient) FlushConfig() *StatusCmd { func (c *SentinelClient) FlushConfig(ctx context.Context) *StatusCmd {
cmd := NewStatusCmd("sentinel", "flushconfig") cmd := NewStatusCmd(ctx, "sentinel", "flushconfig")
_ = c.Process(cmd) _ = c.Process(ctx, cmd)
return cmd return cmd
} }
// Master shows the state and info of the specified master. // Master shows the state and info of the specified master.
func (c *SentinelClient) Master(name string) *StringStringMapCmd { func (c *SentinelClient) Master(ctx context.Context, name string) *StringStringMapCmd {
cmd := NewStringStringMapCmd("sentinel", "master", name) cmd := NewStringStringMapCmd(ctx, "sentinel", "master", name)
_ = c.Process(cmd) _ = c.Process(ctx, cmd)
return cmd return cmd
} }
// Masters shows a list of monitored masters and their state. // Masters shows a list of monitored masters and their state.
func (c *SentinelClient) Masters() *SliceCmd { func (c *SentinelClient) Masters(ctx context.Context) *SliceCmd {
cmd := NewSliceCmd("sentinel", "masters") cmd := NewSliceCmd(ctx, "sentinel", "masters")
_ = c.Process(cmd) _ = c.Process(ctx, cmd)
return cmd return cmd
} }
// Slaves shows a list of slaves for the specified master and their state. // Slaves shows a list of slaves for the specified master and their state.
func (c *SentinelClient) Slaves(name string) *SliceCmd { func (c *SentinelClient) Slaves(ctx context.Context, name string) *SliceCmd {
cmd := NewSliceCmd("sentinel", "slaves", name) cmd := NewSliceCmd(ctx, "sentinel", "slaves", name)
_ = c.Process(cmd) _ = c.Process(ctx, cmd)
return cmd return cmd
} }
@ -251,33 +247,33 @@ func (c *SentinelClient) Slaves(name string) *SliceCmd {
// quorum needed to failover a master, and the majority needed to authorize the // quorum needed to failover a master, and the majority needed to authorize the
// failover. This command should be used in monitoring systems to check if a // failover. This command should be used in monitoring systems to check if a
// Sentinel deployment is ok. // Sentinel deployment is ok.
func (c *SentinelClient) CkQuorum(name string) *StringCmd { func (c *SentinelClient) CkQuorum(ctx context.Context, name string) *StringCmd {
cmd := NewStringCmd("sentinel", "ckquorum", name) cmd := NewStringCmd(ctx, "sentinel", "ckquorum", name)
_ = c.Process(cmd) _ = c.Process(ctx, cmd)
return cmd return cmd
} }
// Monitor tells the Sentinel to start monitoring a new master with the specified // Monitor tells the Sentinel to start monitoring a new master with the specified
// name, ip, port, and quorum. // name, ip, port, and quorum.
func (c *SentinelClient) Monitor(name, ip, port, quorum string) *StringCmd { func (c *SentinelClient) Monitor(ctx context.Context, name, ip, port, quorum string) *StringCmd {
cmd := NewStringCmd("sentinel", "monitor", name, ip, port, quorum) cmd := NewStringCmd(ctx, "sentinel", "monitor", name, ip, port, quorum)
_ = c.Process(cmd) _ = c.Process(ctx, cmd)
return cmd return cmd
} }
// Set is used in order to change configuration parameters of a specific master. // Set is used in order to change configuration parameters of a specific master.
func (c *SentinelClient) Set(name, option, value string) *StringCmd { func (c *SentinelClient) Set(ctx context.Context, name, option, value string) *StringCmd {
cmd := NewStringCmd("sentinel", "set", name, option, value) cmd := NewStringCmd(ctx, "sentinel", "set", name, option, value)
_ = c.Process(cmd) _ = c.Process(ctx, cmd)
return cmd return cmd
} }
// Remove is used in order to remove the specified master: the master will no // Remove is used in order to remove the specified master: the master will no
// longer be monitored, and will totally be removed from the internal state of // longer be monitored, and will totally be removed from the internal state of
// the Sentinel. // the Sentinel.
func (c *SentinelClient) Remove(name string) *StringCmd { func (c *SentinelClient) Remove(ctx context.Context, name string) *StringCmd {
cmd := NewStringCmd("sentinel", "remove", name) cmd := NewStringCmd(ctx, "sentinel", "remove", name)
_ = c.Process(cmd) _ = c.Process(ctx, cmd)
return cmd return cmd
} }
@ -330,7 +326,7 @@ func (c *sentinelFailover) Pool() *pool.ConnPool {
} }
func (c *sentinelFailover) dial(ctx context.Context, network, _ string) (net.Conn, error) { func (c *sentinelFailover) dial(ctx context.Context, network, _ string) (net.Conn, error) {
addr, err := c.MasterAddr() addr, err := c.MasterAddr(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -340,8 +336,8 @@ func (c *sentinelFailover) dial(ctx context.Context, network, _ string) (net.Con
return net.DialTimeout("tcp", addr, c.opt.DialTimeout) return net.DialTimeout("tcp", addr, c.opt.DialTimeout)
} }
func (c *sentinelFailover) MasterAddr() (string, error) { func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) {
addr, err := c.masterAddr() addr, err := c.masterAddr(ctx)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -349,13 +345,13 @@ func (c *sentinelFailover) MasterAddr() (string, error) {
return addr, nil return addr, nil
} }
func (c *sentinelFailover) masterAddr() (string, error) { func (c *sentinelFailover) masterAddr(ctx context.Context) (string, error) {
c.mu.RLock() c.mu.RLock()
sentinel := c.sentinel sentinel := c.sentinel
c.mu.RUnlock() c.mu.RUnlock()
if sentinel != nil { if sentinel != nil {
addr := c.getMasterAddr(sentinel) addr := c.getMasterAddr(ctx, sentinel)
if addr != "" { if addr != "" {
return addr, nil return addr, nil
} }
@ -365,7 +361,7 @@ func (c *sentinelFailover) masterAddr() (string, error) {
defer c.mu.Unlock() defer c.mu.Unlock()
if c.sentinel != nil { if c.sentinel != nil {
addr := c.getMasterAddr(c.sentinel) addr := c.getMasterAddr(ctx, c.sentinel)
if addr != "" { if addr != "" {
return addr, nil return addr, nil
} }
@ -394,7 +390,7 @@ func (c *sentinelFailover) masterAddr() (string, error) {
TLSConfig: c.opt.TLSConfig, TLSConfig: c.opt.TLSConfig,
}) })
masterAddr, err := sentinel.GetMasterAddrByName(c.masterName).Result() masterAddr, err := sentinel.GetMasterAddrByName(ctx, c.masterName).Result()
if err != nil { if err != nil {
internal.Logger.Printf("sentinel: GetMasterAddrByName master=%q failed: %s", internal.Logger.Printf("sentinel: GetMasterAddrByName master=%q failed: %s",
c.masterName, err) c.masterName, err)
@ -404,7 +400,7 @@ func (c *sentinelFailover) masterAddr() (string, error) {
// Push working sentinel to the top. // Push working sentinel to the top.
c.sentinelAddrs[0], c.sentinelAddrs[i] = c.sentinelAddrs[i], c.sentinelAddrs[0] c.sentinelAddrs[0], c.sentinelAddrs[i] = c.sentinelAddrs[i], c.sentinelAddrs[0]
c.setSentinel(sentinel) c.setSentinel(ctx, sentinel)
addr := net.JoinHostPort(masterAddr[0], masterAddr[1]) addr := net.JoinHostPort(masterAddr[0], masterAddr[1])
return addr, nil return addr, nil
@ -413,8 +409,8 @@ func (c *sentinelFailover) masterAddr() (string, error) {
return "", errors.New("redis: all sentinels are unreachable") return "", errors.New("redis: all sentinels are unreachable")
} }
func (c *sentinelFailover) getMasterAddr(sentinel *SentinelClient) string { func (c *sentinelFailover) getMasterAddr(ctx context.Context, sentinel *SentinelClient) string {
addr, err := sentinel.GetMasterAddrByName(c.masterName).Result() addr, err := sentinel.GetMasterAddrByName(ctx, c.masterName).Result()
if err != nil { if err != nil {
internal.Logger.Printf("sentinel: GetMasterAddrByName name=%q failed: %s", internal.Logger.Printf("sentinel: GetMasterAddrByName name=%q failed: %s",
c.masterName, err) c.masterName, err)
@ -446,19 +442,19 @@ func (c *sentinelFailover) switchMaster(addr string) {
c._masterAddr = addr c._masterAddr = addr
} }
func (c *sentinelFailover) setSentinel(sentinel *SentinelClient) { func (c *sentinelFailover) setSentinel(ctx context.Context, sentinel *SentinelClient) {
if c.sentinel != nil { if c.sentinel != nil {
panic("not reached") panic("not reached")
} }
c.sentinel = sentinel c.sentinel = sentinel
c.discoverSentinels() c.discoverSentinels(ctx)
c.pubsub = sentinel.Subscribe("+switch-master") c.pubsub = sentinel.Subscribe(ctx, "+switch-master")
go c.listen(c.pubsub) go c.listen(c.pubsub)
} }
func (c *sentinelFailover) discoverSentinels() { func (c *sentinelFailover) discoverSentinels(ctx context.Context) {
sentinels, err := c.sentinel.Sentinels(c.masterName).Result() sentinels, err := c.sentinel.Sentinels(ctx, c.masterName).Result()
if err != nil { if err != nil {
internal.Logger.Printf("sentinel: Sentinels master=%q failed: %s", c.masterName, err) internal.Logger.Printf("sentinel: Sentinels master=%q failed: %s", c.masterName, err)
return return

View File

@ -1,7 +1,7 @@
package redis_test package redis_test
import ( import (
"github.com/go-redis/redis/v7" "github.com/go-redis/redis/v8"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -15,7 +15,7 @@ var _ = Describe("Sentinel", func() {
MasterName: sentinelName, MasterName: sentinelName,
SentinelAddrs: []string{":" + sentinelPort}, SentinelAddrs: []string{":" + sentinelPort},
}) })
Expect(client.FlushDB().Err()).NotTo(HaveOccurred()) Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
}) })
AfterEach(func() { AfterEach(func() {
@ -24,48 +24,48 @@ var _ = Describe("Sentinel", func() {
It("should facilitate failover", func() { It("should facilitate failover", func() {
// Set value on master. // Set value on master.
err := client.Set("foo", "master", 0).Err() err := client.Set(ctx, "foo", "master", 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
// Verify. // Verify.
val, err := sentinelMaster.Get("foo").Result() val, err := sentinelMaster.Get(ctx, "foo").Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("master")) Expect(val).To(Equal("master"))
// Create subscription. // Create subscription.
ch := client.Subscribe("foo").Channel() ch := client.Subscribe(ctx, "foo").Channel()
// Wait until replicated. // Wait until replicated.
Eventually(func() string { Eventually(func() string {
return sentinelSlave1.Get("foo").Val() return sentinelSlave1.Get(ctx, "foo").Val()
}, "1s", "100ms").Should(Equal("master")) }, "1s", "100ms").Should(Equal("master"))
Eventually(func() string { Eventually(func() string {
return sentinelSlave2.Get("foo").Val() return sentinelSlave2.Get(ctx, "foo").Val()
}, "1s", "100ms").Should(Equal("master")) }, "1s", "100ms").Should(Equal("master"))
// Wait until slaves are picked up by sentinel. // Wait until slaves are picked up by sentinel.
Eventually(func() string { Eventually(func() string {
return sentinel.Info().Val() return sentinel.Info(ctx).Val()
}, "10s", "100ms").Should(ContainSubstring("slaves=2")) }, "10s", "100ms").Should(ContainSubstring("slaves=2"))
// Kill master. // Kill master.
sentinelMaster.Shutdown() sentinelMaster.Shutdown(ctx)
Eventually(func() error { Eventually(func() error {
return sentinelMaster.Ping().Err() return sentinelMaster.Ping(ctx).Err()
}, "5s", "100ms").Should(HaveOccurred()) }, "5s", "100ms").Should(HaveOccurred())
// Wait for Redis sentinel to elect new master. // Wait for Redis sentinel to elect new master.
Eventually(func() string { Eventually(func() string {
return sentinelSlave1.Info().Val() + sentinelSlave2.Info().Val() return sentinelSlave1.Info(ctx).Val() + sentinelSlave2.Info(ctx).Val()
}, "30s", "1s").Should(ContainSubstring("role:master")) }, "30s", "1s").Should(ContainSubstring("role:master"))
// Check that client picked up new master. // Check that client picked up new master.
Eventually(func() error { Eventually(func() error {
return client.Get("foo").Err() return client.Get(ctx, "foo").Err()
}, "5s", "100ms").ShouldNot(HaveOccurred()) }, "5s", "100ms").ShouldNot(HaveOccurred())
// Publish message to check if subscription is renewed. // Publish message to check if subscription is renewed.
err = client.Publish("foo", "hello").Err() err = client.Publish(ctx, "foo", "hello").Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
var msg *redis.Message var msg *redis.Message
@ -82,7 +82,7 @@ var _ = Describe("Sentinel", func() {
SentinelAddrs: []string{":" + sentinelPort}, SentinelAddrs: []string{":" + sentinelPort},
DB: 1, DB: 1,
}) })
err := client.Ping().Err() err := client.Ping(ctx).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
}) })

46
tx.go
View File

@ -3,8 +3,8 @@ package redis
import ( import (
"context" "context"
"github.com/go-redis/redis/v7/internal/pool" "github.com/go-redis/redis/v8/internal/pool"
"github.com/go-redis/redis/v7/internal/proto" "github.com/go-redis/redis/v8/internal/proto"
) )
// TxFailedErr transaction redis failed. // TxFailedErr transaction redis failed.
@ -55,11 +55,7 @@ func (c *Tx) WithContext(ctx context.Context) *Tx {
return &clone return &clone
} }
func (c *Tx) Process(cmd Cmder) error { func (c *Tx) Process(ctx context.Context, cmd Cmder) error {
return c.ProcessContext(c.ctx, cmd)
}
func (c *Tx) ProcessContext(ctx context.Context, cmd Cmder) error {
return c.hooks.process(ctx, cmd, c.baseClient.process) return c.hooks.process(ctx, cmd, c.baseClient.process)
} }
@ -67,52 +63,48 @@ func (c *Tx) ProcessContext(ctx context.Context, cmd Cmder) error {
// for conditional execution if there are any keys. // for conditional execution if there are any keys.
// //
// The transaction is automatically closed when fn exits. // The transaction is automatically closed when fn exits.
func (c *Client) Watch(fn func(*Tx) error, keys ...string) error { func (c *Client) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error {
return c.WatchContext(c.ctx, fn, keys...)
}
func (c *Client) WatchContext(ctx context.Context, fn func(*Tx) error, keys ...string) error {
tx := c.newTx(ctx) tx := c.newTx(ctx)
if len(keys) > 0 { if len(keys) > 0 {
if err := tx.Watch(keys...).Err(); err != nil { if err := tx.Watch(ctx, keys...).Err(); err != nil {
_ = tx.Close() _ = tx.Close(ctx)
return err return err
} }
} }
err := fn(tx) err := fn(tx)
_ = tx.Close() _ = tx.Close(ctx)
return err return err
} }
// Close closes the transaction, releasing any open resources. // Close closes the transaction, releasing any open resources.
func (c *Tx) Close() error { func (c *Tx) Close(ctx context.Context) error {
_ = c.Unwatch().Err() _ = c.Unwatch(ctx).Err()
return c.baseClient.Close() return c.baseClient.Close()
} }
// Watch marks the keys to be watched for conditional execution // Watch marks the keys to be watched for conditional execution
// of a transaction. // of a transaction.
func (c *Tx) Watch(keys ...string) *StatusCmd { func (c *Tx) Watch(ctx context.Context, keys ...string) *StatusCmd {
args := make([]interface{}, 1+len(keys)) args := make([]interface{}, 1+len(keys))
args[0] = "watch" args[0] = "watch"
for i, key := range keys { for i, key := range keys {
args[1+i] = key args[1+i] = key
} }
cmd := NewStatusCmd(args...) cmd := NewStatusCmd(ctx, args...)
_ = c.Process(cmd) _ = c.Process(ctx, cmd)
return cmd return cmd
} }
// Unwatch flushes all the previously watched keys for a transaction. // Unwatch flushes all the previously watched keys for a transaction.
func (c *Tx) Unwatch(keys ...string) *StatusCmd { func (c *Tx) Unwatch(ctx context.Context, keys ...string) *StatusCmd {
args := make([]interface{}, 1+len(keys)) args := make([]interface{}, 1+len(keys))
args[0] = "unwatch" args[0] = "unwatch"
for i, key := range keys { for i, key := range keys {
args[1+i] = key args[1+i] = key
} }
cmd := NewStatusCmd(args...) cmd := NewStatusCmd(ctx, args...)
_ = c.Process(cmd) _ = c.Process(ctx, cmd)
return cmd return cmd
} }
@ -130,8 +122,8 @@ func (c *Tx) Pipeline() Pipeliner {
// Pipelined executes commands queued in the fn outside of the transaction. // Pipelined executes commands queued in the fn outside of the transaction.
// Use TxPipelined if you need transactional behavior. // Use TxPipelined if you need transactional behavior.
func (c *Tx) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { func (c *Tx) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(fn) return c.Pipeline().Pipelined(ctx, fn)
} }
// TxPipelined executes commands queued in the fn in the transaction. // TxPipelined executes commands queued in the fn in the transaction.
@ -142,8 +134,8 @@ func (c *Tx) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
// Exec always returns list of commands. If transaction fails // Exec always returns list of commands. If transaction fails
// TxFailedErr is returned. Otherwise Exec returns an error of the first // TxFailedErr is returned. Otherwise Exec returns an error of the first
// failed command or nil. // failed command or nil.
func (c *Tx) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) { func (c *Tx) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(fn) return c.TxPipeline().Pipelined(ctx, fn)
} }
// TxPipeline creates a pipeline. Usually it is more convenient to use TxPipelined. // TxPipeline creates a pipeline. Usually it is more convenient to use TxPipelined.

View File

@ -5,7 +5,7 @@ import (
"strconv" "strconv"
"sync" "sync"
"github.com/go-redis/redis/v7" "github.com/go-redis/redis/v8"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -16,7 +16,7 @@ var _ = Describe("Tx", func() {
BeforeEach(func() { BeforeEach(func() {
client = redis.NewClient(redisOptions()) client = redis.NewClient(redisOptions())
Expect(client.FlushDB().Err()).NotTo(HaveOccurred()) Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
}) })
AfterEach(func() { AfterEach(func() {
@ -28,14 +28,14 @@ var _ = Describe("Tx", func() {
// Transactionally increments key using GET and SET commands. // Transactionally increments key using GET and SET commands.
incr = func(key string) error { incr = func(key string) error {
err := client.Watch(func(tx *redis.Tx) error { err := client.Watch(ctx, func(tx *redis.Tx) error {
n, err := tx.Get(key).Int64() n, err := tx.Get(ctx, key).Int64()
if err != nil && err != redis.Nil { if err != nil && err != redis.Nil {
return err return err
} }
_, err = tx.TxPipelined(func(pipe redis.Pipeliner) error { _, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set(key, strconv.FormatInt(n+1, 10), 0) pipe.Set(ctx, key, strconv.FormatInt(n+1, 10), 0)
return nil return nil
}) })
return err return err
@ -59,17 +59,17 @@ var _ = Describe("Tx", func() {
} }
wg.Wait() wg.Wait()
n, err := client.Get("key").Int64() n, err := client.Get(ctx, "key").Int64()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(100))) Expect(n).To(Equal(int64(100)))
}) })
It("should discard", func() { It("should discard", func() {
err := client.Watch(func(tx *redis.Tx) error { err := client.Watch(ctx, func(tx *redis.Tx) error {
cmds, err := tx.TxPipelined(func(pipe redis.Pipeliner) error { cmds, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set("key1", "hello1", 0) pipe.Set(ctx, "key1", "hello1", 0)
pipe.Discard() pipe.Discard()
pipe.Set("key2", "hello2", 0) pipe.Set(ctx, "key2", "hello2", 0)
return nil return nil
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -78,23 +78,23 @@ var _ = Describe("Tx", func() {
}, "key1", "key2") }, "key1", "key2")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
get := client.Get("key1") get := client.Get(ctx, "key1")
Expect(get.Err()).To(Equal(redis.Nil)) Expect(get.Err()).To(Equal(redis.Nil))
Expect(get.Val()).To(Equal("")) Expect(get.Val()).To(Equal(""))
get = client.Get("key2") get = client.Get(ctx, "key2")
Expect(get.Err()).NotTo(HaveOccurred()) Expect(get.Err()).NotTo(HaveOccurred())
Expect(get.Val()).To(Equal("hello2")) Expect(get.Val()).To(Equal("hello2"))
}) })
It("returns no error when there are no commands", func() { It("returns no error when there are no commands", func() {
err := client.Watch(func(tx *redis.Tx) error { err := client.Watch(ctx, func(tx *redis.Tx) error {
_, err := tx.TxPipelined(func(redis.Pipeliner) error { return nil }) _, err := tx.TxPipelined(ctx, func(redis.Pipeliner) error { return nil })
return err return err
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
v, err := client.Ping().Result() v, err := client.Ping(ctx).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(v).To(Equal("PONG")) Expect(v).To(Equal("PONG"))
}) })
@ -102,10 +102,10 @@ var _ = Describe("Tx", func() {
It("should exec bulks", func() { It("should exec bulks", func() {
const N = 20000 const N = 20000
err := client.Watch(func(tx *redis.Tx) error { err := client.Watch(ctx, func(tx *redis.Tx) error {
cmds, err := tx.TxPipelined(func(pipe redis.Pipeliner) error { cmds, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
pipe.Incr("key") pipe.Incr(ctx, "key")
} }
return nil return nil
}) })
@ -118,7 +118,7 @@ var _ = Describe("Tx", func() {
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
num, err := client.Get("key").Int64() num, err := client.Get(ctx, "key").Int64()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(num).To(Equal(int64(N))) Expect(num).To(Equal(int64(N)))
}) })
@ -132,9 +132,9 @@ var _ = Describe("Tx", func() {
client.Pool().Put(cn) client.Pool().Put(cn)
do := func() error { do := func() error {
err := client.Watch(func(tx *redis.Tx) error { err := client.Watch(ctx, func(tx *redis.Tx) error {
_, err := tx.TxPipelined(func(pipe redis.Pipeliner) error { _, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping() pipe.Ping(ctx)
return nil return nil
}) })
return err return err

View File

@ -168,13 +168,11 @@ type UniversalClient interface {
Cmdable Cmdable
Context() context.Context Context() context.Context
AddHook(Hook) AddHook(Hook)
Watch(fn func(*Tx) error, keys ...string) error Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error
Do(args ...interface{}) *Cmd Do(ctx context.Context, args ...interface{}) *Cmd
DoContext(ctx context.Context, args ...interface{}) *Cmd Process(ctx context.Context, cmd Cmder) error
Process(cmd Cmder) error Subscribe(ctx context.Context, channels ...string) *PubSub
ProcessContext(ctx context.Context, cmd Cmder) error PSubscribe(ctx context.Context, channels ...string) *PubSub
Subscribe(channels ...string) *PubSub
PSubscribe(channels ...string) *PubSub
Close() error Close() error
} }

View File

@ -4,7 +4,7 @@ import (
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"github.com/go-redis/redis/v7" "github.com/go-redis/redis/v8"
) )
var _ = Describe("UniversalClient", func() { var _ = Describe("UniversalClient", func() {
@ -21,21 +21,21 @@ var _ = Describe("UniversalClient", func() {
MasterName: sentinelName, MasterName: sentinelName,
Addrs: []string{":" + sentinelPort}, Addrs: []string{":" + sentinelPort},
}) })
Expect(client.Ping().Err()).NotTo(HaveOccurred()) Expect(client.Ping(ctx).Err()).NotTo(HaveOccurred())
}) })
It("should connect to simple servers", func() { It("should connect to simple servers", func() {
client = redis.NewUniversalClient(&redis.UniversalOptions{ client = redis.NewUniversalClient(&redis.UniversalOptions{
Addrs: []string{redisAddr}, Addrs: []string{redisAddr},
}) })
Expect(client.Ping().Err()).NotTo(HaveOccurred()) Expect(client.Ping(ctx).Err()).NotTo(HaveOccurred())
}) })
It("should connect to clusters", func() { It("should connect to clusters", func() {
client = redis.NewUniversalClient(&redis.UniversalOptions{ client = redis.NewUniversalClient(&redis.UniversalOptions{
Addrs: cluster.addrs(), Addrs: cluster.addrs(),
}) })
Expect(client.Ping().Err()).NotTo(HaveOccurred()) Expect(client.Ping(ctx).Err()).NotTo(HaveOccurred())
}) })
}) })