Allow passing context where possible

This commit is contained in:
Vladimir Mihailenco 2019-06-04 13:30:47 +03:00
parent 3da4357c0c
commit 09eb108738
9 changed files with 98 additions and 51 deletions

View File

@ -724,16 +724,24 @@ func (c *ClusterClient) Close() error {
// Do creates a Cmd from the args and processes the cmd. // Do creates a Cmd from the args and processes the cmd.
func (c *ClusterClient) Do(args ...interface{}) *Cmd { func (c *ClusterClient) Do(args ...interface{}) *Cmd {
return c.DoContext(c.ctx, args...)
}
func (c *ClusterClient) DoContext(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(args...) cmd := NewCmd(args...)
c.Process(cmd) c.ProcessContext(ctx, cmd)
return cmd return cmd
} }
func (c *ClusterClient) Process(cmd Cmder) error { func (c *ClusterClient) Process(cmd Cmder) error {
return c.hooks.process(c.ctx, cmd, c.process) return c.ProcessContext(c.ctx, cmd)
} }
func (c *ClusterClient) process(cmd Cmder) error { func (c *ClusterClient) ProcessContext(ctx context.Context, cmd Cmder) error {
return c.hooks.process(ctx, cmd, c.process)
}
func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error {
var node *clusterNode var node *clusterNode
var ask bool var ask bool
for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ {
@ -755,11 +763,11 @@ func (c *ClusterClient) process(cmd Cmder) error {
pipe := node.Client.Pipeline() pipe := node.Client.Pipeline()
_ = pipe.Process(NewCmd("ASKING")) _ = pipe.Process(NewCmd("ASKING"))
_ = pipe.Process(cmd) _ = pipe.Process(cmd)
_, err = pipe.Exec() _, err = pipe.ExecContext(ctx)
_ = pipe.Close() _ = pipe.Close()
ask = false ask = false
} else { } else {
err = node.Client.Process(cmd) err = node.Client.ProcessContext(ctx, cmd)
} }
// If there is no error - we are done. // If there is no error - we are done.
@ -1022,11 +1030,11 @@ func (c *ClusterClient) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(fn) return c.Pipeline().Pipelined(fn)
} }
func (c *ClusterClient) processPipeline(cmds []Cmder) error { func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error {
return c.hooks.processPipeline(c.ctx, cmds, c._processPipeline) return c.hooks.processPipeline(c.ctx, cmds, c._processPipeline)
} }
func (c *ClusterClient) _processPipeline(cmds []Cmder) error { func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) error {
cmdsMap := newCmdsMap() cmdsMap := newCmdsMap()
err := c.mapCmdsByNode(cmds, cmdsMap) err := c.mapCmdsByNode(cmds, cmdsMap)
if err != nil { if err != nil {
@ -1216,11 +1224,11 @@ func (c *ClusterClient) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(fn) return c.TxPipeline().Pipelined(fn)
} }
func (c *ClusterClient) processTxPipeline(cmds []Cmder) error { func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
return c.hooks.processPipeline(c.ctx, cmds, c._processTxPipeline) return c.hooks.processPipeline(ctx, cmds, c._processTxPipeline)
} }
func (c *ClusterClient) _processTxPipeline(cmds []Cmder) error { func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) error {
state, err := c.state.Get() state, err := c.state.Get()
if err != nil { if err != nil {
return err return err

View File

@ -1,6 +1,8 @@
package redis package redis
import "sync" import (
"sync"
)
// ScanIterator is used to incrementally iterate over a collection of elements. // ScanIterator is used to incrementally iterate over a collection of elements.
// It's safe for concurrent use by multiple goroutines. // It's safe for concurrent use by multiple goroutines.

View File

@ -16,9 +16,9 @@ import (
// Limiter is the interface of a rate limiter or a circuit breaker. // Limiter is the interface of a rate limiter or a circuit breaker.
type Limiter interface { type Limiter interface {
// Allow returns a nil if operation is allowed or an error otherwise. // Allow returns nil if operation is allowed or an error otherwise.
// If operation is allowed client must report the result of operation // If operation is allowed client must ReportResult of the operation
// whether is a success or a failure. // whether it is a success or a failure.
Allow() error Allow() error
// ReportResult reports the result of previously allowed operation. // ReportResult reports the result of previously allowed operation.
// nil indicates a success, non-nil error indicates a failure. // nil indicates a success, non-nil error indicates a failure.

View File

@ -1,12 +1,13 @@
package redis package redis
import ( import (
"context"
"sync" "sync"
"github.com/go-redis/redis/internal/pool" "github.com/go-redis/redis/internal/pool"
) )
type pipelineExecer func([]Cmder) error type pipelineExecer func(context.Context, []Cmder) error
// Pipeliner is an mechanism to realise Redis Pipeline technique. // Pipeliner is an mechanism to realise Redis Pipeline technique.
// //
@ -28,6 +29,7 @@ type Pipeliner interface {
Close() error Close() error
Discard() error Discard() error
Exec() ([]Cmder, error) Exec() ([]Cmder, error)
ExecContext(ctx context.Context) ([]Cmder, error)
} }
var _ Pipeliner = (*Pipeline)(nil) var _ Pipeliner = (*Pipeline)(nil)
@ -96,6 +98,10 @@ func (c *Pipeline) discard() error {
// Exec always returns list of commands and error of the first failed // Exec always returns list of commands and error of the first failed
// command if any. // command if any.
func (c *Pipeline) Exec() ([]Cmder, error) { func (c *Pipeline) Exec() ([]Cmder, error) {
return c.ExecContext(nil)
}
func (c *Pipeline) ExecContext(ctx context.Context) ([]Cmder, error) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -110,10 +116,10 @@ func (c *Pipeline) Exec() ([]Cmder, error) {
cmds := c.cmds cmds := c.cmds
c.cmds = nil c.cmds = nil
return cmds, c.exec(cmds) return cmds, c.exec(ctx, cmds)
} }
func (c *Pipeline) pipelined(fn func(Pipeliner) error) ([]Cmder, error) { func (c *Pipeline) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
if err := fn(c); err != nil { if err := fn(c); err != nil {
return nil, err return nil, err
} }
@ -122,16 +128,12 @@ func (c *Pipeline) pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return cmds, err return cmds, err
} }
func (c *Pipeline) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.pipelined(fn)
}
func (c *Pipeline) Pipeline() Pipeliner { func (c *Pipeline) Pipeline() Pipeliner {
return c return c
} }
func (c *Pipeline) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) { func (c *Pipeline) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.pipelined(fn) return c.Pipelined(fn)
} }
func (c *Pipeline) TxPipeline() Pipeliner { func (c *Pipeline) TxPipeline() Pipeliner {

View File

@ -45,13 +45,15 @@ func (hs *hooks) AddHook(hook Hook) {
hs.hooks = append(hs.hooks, hook) hs.hooks = append(hs.hooks, hook)
} }
func (hs hooks) process(ctx context.Context, cmd Cmder, fn func(Cmder) error) error { func (hs hooks) process(
ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error,
) error {
ctx, err := hs.beforeProcess(ctx, cmd) ctx, err := hs.beforeProcess(ctx, cmd)
if err != nil { if err != nil {
return err return err
} }
cmdErr := fn(cmd) cmdErr := fn(ctx, cmd)
_, err = hs.afterProcess(ctx, cmd) _, err = hs.afterProcess(ctx, cmd)
if err != nil { if err != nil {
@ -83,13 +85,15 @@ func (hs hooks) afterProcess(ctx context.Context, cmd Cmder) (context.Context, e
return ctx, nil return ctx, nil
} }
func (hs hooks) processPipeline(ctx context.Context, cmds []Cmder, fn func([]Cmder) error) error { func (hs hooks) processPipeline(
ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
) error {
ctx, err := hs.beforeProcessPipeline(ctx, cmds) ctx, err := hs.beforeProcessPipeline(ctx, cmds)
if err != nil { if err != nil {
return err return err
} }
cmdsErr := fn(cmds) cmdsErr := fn(ctx, cmds)
_, err = hs.afterProcessPipeline(ctx, cmds) _, err = hs.afterProcessPipeline(ctx, cmds)
if err != nil { if err != nil {
@ -246,14 +250,7 @@ func (c *baseClient) initConn(cn *pool.Conn) error {
return nil return nil
} }
// Do creates a Cmd from the args and processes the cmd. func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
func (c *baseClient) Do(args ...interface{}) *Cmd {
cmd := NewCmd(args...)
_ = c.process(cmd)
return cmd
}
func (c *baseClient) process(cmd Cmder) error {
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
if attempt > 0 { if attempt > 0 {
time.Sleep(c.retryBackoff(attempt)) time.Sleep(c.retryBackoff(attempt))
@ -328,11 +325,11 @@ func (c *baseClient) getAddr() string {
return c.opt.Addr return c.opt.Addr
} }
func (c *baseClient) processPipeline(cmds []Cmder) error { func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error {
return c.generalProcessPipeline(cmds, c.pipelineProcessCmds) return c.generalProcessPipeline(cmds, c.pipelineProcessCmds)
} }
func (c *baseClient) processTxPipeline(cmds []Cmder) error { func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
return c.generalProcessPipeline(cmds, c.txPipelineProcessCmds) return c.generalProcessPipeline(cmds, c.txPipelineProcessCmds)
} }
@ -503,16 +500,31 @@ func (c *Client) WithContext(ctx context.Context) *Client {
return &clone return &clone
} }
// Do creates a Cmd from the args and processes the cmd.
func (c *Client) Do(args ...interface{}) *Cmd {
return c.DoContext(c.ctx, args...)
}
func (c *Client) DoContext(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(args...)
_ = c.ProcessContext(ctx, cmd)
return cmd
}
func (c *Client) Process(cmd Cmder) error { func (c *Client) Process(cmd Cmder) error {
return c.hooks.process(c.ctx, cmd, c.baseClient.process) return c.ProcessContext(c.ctx, cmd)
} }
func (c *Client) processPipeline(cmds []Cmder) error { func (c *Client) ProcessContext(ctx context.Context, cmd Cmder) error {
return c.hooks.processPipeline(c.ctx, cmds, c.baseClient.processPipeline) return c.hooks.process(ctx, cmd, c.baseClient.process)
} }
func (c *Client) processTxPipeline(cmds []Cmder) error { func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error {
return c.hooks.processPipeline(c.ctx, cmds, c.baseClient.processTxPipeline) return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline)
}
func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error {
return c.hooks.processPipeline(ctx, cmds, c.baseClient.processTxPipeline)
} }
// Options returns read-only Options that were used to create the client. // Options returns read-only Options that were used to create the client.
@ -637,7 +649,11 @@ func newConn(opt *Options, cn *pool.Conn) *Conn {
} }
func (c *Conn) Process(cmd Cmder) error { func (c *Conn) Process(cmd Cmder) error {
return c.baseClient.process(cmd) return c.ProcessContext(context.TODO(), cmd)
}
func (c *Conn) ProcessContext(ctx context.Context, cmd Cmder) error {
return c.baseClient.process(ctx, cmd)
} }
func (c *Conn) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { func (c *Conn) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {

22
ring.go
View File

@ -396,13 +396,21 @@ func (c *Ring) WithContext(ctx context.Context) *Ring {
// Do creates a Cmd from the args and processes the cmd. // Do creates a Cmd from the args and processes the cmd.
func (c *Ring) Do(args ...interface{}) *Cmd { func (c *Ring) Do(args ...interface{}) *Cmd {
return c.DoContext(c.ctx, args...)
}
func (c *Ring) DoContext(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(args...) cmd := NewCmd(args...)
c.Process(cmd) c.ProcessContext(ctx, cmd)
return cmd return cmd
} }
func (c *Ring) Process(cmd Cmder) error { func (c *Ring) Process(cmd Cmder) error {
return c.hooks.process(c.ctx, cmd, c.process) return c.ProcessContext(c.ctx, cmd)
}
func (c *Ring) ProcessContext(ctx context.Context, cmd Cmder) error {
return c.hooks.process(ctx, cmd, c.process)
} }
// Options returns read-only Options that were used to create the client. // Options returns read-only Options that were used to create the client.
@ -532,7 +540,7 @@ func (c *Ring) cmdShard(cmd Cmder) (*ringShard, error) {
return c.shards.GetByKey(firstKey) return c.shards.GetByKey(firstKey)
} }
func (c *Ring) process(cmd Cmder) error { func (c *Ring) process(ctx context.Context, cmd Cmder) error {
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
if attempt > 0 { if attempt > 0 {
time.Sleep(c.retryBackoff(attempt)) time.Sleep(c.retryBackoff(attempt))
@ -544,7 +552,7 @@ func (c *Ring) process(cmd Cmder) error {
return err return err
} }
err = shard.Client.Process(cmd) err = shard.Client.ProcessContext(ctx, cmd)
if err == nil { if err == nil {
return nil return nil
} }
@ -567,11 +575,11 @@ func (c *Ring) Pipeline() Pipeliner {
return &pipe return &pipe
} }
func (c *Ring) processPipeline(cmds []Cmder) error { func (c *Ring) processPipeline(ctx context.Context, cmds []Cmder) error {
return c.hooks.processPipeline(c.ctx, cmds, c._processPipeline) return c.hooks.processPipeline(ctx, cmds, c._processPipeline)
} }
func (c *Ring) _processPipeline(cmds []Cmder) error { func (c *Ring) _processPipeline(ctx context.Context, cmds []Cmder) error {
cmdsMap := make(map[string][]Cmder) cmdsMap := make(map[string][]Cmder)
for _, cmd := range cmds { for _, cmd := range cmds {
cmdInfo := c.cmdInfo(cmd.Name()) cmdInfo := c.cmdInfo(cmd.Name())

View File

@ -136,7 +136,11 @@ func (c *SentinelClient) WithContext(ctx context.Context) *SentinelClient {
} }
func (c *SentinelClient) Process(cmd Cmder) error { func (c *SentinelClient) Process(cmd Cmder) error {
return c.baseClient.process(cmd) return c.ProcessContext(c.ctx, cmd)
}
func (c *SentinelClient) ProcessContext(ctx context.Context, cmd Cmder) error {
return c.baseClient.process(ctx, cmd)
} }
func (c *SentinelClient) pubSub() *PubSub { func (c *SentinelClient) pubSub() *PubSub {

6
tx.go
View File

@ -56,7 +56,11 @@ func (c *Tx) WithContext(ctx context.Context) *Tx {
} }
func (c *Tx) Process(cmd Cmder) error { func (c *Tx) Process(cmd Cmder) error {
return c.baseClient.process(cmd) return c.ProcessContext(c.ctx, cmd)
}
func (c *Tx) ProcessContext(ctx context.Context, cmd Cmder) error {
return c.baseClient.process(ctx, cmd)
} }
// Watch prepares a transaction and marks the keys to be watched // Watch prepares a transaction and marks the keys to be watched

View File

@ -162,7 +162,10 @@ type UniversalClient interface {
Context() context.Context Context() context.Context
AddHook(Hook) AddHook(Hook)
Watch(fn func(*Tx) error, keys ...string) error Watch(fn func(*Tx) error, keys ...string) error
Do(args ...interface{}) *Cmd
DoContext(ctx context.Context, args ...interface{}) *Cmd
Process(cmd Cmder) error Process(cmd Cmder) error
ProcessContext(ctx context.Context, cmd Cmder) error
Subscribe(channels ...string) *PubSub Subscribe(channels ...string) *PubSub
PSubscribe(channels ...string) *PubSub PSubscribe(channels ...string) *PubSub
Close() error Close() error