From 09eb10873829861034f7c20222fe37f13615afd1 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Tue, 4 Jun 2019 13:30:47 +0300 Subject: [PATCH] Allow passing context where possible --- cluster.go | 28 ++++++++++++++++---------- iterator.go | 4 +++- options.go | 6 +++--- pipeline.go | 18 +++++++++-------- redis.go | 56 +++++++++++++++++++++++++++++++++------------------- ring.go | 22 ++++++++++++++------- sentinel.go | 6 +++++- tx.go | 6 +++++- universal.go | 3 +++ 9 files changed, 98 insertions(+), 51 deletions(-) diff --git a/cluster.go b/cluster.go index e1c1b12..6037387 100644 --- a/cluster.go +++ b/cluster.go @@ -724,16 +724,24 @@ func (c *ClusterClient) Close() error { // Do creates a Cmd from the args and processes the cmd. func (c *ClusterClient) Do(args ...interface{}) *Cmd { + return c.DoContext(c.ctx, args...) +} + +func (c *ClusterClient) DoContext(ctx context.Context, args ...interface{}) *Cmd { cmd := NewCmd(args...) - c.Process(cmd) + c.ProcessContext(ctx, cmd) return cmd } func (c *ClusterClient) Process(cmd Cmder) error { - return c.hooks.process(c.ctx, cmd, c.process) + return c.ProcessContext(c.ctx, cmd) } -func (c *ClusterClient) process(cmd Cmder) error { +func (c *ClusterClient) ProcessContext(ctx context.Context, cmd Cmder) error { + return c.hooks.process(ctx, cmd, c.process) +} + +func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { var node *clusterNode var ask bool for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { @@ -755,11 +763,11 @@ func (c *ClusterClient) process(cmd Cmder) error { pipe := node.Client.Pipeline() _ = pipe.Process(NewCmd("ASKING")) _ = pipe.Process(cmd) - _, err = pipe.Exec() + _, err = pipe.ExecContext(ctx) _ = pipe.Close() ask = false } else { - err = node.Client.Process(cmd) + err = node.Client.ProcessContext(ctx, cmd) } // If there is no error - we are done. @@ -1022,11 +1030,11 @@ func (c *ClusterClient) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { return c.Pipeline().Pipelined(fn) } -func (c *ClusterClient) processPipeline(cmds []Cmder) error { +func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error { return c.hooks.processPipeline(c.ctx, cmds, c._processPipeline) } -func (c *ClusterClient) _processPipeline(cmds []Cmder) error { +func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) error { cmdsMap := newCmdsMap() err := c.mapCmdsByNode(cmds, cmdsMap) if err != nil { @@ -1216,11 +1224,11 @@ func (c *ClusterClient) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) { return c.TxPipeline().Pipelined(fn) } -func (c *ClusterClient) processTxPipeline(cmds []Cmder) error { - return c.hooks.processPipeline(c.ctx, cmds, c._processTxPipeline) +func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) error { + return c.hooks.processPipeline(ctx, cmds, c._processTxPipeline) } -func (c *ClusterClient) _processTxPipeline(cmds []Cmder) error { +func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) error { state, err := c.state.Get() if err != nil { return err diff --git a/iterator.go b/iterator.go index 5d4bedf..a1e69e2 100644 --- a/iterator.go +++ b/iterator.go @@ -1,6 +1,8 @@ package redis -import "sync" +import ( + "sync" +) // ScanIterator is used to incrementally iterate over a collection of elements. // It's safe for concurrent use by multiple goroutines. diff --git a/options.go b/options.go index d00ead0..2a4d39f 100644 --- a/options.go +++ b/options.go @@ -16,9 +16,9 @@ import ( // Limiter is the interface of a rate limiter or a circuit breaker. type Limiter interface { - // Allow returns a nil if operation is allowed or an error otherwise. - // If operation is allowed client must report the result of operation - // whether is a success or a failure. + // Allow returns nil if operation is allowed or an error otherwise. + // If operation is allowed client must ReportResult of the operation + // whether it is a success or a failure. Allow() error // ReportResult reports the result of previously allowed operation. // nil indicates a success, non-nil error indicates a failure. diff --git a/pipeline.go b/pipeline.go index d403965..51333a7 100644 --- a/pipeline.go +++ b/pipeline.go @@ -1,12 +1,13 @@ package redis import ( + "context" "sync" "github.com/go-redis/redis/internal/pool" ) -type pipelineExecer func([]Cmder) error +type pipelineExecer func(context.Context, []Cmder) error // Pipeliner is an mechanism to realise Redis Pipeline technique. // @@ -28,6 +29,7 @@ type Pipeliner interface { Close() error Discard() error Exec() ([]Cmder, error) + ExecContext(ctx context.Context) ([]Cmder, error) } var _ Pipeliner = (*Pipeline)(nil) @@ -96,6 +98,10 @@ func (c *Pipeline) discard() error { // Exec always returns list of commands and error of the first failed // command if any. func (c *Pipeline) Exec() ([]Cmder, error) { + return c.ExecContext(nil) +} + +func (c *Pipeline) ExecContext(ctx context.Context) ([]Cmder, error) { c.mu.Lock() defer c.mu.Unlock() @@ -110,10 +116,10 @@ func (c *Pipeline) Exec() ([]Cmder, error) { cmds := c.cmds c.cmds = nil - return cmds, c.exec(cmds) + return cmds, c.exec(ctx, cmds) } -func (c *Pipeline) pipelined(fn func(Pipeliner) error) ([]Cmder, error) { +func (c *Pipeline) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { if err := fn(c); err != nil { return nil, err } @@ -122,16 +128,12 @@ func (c *Pipeline) pipelined(fn func(Pipeliner) error) ([]Cmder, error) { return cmds, err } -func (c *Pipeline) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { - return c.pipelined(fn) -} - func (c *Pipeline) Pipeline() Pipeliner { return c } func (c *Pipeline) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) { - return c.pipelined(fn) + return c.Pipelined(fn) } func (c *Pipeline) TxPipeline() Pipeliner { diff --git a/redis.go b/redis.go index 8bbfbe9..793a4cc 100644 --- a/redis.go +++ b/redis.go @@ -45,13 +45,15 @@ func (hs *hooks) AddHook(hook Hook) { hs.hooks = append(hs.hooks, hook) } -func (hs hooks) process(ctx context.Context, cmd Cmder, fn func(Cmder) error) error { +func (hs hooks) process( + ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error, +) error { ctx, err := hs.beforeProcess(ctx, cmd) if err != nil { return err } - cmdErr := fn(cmd) + cmdErr := fn(ctx, cmd) _, err = hs.afterProcess(ctx, cmd) if err != nil { @@ -83,13 +85,15 @@ func (hs hooks) afterProcess(ctx context.Context, cmd Cmder) (context.Context, e return ctx, nil } -func (hs hooks) processPipeline(ctx context.Context, cmds []Cmder, fn func([]Cmder) error) error { +func (hs hooks) processPipeline( + ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, +) error { ctx, err := hs.beforeProcessPipeline(ctx, cmds) if err != nil { return err } - cmdsErr := fn(cmds) + cmdsErr := fn(ctx, cmds) _, err = hs.afterProcessPipeline(ctx, cmds) if err != nil { @@ -246,14 +250,7 @@ func (c *baseClient) initConn(cn *pool.Conn) error { return nil } -// Do creates a Cmd from the args and processes the cmd. -func (c *baseClient) Do(args ...interface{}) *Cmd { - cmd := NewCmd(args...) - _ = c.process(cmd) - return cmd -} - -func (c *baseClient) process(cmd Cmder) error { +func (c *baseClient) process(ctx context.Context, cmd Cmder) error { for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { if attempt > 0 { time.Sleep(c.retryBackoff(attempt)) @@ -328,11 +325,11 @@ func (c *baseClient) getAddr() string { return c.opt.Addr } -func (c *baseClient) processPipeline(cmds []Cmder) error { +func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error { return c.generalProcessPipeline(cmds, c.pipelineProcessCmds) } -func (c *baseClient) processTxPipeline(cmds []Cmder) error { +func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error { return c.generalProcessPipeline(cmds, c.txPipelineProcessCmds) } @@ -503,16 +500,31 @@ func (c *Client) WithContext(ctx context.Context) *Client { return &clone } +// Do creates a Cmd from the args and processes the cmd. +func (c *Client) Do(args ...interface{}) *Cmd { + return c.DoContext(c.ctx, args...) +} + +func (c *Client) DoContext(ctx context.Context, args ...interface{}) *Cmd { + cmd := NewCmd(args...) + _ = c.ProcessContext(ctx, cmd) + return cmd +} + func (c *Client) Process(cmd Cmder) error { - return c.hooks.process(c.ctx, cmd, c.baseClient.process) + return c.ProcessContext(c.ctx, cmd) } -func (c *Client) processPipeline(cmds []Cmder) error { - return c.hooks.processPipeline(c.ctx, cmds, c.baseClient.processPipeline) +func (c *Client) ProcessContext(ctx context.Context, cmd Cmder) error { + return c.hooks.process(ctx, cmd, c.baseClient.process) } -func (c *Client) processTxPipeline(cmds []Cmder) error { - return c.hooks.processPipeline(c.ctx, cmds, c.baseClient.processTxPipeline) +func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error { + return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline) +} + +func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error { + return c.hooks.processPipeline(ctx, cmds, c.baseClient.processTxPipeline) } // Options returns read-only Options that were used to create the client. @@ -637,7 +649,11 @@ func newConn(opt *Options, cn *pool.Conn) *Conn { } func (c *Conn) Process(cmd Cmder) error { - return c.baseClient.process(cmd) + return c.ProcessContext(context.TODO(), cmd) +} + +func (c *Conn) ProcessContext(ctx context.Context, cmd Cmder) error { + return c.baseClient.process(ctx, cmd) } func (c *Conn) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { diff --git a/ring.go b/ring.go index b18ade0..c207039 100644 --- a/ring.go +++ b/ring.go @@ -396,13 +396,21 @@ func (c *Ring) WithContext(ctx context.Context) *Ring { // Do creates a Cmd from the args and processes the cmd. func (c *Ring) Do(args ...interface{}) *Cmd { + return c.DoContext(c.ctx, args...) +} + +func (c *Ring) DoContext(ctx context.Context, args ...interface{}) *Cmd { cmd := NewCmd(args...) - c.Process(cmd) + c.ProcessContext(ctx, cmd) return cmd } func (c *Ring) Process(cmd Cmder) error { - return c.hooks.process(c.ctx, cmd, c.process) + return c.ProcessContext(c.ctx, cmd) +} + +func (c *Ring) ProcessContext(ctx context.Context, cmd Cmder) error { + return c.hooks.process(ctx, cmd, c.process) } // Options returns read-only Options that were used to create the client. @@ -532,7 +540,7 @@ func (c *Ring) cmdShard(cmd Cmder) (*ringShard, error) { return c.shards.GetByKey(firstKey) } -func (c *Ring) process(cmd Cmder) error { +func (c *Ring) process(ctx context.Context, cmd Cmder) error { for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { if attempt > 0 { time.Sleep(c.retryBackoff(attempt)) @@ -544,7 +552,7 @@ func (c *Ring) process(cmd Cmder) error { return err } - err = shard.Client.Process(cmd) + err = shard.Client.ProcessContext(ctx, cmd) if err == nil { return nil } @@ -567,11 +575,11 @@ func (c *Ring) Pipeline() Pipeliner { return &pipe } -func (c *Ring) processPipeline(cmds []Cmder) error { - return c.hooks.processPipeline(c.ctx, cmds, c._processPipeline) +func (c *Ring) processPipeline(ctx context.Context, cmds []Cmder) error { + return c.hooks.processPipeline(ctx, cmds, c._processPipeline) } -func (c *Ring) _processPipeline(cmds []Cmder) error { +func (c *Ring) _processPipeline(ctx context.Context, cmds []Cmder) error { cmdsMap := make(map[string][]Cmder) for _, cmd := range cmds { cmdInfo := c.cmdInfo(cmd.Name()) diff --git a/sentinel.go b/sentinel.go index 8898350..3636331 100644 --- a/sentinel.go +++ b/sentinel.go @@ -136,7 +136,11 @@ func (c *SentinelClient) WithContext(ctx context.Context) *SentinelClient { } func (c *SentinelClient) Process(cmd Cmder) error { - return c.baseClient.process(cmd) + return c.ProcessContext(c.ctx, cmd) +} + +func (c *SentinelClient) ProcessContext(ctx context.Context, cmd Cmder) error { + return c.baseClient.process(ctx, cmd) } func (c *SentinelClient) pubSub() *PubSub { diff --git a/tx.go b/tx.go index bc4e390..f6b8bbd 100644 --- a/tx.go +++ b/tx.go @@ -56,7 +56,11 @@ func (c *Tx) WithContext(ctx context.Context) *Tx { } func (c *Tx) Process(cmd Cmder) error { - return c.baseClient.process(cmd) + return c.ProcessContext(c.ctx, cmd) +} + +func (c *Tx) ProcessContext(ctx context.Context, cmd Cmder) error { + return c.baseClient.process(ctx, cmd) } // Watch prepares a transaction and marks the keys to be watched diff --git a/universal.go b/universal.go index 656bc5c..dd2b5f7 100644 --- a/universal.go +++ b/universal.go @@ -162,7 +162,10 @@ type UniversalClient interface { Context() context.Context AddHook(Hook) Watch(fn func(*Tx) error, keys ...string) error + Do(args ...interface{}) *Cmd + DoContext(ctx context.Context, args ...interface{}) *Cmd Process(cmd Cmder) error + ProcessContext(ctx context.Context, cmd Cmder) error Subscribe(channels ...string) *PubSub PSubscribe(channels ...string) *PubSub Close() error