Merge pull request #1040 from go-redis/feature/hook-new

Feature/hook new
This commit is contained in:
Vladimir Mihailenco 2019-06-01 11:50:49 +03:00 committed by GitHub
commit 17480c545e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 1151 additions and 1051 deletions

View File

@ -1,6 +1,11 @@
# Changelog # Changelog
## Unreleased ## v7 WIP
- WrapProcess is replaced with more convenient AddHook that has access to context.Context.
- WithContext no longer creates shallow copy.
## v6.15
- Cluster and Ring pipelines process commands for each node in its own goroutine. - Cluster and Ring pipelines process commands for each node in its own goroutine.

View File

@ -2,6 +2,7 @@ package redis_test
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"strings" "strings"
"testing" "testing"
@ -198,3 +199,140 @@ func BenchmarkZAdd(b *testing.B) {
} }
}) })
} }
var clientSink *redis.Client
func BenchmarkWithContext(b *testing.B) {
rdb := benchmarkRedisClient(10)
defer rdb.Close()
ctx := context.Background()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
clientSink = rdb.WithContext(ctx)
}
}
var ringSink *redis.Ring
func BenchmarkRingWithContext(b *testing.B) {
rdb := redis.NewRing(&redis.RingOptions{})
defer rdb.Close()
ctx := context.Background()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
ringSink = rdb.WithContext(ctx)
}
}
//------------------------------------------------------------------------------
func newClusterScenario() *clusterScenario {
return &clusterScenario{
ports: []string{"8220", "8221", "8222", "8223", "8224", "8225"},
nodeIds: make([]string, 6),
processes: make(map[string]*redisProcess, 6),
clients: make(map[string]*redis.Client, 6),
}
}
func BenchmarkClusterPing(b *testing.B) {
if testing.Short() {
b.Skip("skipping in short mode")
}
cluster := newClusterScenario()
if err := startCluster(cluster); err != nil {
b.Fatal(err)
}
defer stopCluster(cluster)
client := cluster.clusterClient(redisClusterOptions())
defer client.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
err := client.Ping().Err()
if err != nil {
b.Fatal(err)
}
}
})
}
func BenchmarkClusterSetString(b *testing.B) {
if testing.Short() {
b.Skip("skipping in short mode")
}
cluster := newClusterScenario()
if err := startCluster(cluster); err != nil {
b.Fatal(err)
}
defer stopCluster(cluster)
client := cluster.clusterClient(redisClusterOptions())
defer client.Close()
value := string(bytes.Repeat([]byte{'1'}, 10000))
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
err := client.Set("key", value, 0).Err()
if err != nil {
b.Fatal(err)
}
}
})
}
func BenchmarkClusterReloadState(b *testing.B) {
if testing.Short() {
b.Skip("skipping in short mode")
}
cluster := newClusterScenario()
if err := startCluster(cluster); err != nil {
b.Fatal(err)
}
defer stopCluster(cluster)
client := cluster.clusterClient(redisClusterOptions())
defer client.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err := client.ReloadState()
if err != nil {
b.Fatal(err)
}
}
}
var clusterSink *redis.ClusterClient
func BenchmarkClusterWithContext(b *testing.B) {
rdb := redis.NewClusterClient(&redis.ClusterOptions{})
defer rdb.Close()
ctx := context.Background()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
clusterSink = rdb.WithContext(ctx)
}
}

View File

