diff --git a/options.go b/options.go index 8ba74ccd..4169b928 100644 --- a/options.go +++ b/options.go @@ -21,10 +21,12 @@ type Limiter interface { // Allow returns nil if operation is allowed or an error otherwise. // If operation is allowed client must ReportResult of the operation // whether it is a success or a failure. - Allow() error + // The returned context will be passed to ReportResult. + Allow(ctx context.Context) (context.Context, error) // ReportResult reports the result of the previously allowed operation. // nil indicates a success, non-nil error usually indicates a failure. - ReportResult(result error) + // Context can be used to access state tracked by previous Allow call. + ReportResult(ctx context.Context, result error) } // Options keeps the settings to set up redis connection. diff --git a/osscluster.go b/osscluster.go index ce258ff3..cc40f77a 100644 --- a/osscluster.go +++ b/osscluster.go @@ -1319,7 +1319,7 @@ func (c *ClusterClient) processPipelineNode( ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, ) { _ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { - cn, err := node.Client.getConn(ctx) + ctx, cn, err := node.Client.getConn(ctx) if err != nil { node.MarkAsFailing() _ = c.mapCmdsByNode(ctx, failedCmds, cmds) @@ -1504,7 +1504,7 @@ func (c *ClusterClient) processTxPipelineNode( ) { cmds = wrapMultiExec(ctx, cmds) _ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { - cn, err := node.Client.getConn(ctx) + ctx, cn, err := node.Client.getConn(ctx) if err != nil { _ = c.mapCmdsByNode(ctx, failedCmds, cmds) setCmdsErr(cmds, err) diff --git a/redis.go b/redis.go index c8b50080..ca79b0fe 100644 --- a/redis.go +++ b/redis.go @@ -237,23 +237,24 @@ func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) { return cn, nil } -func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) { +func (c *baseClient) getConn(ctx context.Context) (context.Context, *pool.Conn, error) { + var err error if c.opt.Limiter != nil { - err := c.opt.Limiter.Allow() + ctx, err = c.opt.Limiter.Allow(ctx) if err != nil { - return nil, err + return ctx, nil, err } } cn, err := c._getConn(ctx) if err != nil { if c.opt.Limiter != nil { - c.opt.Limiter.ReportResult(err) + c.opt.Limiter.ReportResult(ctx, err) } - return nil, err + return ctx, nil, err } - return cn, nil + return ctx, cn, nil } func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { @@ -365,7 +366,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) { if c.opt.Limiter != nil { - c.opt.Limiter.ReportResult(err) + c.opt.Limiter.ReportResult(ctx, err) } if isBadConn(err, false, c.opt.Addr) { @@ -378,7 +379,7 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) func (c *baseClient) withConn( ctx context.Context, fn func(context.Context, *pool.Conn) error, ) error { - cn, err := c.getConn(ctx) + ctx, cn, err := c.getConn(ctx) if err != nil { return err }