@ -639,22 +639,22 @@ func (c *clusterStateHolder) ReloadOrGet() (*clusterState, error) {
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
// ClusterClient is a Redis Cluster client representing a pool of zero type clusterClient struct {
// or more underlying connections. It's safe for concurrent use by
// multiple goroutines.
type ClusterClient struct {
cmdable cmdable
hooks
ctx context.Context
opt *ClusterOptions opt *ClusterOptions
nodes *clusterNodes nodes *clusterNodes
state *clusterStateHolder state *clusterStateHolder
cmdsInfoCache *cmdsInfoCache cmdsInfoCache *cmdsInfoCache
}
process func(Cmder) error // ClusterClient is a Redis Cluster client representing a pool of zero
processPipeline func([]Cmder) error // or more underlying connections. It's safe for concurrent use by
processTxPipeline func([]Cmder) error // multiple goroutines.
type ClusterClient struct {
*clusterClient
ctx context.Context
} }
// NewClusterClient returns a Redis Cluster client as described in // NewClusterClient returns a Redis Cluster client as described in
@ -663,16 +663,14 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient {
opt.init() opt.init()
c := &ClusterClient{ c := &ClusterClient{
opt: opt, clusterClient: &clusterClient{
nodes: newClusterNodes(opt), opt: opt,
nodes: newClusterNodes(opt),
},
} }
c.state = newClusterStateHolder(c.loadState) c.state = newClusterStateHolder(c.loadState)
c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo)
c.process = c.defaultProcess
c.processPipeline = c.defaultProcessPipeline
c.processTxPipeline = c.defaultProcessTxPipeline
c.init() c.init()
if opt.IdleCheckFrequency > 0 { if opt.IdleCheckFrequency > 0 {
go c.reaper(opt.IdleCheckFrequency) go c.reaper(opt.IdleCheckFrequency)
@ -682,14 +680,7 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient {
} }
func (c *ClusterClient) init() { func (c *ClusterClient) init() {
c.cmdable.setProcessor(c.Process) c.cmdable = c.Process
}
// ReloadState reloads cluster state. If available it calls ClusterSlots func
// to get cluster slots information.
func (c *ClusterClient) ReloadState() error {
_, err := c.state.Reload()
return err
} }
func (c *ClusterClient) Context() context.Context { func (c *ClusterClient) Context() context.Context {
@ -703,15 +694,9 @@ func (c *ClusterClient) WithContext(ctx context.Context) *ClusterClient {
if ctx == nil { if ctx == nil {
panic("nil context") panic("nil context")
} }
c2 := c.clone() clone := *c
c2.ctx = ctx clone.ctx = ctx
return c2 return &clone
}
func (c *ClusterClient) clone() *ClusterClient {
cp := *c
cp.init()
return &cp
} }
// Options returns read-only Options that were used to create the client. // Options returns read-only Options that were used to create the client.
@ -719,164 +704,10 @@ func (c *ClusterClient) Options() *ClusterOptions {
return c.opt return c.opt
} }
func (c *ClusterClient) retryBackoff(attempt int) time.Duration { // ReloadState reloads cluster state. If available it calls ClusterSlots func
return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) // to get cluster slots information.
} func (c *ClusterClient) ReloadState() error {
_, err := c.state.Reload()
func (c *ClusterClient) cmdsInfo() (map[string]*CommandInfo, error) {
addrs, err := c.nodes.Addrs()
if err != nil {
return nil, err
}
var firstErr error
for _, addr := range addrs {
node, err := c.nodes.Get(addr)
if err != nil {
return nil, err
}
if node == nil {
continue
}
info, err := node.Client.Command().Result()
if err == nil {
return info, nil
}
if firstErr == nil {
firstErr = err
}
}
return nil, firstErr
}
func (c *ClusterClient) cmdInfo(name string) *CommandInfo {
cmdsInfo, err := c.cmdsInfoCache.Get()
if err != nil {
return nil
}
info := cmdsInfo[name]
if info == nil {
internal.Logf("info for cmd=%s not found", name)
}
return info
}
func cmdSlot(cmd Cmder, pos int) int {
if pos == 0 {
return hashtag.RandomSlot()
}
firstKey := cmd.stringArg(pos)
return hashtag.Slot(firstKey)
}
func (c *ClusterClient) cmdSlot(cmd Cmder) int {
args := cmd.Args()
if args[0] == "cluster" && args[1] == "getkeysinslot" {
return args[2].(int)
}
cmdInfo := c.cmdInfo(cmd.Name())
return cmdSlot(cmd, cmdFirstKeyPos(cmd, cmdInfo))
}
func (c *ClusterClient) cmdSlotAndNode(cmd Cmder) (int, *clusterNode, error) {
state, err := c.state.Get()
if err != nil {
return 0, nil, err
}
cmdInfo := c.cmdInfo(cmd.Name())
slot := c.cmdSlot(cmd)
if c.opt.ReadOnly && cmdInfo != nil && cmdInfo.ReadOnly {
if c.opt.RouteByLatency {
node, err := state.slotClosestNode(slot)
return slot, node, err
}
if c.opt.RouteRandomly {
node := state.slotRandomNode(slot)
return slot, node, nil
}
node, err := state.slotSlaveNode(slot)
return slot, node, err
}
node, err := state.slotMasterNode(slot)
return slot, node, err
}
func (c *ClusterClient) slotMasterNode(slot int) (*clusterNode, error) {
state, err := c.state.Get()
if err != nil {
return nil, err
}
nodes := state.slotNodes(slot)
if len(nodes) > 0 {
return nodes[0], nil
}
return c.nodes.Random()
}
func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error {
if len(keys) == 0 {
return fmt.Errorf("redis: Watch requires at least one key")
}
slot := hashtag.Slot(keys[0])
for _, key := range keys[1:] {
if hashtag.Slot(key) != slot {
err := fmt.Errorf("redis: Watch requires all keys to be in the same slot")
return err
}
}
node, err := c.slotMasterNode(slot)
if err != nil {
return err
}
for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ {
if attempt > 0 {
time.Sleep(c.retryBackoff(attempt))
}
err = node.Client.Watch(fn, keys...)
if err == nil {
break
}
if err != Nil {
c.state.LazyReload()
}
moved, ask, addr := internal.IsMovedError(err)
if moved || ask {
node, err = c.nodes.GetOrCreate(addr)
if err != nil {
return err
}
continue
}
if err == pool.ErrClosed || internal.IsReadOnlyError(err) {
node, err = c.slotMasterNode(slot)
if err != nil {
return err
}
continue
}
if internal.IsRetryableError(err, true) {
continue
}
return err
}
return err return err
} }
@ -895,17 +726,11 @@ func (c *ClusterClient) Do(args ...interface{}) *Cmd {
return cmd return cmd
} }
func (c *ClusterClient) WrapProcess(
fn func(oldProcess func(Cmder) error) func(Cmder) error,
) {
c.process = fn(c.process)
}
func (c *ClusterClient) Process(cmd Cmder) error { func (c *ClusterClient) Process(cmd Cmder) error {
return c.process(cmd) return c.hooks.process(c.ctx, cmd, c.process)
} }
func (c *ClusterClient) defaultProcess(cmd Cmder) error { func (c *ClusterClient) process(cmd Cmder) error {
var node *clusterNode var node *clusterNode
var ask bool var ask bool
for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ {
@ -1186,7 +1011,7 @@ func (c *ClusterClient) Pipeline() Pipeliner {
pipe := Pipeline{ pipe := Pipeline{
exec: c.processPipeline, exec: c.processPipeline,
} }
pipe.statefulCmdable.setProcessor(pipe.Process) pipe.init()
return &pipe return &pipe
} }
@ -1194,14 +1019,11 @@ func (c *ClusterClient) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(fn) return c.Pipeline().Pipelined(fn)
} }
func (c *ClusterClient) WrapProcessPipeline( func (c *ClusterClient) processPipeline(cmds []Cmder) error {
fn func(oldProcess func([]Cmder) error) func([]Cmder) error, return c.hooks.processPipeline(c.ctx, cmds, c._processPipeline)
) {
c.processPipeline = fn(c.processPipeline)
c.processTxPipeline = fn(c.processTxPipeline)
} }
func (c *ClusterClient) defaultProcessPipeline(cmds []Cmder) error { func (c *ClusterClient) _processPipeline(cmds []Cmder) error {
cmdsMap := newCmdsMap() cmdsMap := newCmdsMap()
err := c.mapCmdsByNode(cmds, cmdsMap) err := c.mapCmdsByNode(cmds, cmdsMap)
if err != nil { if err != nil {
@ -1383,7 +1205,7 @@ func (c *ClusterClient) TxPipeline() Pipeliner {
pipe := Pipeline{ pipe := Pipeline{
exec: c.processTxPipeline, exec: c.processTxPipeline,
} }
pipe.statefulCmdable.setProcessor(pipe.Process) pipe.init()
return &pipe return &pipe
} }
@ -1391,7 +1213,11 @@ func (c *ClusterClient) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(fn) return c.TxPipeline().Pipelined(fn)
} }
func (c *ClusterClient) defaultProcessTxPipeline(cmds []Cmder) error { func (c *ClusterClient) processTxPipeline(cmds []Cmder) error {
return c.hooks.processPipeline(c.ctx, cmds, c._processTxPipeline)
}
func (c *ClusterClient) _processTxPipeline(cmds []Cmder) error {
state, err := c.state.Get() state, err := c.state.Get()
if err != nil { if err != nil {
return err return err
@ -1529,6 +1355,64 @@ func (c *ClusterClient) txPipelineReadQueued(
return nil return nil
} }
func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error {
if len(keys) == 0 {
return fmt.Errorf("redis: Watch requires at least one key")
}
slot := hashtag.Slot(keys[0])
for _, key := range keys[1:] {
if hashtag.Slot(key) != slot {
err := fmt.Errorf("redis: Watch requires all keys to be in the same slot")
return err
}
}
node, err := c.slotMasterNode(slot)
if err != nil {
return err
}
for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ {
if attempt > 0 {
time.Sleep(c.retryBackoff(attempt))
}
err = node.Client.Watch(fn, keys...)
if err == nil {
break
}
if err != Nil {
c.state.LazyReload()
}
moved, ask, addr := internal.IsMovedError(err)
if moved || ask {
node, err = c.nodes.GetOrCreate(addr)
if err != nil {
return err
}
continue
}
if err == pool.ErrClosed || internal.IsReadOnlyError(err) {
node, err = c.slotMasterNode(slot)
if err != nil {
return err
}
continue
}
if internal.IsRetryableError(err, true) {
continue
}
return err
}
return err
}
func (c *ClusterClient) pubSub() *PubSub { func (c *ClusterClient) pubSub() *PubSub {
var node *clusterNode var node *clusterNode
pubsub := &PubSub{ pubsub := &PubSub{
@ -1590,6 +1474,109 @@ func (c *ClusterClient) PSubscribe(channels ...string) *PubSub {
return pubsub return pubsub
} }
func (c *ClusterClient) retryBackoff(attempt int) time.Duration {
return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
}
func (c *ClusterClient) cmdsInfo() (map[string]*CommandInfo, error) {
addrs, err := c.nodes.Addrs()
if err != nil {
return nil, err
}
var firstErr error
for _, addr := range addrs {
node, err := c.nodes.Get(addr)
if err != nil {
return nil, err
}
if node == nil {
continue
}
info, err := node.Client.Command().Result()
if err == nil {
return info, nil
}
if firstErr == nil {
firstErr = err
}
}
return nil, firstErr
}
func (c *ClusterClient) cmdInfo(name string) *CommandInfo {
cmdsInfo, err := c.cmdsInfoCache.Get()
if err != nil {
return nil
}
info := cmdsInfo[name]
if info == nil {
internal.Logf("info for cmd=%s not found", name)
}
return info
}
func cmdSlot(cmd Cmder, pos int) int {
if pos == 0 {
return hashtag.RandomSlot()
}
firstKey := cmd.stringArg(pos)
return hashtag.Slot(firstKey)
}
func (c *ClusterClient) cmdSlot(cmd Cmder) int {
args := cmd.Args()
if args[0] == "cluster" && args[1] == "getkeysinslot" {
return args[2].(int)
}
cmdInfo := c.cmdInfo(cmd.Name())
return cmdSlot(cmd, cmdFirstKeyPos(cmd, cmdInfo))
}
func (c *ClusterClient) cmdSlotAndNode(cmd Cmder) (int, *clusterNode, error) {
state, err := c.state.Get()
if err != nil {
return 0, nil, err
}
cmdInfo := c.cmdInfo(cmd.Name())
slot := c.cmdSlot(cmd)
if c.opt.ReadOnly && cmdInfo != nil && cmdInfo.ReadOnly {
if c.opt.RouteByLatency {
node, err := state.slotClosestNode(slot)
return slot, node, err
}
if c.opt.RouteRandomly {
node := state.slotRandomNode(slot)
return slot, node, nil
}
node, err := state.slotSlaveNode(slot)
return slot, node, err
}
node, err := state.slotMasterNode(slot)
return slot, node, err
}
func (c *ClusterClient) slotMasterNode(slot int) (*clusterNode, error) {
state, err := c.state.Get()
if err != nil {
return nil, err
}
nodes := state.slotNodes(slot)
if len(nodes) > 0 {
return nodes[0], nil
}
return c.nodes.Random()
}
func appendUniqueNode(nodes []*clusterNode, node *clusterNode) []*clusterNode { func appendUniqueNode(nodes []*clusterNode, node *clusterNode) []*clusterNode {
for _, n := range nodes { for _, n := range nodes {
if n == node { if n == node {

View File

@ -1,13 +1,11 @@
package redis_test package redis_test
import ( import (
"bytes"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"testing"
"time" "time"
"github.com/go-redis/redis" "github.com/go-redis/redis"
@ -545,18 +543,6 @@ var _ = Describe("ClusterClient", func() {
Expect(stats).To(BeAssignableToTypeOf(&redis.PoolStats{})) Expect(stats).To(BeAssignableToTypeOf(&redis.PoolStats{}))
}) })
It("removes idle connections", func() {
stats := client.PoolStats()
Expect(stats.TotalConns).NotTo(BeZero())
Expect(stats.IdleConns).NotTo(BeZero())
time.Sleep(2 * time.Second)
stats = client.PoolStats()
Expect(stats.TotalConns).To(BeZero())
Expect(stats.IdleConns).To(BeZero())
})
It("returns an error when there are no attempts left", func() { It("returns an error when there are no attempts left", func() {
opt := redisClusterOptions() opt := redisClusterOptions()
opt.MaxRedirects = -1 opt.MaxRedirects = -1
@ -1054,92 +1040,3 @@ var _ = Describe("ClusterClient timeout", func() {
testTimeout() testTimeout()
}) })
}) })
//------------------------------------------------------------------------------
func newClusterScenario() *clusterScenario {
return &clusterScenario{
ports: []string{"8220", "8221", "8222", "8223", "8224", "8225"},
nodeIds: make([]string, 6),
processes: make(map[string]*redisProcess, 6),
clients: make(map[string]*redis.Client, 6),
}
}
func BenchmarkClusterPing(b *testing.B) {
if testing.Short() {
b.Skip("skipping in short mode")
}
cluster := newClusterScenario()
if err := startCluster(cluster); err != nil {
b.Fatal(err)
}
defer stopCluster(cluster)
client := cluster.clusterClient(redisClusterOptions())
defer client.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
err := client.Ping().Err()
if err != nil {
b.Fatal(err)
}
}
})
}
func BenchmarkClusterSetString(b *testing.B) {
if testing.Short() {
b.Skip("skipping in short mode")
}
cluster := newClusterScenario()
if err := startCluster(cluster); err != nil {
b.Fatal(err)
}
defer stopCluster(cluster)
client := cluster.clusterClient(redisClusterOptions())
defer client.Close()
value := string(bytes.Repeat([]byte{'1'}, 10000))
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
err := client.Set("key", value, 0).Err()
if err != nil {
b.Fatal(err)
}
}
})
}
func BenchmarkClusterReloadState(b *testing.B) {
if testing.Short() {
b.Skip("skipping in short mode")
}
cluster := newClusterScenario()
if err := startCluster(cluster); err != nil {
b.Fatal(err)
}
defer stopCluster(cluster)
client := cluster.clusterClient(redisClusterOptions())
defer client.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err := client.ReloadState()
if err != nil {
b.Fatal(err)
}
}
}

View File

@ -100,8 +100,14 @@ type baseCmd struct {
var _ Cmder = (*Cmd)(nil) var _ Cmder = (*Cmd)(nil)
func (cmd *baseCmd) Err() error { func (cmd *baseCmd) Name() string {
return cmd.err if len(cmd._args) > 0 {
// Cmd name must be lower cased.
s := internal.ToLower(cmd.stringArg(0))
cmd._args[0] = s
return s
}
return ""
} }
func (cmd *baseCmd) Args() []interface{} { func (cmd *baseCmd) Args() []interface{} {
@ -116,14 +122,8 @@ func (cmd *baseCmd) stringArg(pos int) string {
return s return s
} }
func (cmd *baseCmd) Name() string { func (cmd *baseCmd) Err() error {
if len(cmd._args) > 0 { return cmd.err
// Cmd name must be lower cased.
s := internal.ToLower(cmd.stringArg(0))
cmd._args[0] = s
return s
}
return ""
} }
func (cmd *baseCmd) readTimeout() *time.Duration { func (cmd *baseCmd) readTimeout() *time.Duration {

File diff suppressed because it is too large Load Diff

View File

@ -1506,8 +1506,8 @@ var _ = Describe("Commands", func() {
Expect(client.Ping().Err()).NotTo(HaveOccurred()) Expect(client.Ping().Err()).NotTo(HaveOccurred())
stats := client.PoolStats() stats := client.PoolStats()
Expect(stats.Hits).To(Equal(uint32(1))) Expect(stats.Hits).To(Equal(uint32(2)))
Expect(stats.Misses).To(Equal(uint32(2))) Expect(stats.Misses).To(Equal(uint32(1)))
Expect(stats.Timeouts).To(Equal(uint32(0))) Expect(stats.Timeouts).To(Equal(uint32(0)))
}) })
@ -2219,8 +2219,8 @@ var _ = Describe("Commands", func() {
Expect(client.Ping().Err()).NotTo(HaveOccurred()) Expect(client.Ping().Err()).NotTo(HaveOccurred())
stats := client.PoolStats() stats := client.PoolStats()
Expect(stats.Hits).To(Equal(uint32(1))) Expect(stats.Hits).To(Equal(uint32(2)))
Expect(stats.Misses).To(Equal(uint32(2))) Expect(stats.Misses).To(Equal(uint32(1)))
Expect(stats.Timeouts).To(Equal(uint32(0))) Expect(stats.Timeouts).To(Equal(uint32(0)))
}) })
@ -2301,8 +2301,8 @@ var _ = Describe("Commands", func() {
Expect(client.Ping().Err()).NotTo(HaveOccurred()) Expect(client.Ping().Err()).NotTo(HaveOccurred())
stats := client.PoolStats() stats := client.PoolStats()
Expect(stats.Hits).To(Equal(uint32(1))) Expect(stats.Hits).To(Equal(uint32(2)))
Expect(stats.Misses).To(Equal(uint32(2))) Expect(stats.Misses).To(Equal(uint32(1)))
Expect(stats.Timeouts).To(Equal(uint32(0))) Expect(stats.Timeouts).To(Equal(uint32(0)))
}) })

View File

@ -1,44 +1,54 @@
package redis_test package redis_test
import ( import (
"context"
"fmt" "fmt"
"github.com/go-redis/redis" "github.com/go-redis/redis"
) )
type redisHook struct{}
var _ redis.Hook = redisHook{}
func (redisHook) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
fmt.Printf("starting processing: <%s>\n", cmd)
return ctx, nil
}
func (redisHook) AfterProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
fmt.Printf("finished processing: <%s>\n", cmd)
return ctx, nil
}
func (redisHook) BeforeProcessPipeline(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
fmt.Printf("pipeline starting processing: %v\n", cmds)
return ctx, nil
}
func (redisHook) AfterProcessPipeline(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
fmt.Printf("pipeline finished processing: %v\n", cmds)
return ctx, nil
}
func Example_instrumentation() { func Example_instrumentation() {
redisdb := redis.NewClient(&redis.Options{ rdb := redis.NewClient(&redis.Options{
Addr: ":6379", Addr: ":6379",
}) })
redisdb.WrapProcess(func(old func(cmd redis.Cmder) error) func(cmd redis.Cmder) error { rdb.AddHook(redisHook{})
return func(cmd redis.Cmder) error {
fmt.Printf("starting processing: <%s>\n", cmd)
err := old(cmd)
fmt.Printf("finished processing: <%s>\n", cmd)
return err
}
})
redisdb.Ping() rdb.Ping()
// Output: starting processing: <ping: > // Output: starting processing: <ping: >
// finished processing: <ping: PONG> // finished processing: <ping: PONG>
} }
func ExamplePipeline_instrumentation() { func ExamplePipeline_instrumentation() {
redisdb := redis.NewClient(&redis.Options{ rdb := redis.NewClient(&redis.Options{
Addr: ":6379", Addr: ":6379",
}) })
rdb.AddHook(redisHook{})
redisdb.WrapProcessPipeline(func(old func([]redis.Cmder) error) func([]redis.Cmder) error { rdb.Pipelined(func(pipe redis.Pipeliner) error {
return func(cmds []redis.Cmder) error {
fmt.Printf("pipeline starting processing: %v\n", cmds)
err := old(cmds)
fmt.Printf("pipeline finished processing: %v\n", cmds)
return err
}
})
redisdb.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Ping() pipe.Ping()
pipe.Ping() pipe.Ping()
return nil return nil

View File

@ -13,14 +13,13 @@ var noDeadline = time.Time{}
type Conn struct { type Conn struct {
netConn net.Conn netConn net.Conn
rd *proto.Reader rd *proto.Reader
rdLocked bool wr *proto.Writer
wr *proto.Writer
Inited bool Inited bool
pooled bool pooled bool
createdAt time.Time createdAt time.Time
usedAt atomic.Value usedAt int64 // atomic
} }
func NewConn(netConn net.Conn) *Conn { func NewConn(netConn net.Conn) *Conn {
@ -35,11 +34,12 @@ func NewConn(netConn net.Conn) *Conn {
} }
func (cn *Conn) UsedAt() time.Time { func (cn *Conn) UsedAt() time.Time {
return cn.usedAt.Load().(time.Time) unix := atomic.LoadInt64(&cn.usedAt)
return time.Unix(unix, 0)
} }
func (cn *Conn) SetUsedAt(tm time.Time) { func (cn *Conn) SetUsedAt(tm time.Time) {
cn.usedAt.Store(tm) atomic.StoreInt64(&cn.usedAt, tm.Unix())
} }
func (cn *Conn) SetNetConn(netConn net.Conn) { func (cn *Conn) SetNetConn(netConn net.Conn) {

View File

@ -113,8 +113,8 @@ func redisOptions() *redis.Options {
WriteTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second,
PoolSize: 10, PoolSize: 10,
PoolTimeout: 30 * time.Second, PoolTimeout: 30 * time.Second,
IdleTimeout: 500 * time.Millisecond, IdleTimeout: time.Minute,
IdleCheckFrequency: 500 * time.Millisecond, IdleCheckFrequency: 100 * time.Millisecond,
} }
} }
@ -125,8 +125,8 @@ func redisClusterOptions() *redis.ClusterOptions {
WriteTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second,
PoolSize: 10, PoolSize: 10,
PoolTimeout: 30 * time.Second, PoolTimeout: 30 * time.Second,
IdleTimeout: 500 * time.Millisecond, IdleTimeout: time.Minute,
IdleCheckFrequency: 500 * time.Millisecond, IdleCheckFrequency: 100 * time.Millisecond,
} }
} }
@ -141,8 +141,8 @@ func redisRingOptions() *redis.RingOptions {
WriteTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second,
PoolSize: 10, PoolSize: 10,
PoolTimeout: 30 * time.Second, PoolTimeout: 30 * time.Second,
IdleTimeout: 500 * time.Millisecond, IdleTimeout: time.Minute,
IdleCheckFrequency: 500 * time.Millisecond, IdleCheckFrequency: 100 * time.Millisecond,
} }
} }

View File

@ -36,6 +36,7 @@ var _ Pipeliner = (*Pipeline)(nil)
// http://redis.io/topics/pipelining. It's safe for concurrent use // http://redis.io/topics/pipelining. It's safe for concurrent use
// by multiple goroutines. // by multiple goroutines.
type Pipeline struct { type Pipeline struct {
cmdable
statefulCmdable statefulCmdable
exec pipelineExecer exec pipelineExecer
@ -45,6 +46,11 @@ type Pipeline struct {
closed bool closed bool
} }
func (c *Pipeline) init() {
c.cmdable = c.Process
c.statefulCmdable = c.Process
}
func (c *Pipeline) Do(args ...interface{}) *Cmd { func (c *Pipeline) Do(args ...interface{}) *Cmd {
cmd := NewCmd(args...) cmd := NewCmd(args...)
_ = c.Process(cmd) _ = c.Process(cmd)

View File

@ -16,8 +16,8 @@ var _ = Describe("pool", func() {
opt := redisOptions() opt := redisOptions()
opt.MinIdleConns = 0 opt.MinIdleConns = 0
opt.MaxConnAge = 0 opt.MaxConnAge = 0
opt.IdleTimeout = time.Second
client = redis.NewClient(opt) client = redis.NewClient(opt)
Expect(client.FlushDB().Err()).NotTo(HaveOccurred())
}) })
AfterEach(func() { AfterEach(func() {
@ -98,7 +98,7 @@ var _ = Describe("pool", func() {
Expect(pool.IdleLen()).To(Equal(1)) Expect(pool.IdleLen()).To(Equal(1))
stats := pool.Stats() stats := pool.Stats()
Expect(stats.Hits).To(Equal(uint32(2))) Expect(stats.Hits).To(Equal(uint32(1)))
Expect(stats.Misses).To(Equal(uint32(2))) Expect(stats.Misses).To(Equal(uint32(2)))
Expect(stats.Timeouts).To(Equal(uint32(0))) Expect(stats.Timeouts).To(Equal(uint32(0)))
}) })
@ -115,12 +115,15 @@ var _ = Describe("pool", func() {
Expect(pool.IdleLen()).To(Equal(1)) Expect(pool.IdleLen()).To(Equal(1))
stats := pool.Stats() stats := pool.Stats()
Expect(stats.Hits).To(Equal(uint32(100))) Expect(stats.Hits).To(Equal(uint32(99)))
Expect(stats.Misses).To(Equal(uint32(1))) Expect(stats.Misses).To(Equal(uint32(1)))
Expect(stats.Timeouts).To(Equal(uint32(0))) Expect(stats.Timeouts).To(Equal(uint32(0)))
}) })
It("removes idle connections", func() { It("removes idle connections", func() {
err := client.Ping().Err()
Expect(err).NotTo(HaveOccurred())
stats := client.PoolStats() stats := client.PoolStats()
Expect(stats).To(Equal(&redis.PoolStats{ Expect(stats).To(Equal(&redis.PoolStats{
Hits: 0, Hits: 0,

View File

@ -78,7 +78,7 @@ var _ = Describe("PubSub", func() {
} }
stats := client.PoolStats() stats := client.PoolStats()
Expect(stats.Misses).To(Equal(uint32(2))) Expect(stats.Misses).To(Equal(uint32(1)))
}) })
It("should pub/sub channels", func() { It("should pub/sub channels", func() {
@ -201,7 +201,7 @@ var _ = Describe("PubSub", func() {
} }
stats := client.PoolStats() stats := client.PoolStats()
Expect(stats.Misses).To(Equal(uint32(2))) Expect(stats.Misses).To(Equal(uint32(1)))
}) })
It("should ping/pong", func() { It("should ping/pong", func() {

194
redis.go
View File

@ -23,24 +23,114 @@ func SetLogger(logger *log.Logger) {
internal.Logger = logger internal.Logger = logger
} }
//------------------------------------------------------------------------------
type Hook interface {
BeforeProcess(ctx context.Context, cmd Cmder) (context.Context, error)
AfterProcess(ctx context.Context, cmd Cmder) (context.Context, error)
BeforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error)
AfterProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error)
}
type hooks struct {
hooks []Hook
}
func (hs *hooks) lazyCopy() {
hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)]
}
func (hs *hooks) AddHook(hook Hook) {
hs.hooks = append(hs.hooks, hook)
}
func (hs hooks) process(ctx context.Context, cmd Cmder, fn func(Cmder) error) error {
ctx, err := hs.beforeProcess(ctx, cmd)
if err != nil {
return err
}
cmdErr := fn(cmd)
_, err = hs.afterProcess(ctx, cmd)
if err != nil {
return err
}
return cmdErr
}
func (hs hooks) beforeProcess(ctx context.Context, cmd Cmder) (context.Context, error) {
for _, h := range hs.hooks {
var err error
ctx, err = h.BeforeProcess(ctx, cmd)
if err != nil {
return nil, err
}
}
return ctx, nil
}
func (hs hooks) afterProcess(ctx context.Context, cmd Cmder) (context.Context, error) {
for _, h := range hs.hooks {
var err error
ctx, err = h.AfterProcess(ctx, cmd)
if err != nil {
return nil, err
}
}
return ctx, nil
}
func (hs hooks) processPipeline(ctx context.Context, cmds []Cmder, fn func([]Cmder) error) error {
ctx, err := hs.beforeProcessPipeline(ctx, cmds)
if err != nil {
return err
}
cmdsErr := fn(cmds)
_, err = hs.afterProcessPipeline(ctx, cmds)
if err != nil {
return err
}
return cmdsErr
}
func (hs hooks) beforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) {
for _, h := range hs.hooks {
var err error
ctx, err = h.BeforeProcessPipeline(ctx, cmds)
if err != nil {
return nil, err
}
}
return ctx, nil
}
func (hs hooks) afterProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) {
for _, h := range hs.hooks {
var err error
ctx, err = h.AfterProcessPipeline(ctx, cmds)
if err != nil {
return nil, err
}
}
return ctx, nil
}
//------------------------------------------------------------------------------
type baseClient struct { type baseClient struct {
opt *Options opt *Options
connPool pool.Pooler connPool pool.Pooler
limiter Limiter limiter Limiter
process func(Cmder) error
processPipeline func([]Cmder) error
processTxPipeline func([]Cmder) error
onClose func() error // hook called when client is closed onClose func() error // hook called when client is closed
} }
func (c *baseClient) init() {
c.process = c.defaultProcess
c.processPipeline = c.defaultProcessPipeline
c.processTxPipeline = c.defaultProcessTxPipeline
}
func (c *baseClient) String() string { func (c *baseClient) String() string {
return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB) return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
} }
@ -159,22 +249,11 @@ func (c *baseClient) initConn(cn *pool.Conn) error {
// Do creates a Cmd from the args and processes the cmd. // Do creates a Cmd from the args and processes the cmd.
func (c *baseClient) Do(args ...interface{}) *Cmd { func (c *baseClient) Do(args ...interface{}) *Cmd {
cmd := NewCmd(args...) cmd := NewCmd(args...)
_ = c.Process(cmd) _ = c.process(cmd)
return cmd return cmd
} }
// WrapProcess wraps function that processes Redis commands. func (c *baseClient) process(cmd Cmder) error {
func (c *baseClient) WrapProcess(
fn func(oldProcess func(cmd Cmder) error) func(cmd Cmder) error,
) {
c.process = fn(c.process)
}
func (c *baseClient) Process(cmd Cmder) error {
return c.process(cmd)
}
func (c *baseClient) defaultProcess(cmd Cmder) error {
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
if attempt > 0 { if attempt > 0 {
time.Sleep(c.retryBackoff(attempt)) time.Sleep(c.retryBackoff(attempt))
@ -249,18 +328,11 @@ func (c *baseClient) getAddr() string {
return c.opt.Addr return c.opt.Addr
} }
func (c *baseClient) WrapProcessPipeline( func (c *baseClient) processPipeline(cmds []Cmder) error {
fn func(oldProcess func([]Cmder) error) func([]Cmder) error,
) {
c.processPipeline = fn(c.processPipeline)
c.processTxPipeline = fn(c.processTxPipeline)
}
func (c *baseClient) defaultProcessPipeline(cmds []Cmder) error {
return c.generalProcessPipeline(cmds, c.pipelineProcessCmds) return c.generalProcessPipeline(cmds, c.pipelineProcessCmds)
} }
func (c *baseClient) defaultProcessTxPipeline(cmds []Cmder) error { func (c *baseClient) processTxPipeline(cmds []Cmder) error {
return c.generalProcessPipeline(cmds, c.txPipelineProcessCmds) return c.generalProcessPipeline(cmds, c.txPipelineProcessCmds)
} }
@ -380,13 +452,17 @@ func txPipelineReadQueued(rd *proto.Reader, cmds []Cmder) error {
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
type client struct {
baseClient
cmdable
hooks
}
// Client is a Redis client representing a pool of zero or more // Client is a Redis client representing a pool of zero or more
// underlying connections. It's safe for concurrent use by multiple // underlying connections. It's safe for concurrent use by multiple
// goroutines. // goroutines.
type Client struct { type Client struct {
baseClient *client
cmdable
ctx context.Context ctx context.Context
} }
@ -395,19 +471,20 @@ func NewClient(opt *Options) *Client {
opt.init() opt.init()
c := Client{ c := Client{
baseClient: baseClient{ client: &client{
opt: opt, baseClient: baseClient{
connPool: newConnPool(opt), opt: opt,
connPool: newConnPool(opt),
},
}, },
} }
c.baseClient.init()
c.init() c.init()
return &c return &c
} }
func (c *Client) init() { func (c *Client) init() {
c.cmdable.setProcessor(c.Process) c.cmdable = c.Process
} }
func (c *Client) Context() context.Context { func (c *Client) Context() context.Context {
@ -421,15 +498,21 @@ func (c *Client) WithContext(ctx context.Context) *Client {
if ctx == nil { if ctx == nil {
panic("nil context") panic("nil context")
} }
c2 := c.clone() clone := *c
c2.ctx = ctx clone.ctx = ctx
return c2 return &clone
} }
func (c *Client) clone() *Client { func (c *Client) Process(cmd Cmder) error {
cp := *c return c.hooks.process(c.ctx, cmd, c.baseClient.process)
cp.init() }
return &cp
func (c *Client) processPipeline(cmds []Cmder) error {
return c.hooks.processPipeline(c.ctx, cmds, c.baseClient.processPipeline)
}
func (c *Client) processTxPipeline(cmds []Cmder) error {
return c.hooks.processPipeline(c.ctx, cmds, c.baseClient.processTxPipeline)
} }
// Options returns read-only Options that were used to create the client. // Options returns read-only Options that were used to create the client.
@ -458,7 +541,7 @@ func (c *Client) Pipeline() Pipeliner {
pipe := Pipeline{ pipe := Pipeline{
exec: c.processPipeline, exec: c.processPipeline,
} }
pipe.statefulCmdable.setProcessor(pipe.Process) pipe.init()
return &pipe return &pipe
} }
@ -471,7 +554,7 @@ func (c *Client) TxPipeline() Pipeliner {
pipe := Pipeline{ pipe := Pipeline{
exec: c.processTxPipeline, exec: c.processTxPipeline,
} }
pipe.statefulCmdable.setProcessor(pipe.Process) pipe.init()
return &pipe return &pipe
} }
@ -537,6 +620,7 @@ func (c *Client) PSubscribe(channels ...string) *PubSub {
// Conn is like Client, but its pool contains single connection. // Conn is like Client, but its pool contains single connection.
type Conn struct { type Conn struct {
baseClient baseClient
cmdable
statefulCmdable statefulCmdable
} }
@ -547,11 +631,15 @@ func newConn(opt *Options, cn *pool.Conn) *Conn {
connPool: pool.NewSingleConnPool(cn), connPool: pool.NewSingleConnPool(cn),
}, },
} }
c.baseClient.init() c.cmdable = c.Process
c.statefulCmdable.setProcessor(c.Process) c.statefulCmdable = c.Process
return &c return &c
} }
func (c *Conn) Process(cmd Cmder) error {
return c.baseClient.process(cmd)
}
func (c *Conn) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { func (c *Conn) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(fn) return c.Pipeline().Pipelined(fn)
} }
@ -560,7 +648,7 @@ func (c *Conn) Pipeline() Pipeliner {
pipe := Pipeline{ pipe := Pipeline{
exec: c.processPipeline, exec: c.processPipeline,
} }
pipe.statefulCmdable.setProcessor(pipe.Process) pipe.init()
return &pipe return &pipe
} }
@ -573,6 +661,6 @@ func (c *Conn) TxPipeline() Pipeliner {
pipe := Pipeline{ pipe := Pipeline{
exec: c.processTxPipeline, exec: c.processTxPipeline,
} }
pipe.statefulCmdable.setProcessor(pipe.Process) pipe.init()
return &pipe return &pipe
} }

View File

@ -191,6 +191,8 @@ var _ = Describe("Client", func() {
client.Pool().Put(cn) client.Pool().Put(cn)
Expect(cn.UsedAt().Equal(createdAt)).To(BeTrue()) Expect(cn.UsedAt().Equal(createdAt)).To(BeTrue())
time.Sleep(time.Second)
err = client.Ping().Err() err = client.Ping().Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -224,43 +226,6 @@ var _ = Describe("Client", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(got).To(Equal(bigVal)) Expect(got).To(Equal(bigVal))
}) })
It("should call WrapProcess", func() {
var fnCalled bool
client.WrapProcess(func(old func(redis.Cmder) error) func(redis.Cmder) error {
return func(cmd redis.Cmder) error {
fnCalled = true
return old(cmd)
}
})
Expect(client.Ping().Err()).NotTo(HaveOccurred())
Expect(fnCalled).To(BeTrue())
})
It("should call WrapProcess after WithContext", func() {
var fn1Called, fn2Called bool
client.WrapProcess(func(old func(cmd redis.Cmder) error) func(cmd redis.Cmder) error {
return func(cmd redis.Cmder) error {
fn1Called = true
return old(cmd)
}
})
client2 := client.WithContext(client.Context())
client2.WrapProcess(func(old func(cmd redis.Cmder) error) func(cmd redis.Cmder) error {
return func(cmd redis.Cmder) error {
fn2Called = true
return old(cmd)
}
})
Expect(client2.Ping().Err()).NotTo(HaveOccurred())
Expect(fn2Called).To(BeTrue())
Expect(fn1Called).To(BeTrue())
})
}) })
var _ = Describe("Client timeout", func() { var _ = Describe("Client timeout", func() {

98
ring.go
View File

@ -323,6 +323,14 @@ func (c *ringShards) Close() error {
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
type ring struct {
cmdable
hooks
opt *RingOptions
shards *ringShards
cmdsInfoCache *cmdsInfoCache
}
// Ring is a Redis client that uses consistent hashing to distribute // Ring is a Redis client that uses consistent hashing to distribute
// keys across multiple Redis servers (shards). It's safe for // keys across multiple Redis servers (shards). It's safe for
// concurrent use by multiple goroutines. // concurrent use by multiple goroutines.
@ -338,32 +346,23 @@ func (c *ringShards) Close() error {
// and can tolerate losing data when one of the servers dies. // and can tolerate losing data when one of the servers dies.
// Otherwise you should use Redis Cluster. // Otherwise you should use Redis Cluster.
type Ring struct { type Ring struct {
cmdable *ring
ctx context.Context ctx context.Context
opt *RingOptions
shards *ringShards
cmdsInfoCache *cmdsInfoCache
process func(Cmder) error
processPipeline func([]Cmder) error
} }
func NewRing(opt *RingOptions) *Ring { func NewRing(opt *RingOptions) *Ring {
opt.init() opt.init()
ring := &Ring{ ring := Ring{
opt: opt, ring: &ring{
shards: newRingShards(opt), opt: opt,
shards: newRingShards(opt),
},
} }
ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo)
ring.process = ring.defaultProcess
ring.processPipeline = ring.defaultProcessPipeline
ring.init() ring.init()
ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo)
for name, addr := range opt.Addrs { for name, addr := range opt.Addrs {
clopt := opt.clientOptions() clopt := opt.clientOptions()
clopt.Addr = addr clopt.Addr = addr
@ -372,11 +371,11 @@ func NewRing(opt *RingOptions) *Ring {
go ring.shards.Heartbeat(opt.HeartbeatFrequency) go ring.shards.Heartbeat(opt.HeartbeatFrequency)
return ring return &ring
} }
func (c *Ring) init() { func (c *Ring) init() {
c.cmdable.setProcessor(c.Process) c.cmdable = c.Process
} }
func (c *Ring) Context() context.Context { func (c *Ring) Context() context.Context {
@ -390,16 +389,20 @@ func (c *Ring) WithContext(ctx context.Context) *Ring {
if ctx == nil { if ctx == nil {
panic("nil context") panic("nil context")
} }
c2 := c.clone() clone := *c
c2.ctx = ctx clone.ctx = ctx
return c2 return &clone
} }
func (c *Ring) clone() *Ring { // Do creates a Cmd from the args and processes the cmd.
cp := *c func (c *Ring) Do(args ...interface{}) *Cmd {
cp.init() cmd := NewCmd(args...)
c.Process(cmd)
return cmd
}
return &cp func (c *Ring) Process(cmd Cmder) error {
return c.hooks.process(c.ctx, cmd, c.process)
} }
// Options returns read-only Options that were used to create the client. // Options returns read-only Options that were used to create the client.
@ -529,24 +532,7 @@ func (c *Ring) cmdShard(cmd Cmder) (*ringShard, error) {
return c.shards.GetByKey(firstKey) return c.shards.GetByKey(firstKey)
} }
// Do creates a Cmd from the args and processes the cmd. func (c *Ring) process(cmd Cmder) error {
func (c *Ring) Do(args ...interface{}) *Cmd {
cmd := NewCmd(args...)
c.Process(cmd)
return cmd
}
func (c *Ring) WrapProcess(
fn func(oldProcess func(cmd Cmder) error) func(cmd Cmder) error,
) {
c.process = fn(c.process)
}
func (c *Ring) Process(cmd Cmder) error {
return c.process(cmd)
}
func (c *Ring) defaultProcess(cmd Cmder) error {
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
if attempt > 0 { if attempt > 0 {
time.Sleep(c.retryBackoff(attempt)) time.Sleep(c.retryBackoff(attempt))
@ -569,25 +555,23 @@ func (c *Ring) defaultProcess(cmd Cmder) error {
return cmd.Err() return cmd.Err()
} }
func (c *Ring) Pipeline() Pipeliner {
pipe := Pipeline{
exec: c.processPipeline,
}
pipe.cmdable.setProcessor(pipe.Process)
return &pipe
}
func (c *Ring) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { func (c *Ring) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(fn) return c.Pipeline().Pipelined(fn)
} }
func (c *Ring) WrapProcessPipeline( func (c *Ring) Pipeline() Pipeliner {
fn func(oldProcess func([]Cmder) error) func([]Cmder) error, pipe := Pipeline{
) { exec: c.processPipeline,
c.processPipeline = fn(c.processPipeline) }
pipe.init()
return &pipe
} }
func (c *Ring) defaultProcessPipeline(cmds []Cmder) error { func (c *Ring) processPipeline(cmds []Cmder) error {
return c.hooks.processPipeline(c.ctx, cmds, c._processPipeline)
}
func (c *Ring) _processPipeline(cmds []Cmder) error {
cmdsMap := make(map[string][]Cmder) cmdsMap := make(map[string][]Cmder)
for _, cmd := range cmds { for _, cmd := range cmds {
cmdInfo := c.cmdInfo(cmd.Name()) cmdInfo := c.cmdInfo(cmd.Name())

View File

@ -1,7 +1,6 @@
package redis_test package redis_test
import ( import (
"context"
"crypto/rand" "crypto/rand"
"fmt" "fmt"
"net" "net"
@ -105,27 +104,6 @@ var _ = Describe("Redis Ring", func() {
Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=100")) Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=100"))
}) })
It("propagates process for WithContext", func() {
var fromWrap []string
wrapper := func(oldProcess func(cmd redis.Cmder) error) func(cmd redis.Cmder) error {
return func(cmd redis.Cmder) error {
fromWrap = append(fromWrap, cmd.Name())
return oldProcess(cmd)
}
}
ctx := context.Background()
ring = ring.WithContext(ctx)
ring.WrapProcess(wrapper)
ring.Ping()
Expect(fromWrap).To(Equal([]string{"ping"}))
ring.Ping()
Expect(fromWrap).To(Equal([]string{"ping", "ping"}))
})
Describe("pipeline", func() { Describe("pipeline", func() {
It("distributes keys", func() { It("distributes keys", func() {
pipe := ring.Pipeline() pipe := ring.Pipeline()

View File

@ -1,6 +1,7 @@
package redis package redis
import ( import (
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"net" "net"
@ -86,15 +87,15 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
} }
c := Client{ c := Client{
baseClient: baseClient{ client: &client{
opt: opt, baseClient: baseClient{
connPool: failover.Pool(), opt: opt,
connPool: failover.Pool(),
onClose: failover.Close, onClose: failover.Close,
},
}, },
} }
c.baseClient.init() c.cmdable = c.Process
c.cmdable.setProcessor(c.Process)
return &c return &c
} }
@ -102,21 +103,41 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
type SentinelClient struct { type SentinelClient struct {
baseClient *baseClient
ctx context.Context
} }
func NewSentinelClient(opt *Options) *SentinelClient { func NewSentinelClient(opt *Options) *SentinelClient {
opt.init() opt.init()
c := &SentinelClient{ c := &SentinelClient{
baseClient: baseClient{ baseClient: &baseClient{
opt: opt, opt: opt,
connPool: newConnPool(opt), connPool: newConnPool(opt),
}, },
} }
c.baseClient.init()
return c return c
} }
func (c *SentinelClient) Context() context.Context {
if c.ctx != nil {
return c.ctx
}
return context.Background()
}
func (c *SentinelClient) WithContext(ctx context.Context) *SentinelClient {
if ctx == nil {
panic("nil context")
}
clone := *c
clone.ctx = ctx
return &clone
}
func (c *SentinelClient) Process(cmd Cmder) error {
return c.baseClient.process(cmd)
}
func (c *SentinelClient) pubSub() *PubSub { func (c *SentinelClient) pubSub() *PubSub {
pubsub := &PubSub{ pubsub := &PubSub{
opt: c.opt, opt: c.opt,

36
tx.go
View File

@ -1,6 +1,8 @@
package redis package redis
import ( import (
"context"
"github.com/go-redis/redis/internal/pool" "github.com/go-redis/redis/internal/pool"
"github.com/go-redis/redis/internal/proto" "github.com/go-redis/redis/internal/proto"
) )
@ -13,8 +15,11 @@ const TxFailedErr = proto.RedisError("redis: transaction failed")
// by multiple goroutines, because Exec resets list of watched keys. // by multiple goroutines, because Exec resets list of watched keys.
// If you don't need WATCH it is better to use Pipeline. // If you don't need WATCH it is better to use Pipeline.
type Tx struct { type Tx struct {
cmdable
statefulCmdable statefulCmdable
baseClient baseClient
ctx context.Context
} }
func (c *Client) newTx() *Tx { func (c *Client) newTx() *Tx {
@ -23,12 +28,37 @@ func (c *Client) newTx() *Tx {
opt: c.opt, opt: c.opt,
connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), true), connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), true),
}, },
ctx: c.ctx,
} }
tx.baseClient.init() tx.init()
tx.statefulCmdable.setProcessor(tx.Process)
return &tx return &tx
} }
func (c *Tx) init() {
c.cmdable = c.Process
c.statefulCmdable = c.Process
}
func (c *Tx) Context() context.Context {
if c.ctx != nil {
return c.ctx
}
return context.Background()
}
func (c *Tx) WithContext(ctx context.Context) *Tx {
if ctx == nil {
panic("nil context")
}
clone := *c
clone.ctx = ctx
return &clone
}
func (c *Tx) Process(cmd Cmder) error {
return c.baseClient.process(cmd)
}
// Watch prepares a transaction and marks the keys to be watched // Watch prepares a transaction and marks the keys to be watched
// for conditional execution if there are any keys. // for conditional execution if there are any keys.
// //
@ -83,7 +113,7 @@ func (c *Tx) Pipeline() Pipeliner {
pipe := Pipeline{ pipe := Pipeline{
exec: c.processTxPipeline, exec: c.processTxPipeline,
} }
pipe.statefulCmdable.setProcessor(pipe.Process) pipe.init()
return &pipe return &pipe
} }

View File

@ -1,6 +1,7 @@
package redis package redis
import ( import (
"context"
"crypto/tls" "crypto/tls"
"time" "time"
) )
@ -147,15 +148,15 @@ func (o *UniversalOptions) simple() *Options {
// -------------------------------------------------------------------- // --------------------------------------------------------------------
// UniversalClient is an abstract client which - based on the provided options - // UniversalClient is an abstract client which - based on the provided options -
// can connect to either clusters, or sentinel-backed failover instances or simple // can connect to either clusters, or sentinel-backed failover instances
// single-instance servers. This can be useful for testing cluster-specific // or simple single-instance servers. This can be useful for testing
// applications locally. // cluster-specific applications locally.
type UniversalClient interface { type UniversalClient interface {
Cmdable Cmdable
Context() context.Context
AddHook(Hook)
Watch(fn func(*Tx) error, keys ...string) error Watch(fn func(*Tx) error, keys ...string) error
Process(cmd Cmder) error Process(cmd Cmder) error
WrapProcess(fn func(oldProcess func(cmd Cmder) error) func(cmd Cmder) error)
WrapProcessPipeline(fn func(oldProcess func([]Cmder) error) func([]Cmder) error)
Subscribe(channels ...string) *PubSub Subscribe(channels ...string) *PubSub
PSubscribe(channels ...string) *PubSub PSubscribe(channels ...string) *PubSub
Close() error Close() error
@ -163,6 +164,7 @@ type UniversalClient interface {
var _ UniversalClient = (*Client)(nil) var _ UniversalClient = (*Client)(nil)
var _ UniversalClient = (*ClusterClient)(nil) var _ UniversalClient = (*ClusterClient)(nil)
var _ UniversalClient = (*Ring)(nil)
// NewUniversalClient returns a new multi client. The type of client returned depends // NewUniversalClient returns a new multi client. The type of client returned depends
// on the following three conditions: // on the following three conditions: