This commit is contained in:
Andreas Bergmeier 2024-11-22 00:17:39 -05:00 committed by GitHub
commit a2d86d85a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 12 deletions

View File

@ -21,10 +21,12 @@ type Limiter interface {
// Allow returns nil if operation is allowed or an error otherwise. // Allow returns nil if operation is allowed or an error otherwise.
// If operation is allowed client must ReportResult of the operation // If operation is allowed client must ReportResult of the operation
// whether it is a success or a failure. // whether it is a success or a failure.
Allow() error // The returned context will be passed to ReportResult.
Allow(ctx context.Context) (context.Context, error)
// ReportResult reports the result of the previously allowed operation. // ReportResult reports the result of the previously allowed operation.
// nil indicates a success, non-nil error usually indicates a failure. // nil indicates a success, non-nil error usually indicates a failure.
ReportResult(result error) // Context can be used to access state tracked by previous Allow call.
ReportResult(ctx context.Context, result error)
} }
// Options keeps the settings to set up redis connection. // Options keeps the settings to set up redis connection.

View File

@ -1342,7 +1342,7 @@ func (c *ClusterClient) processPipelineNode(
ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap,
) { ) {
_ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { _ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
cn, err := node.Client.getConn(ctx) ctx, cn, err := node.Client.getConn(ctx)
if err != nil { if err != nil {
node.MarkAsFailing() node.MarkAsFailing()
_ = c.mapCmdsByNode(ctx, failedCmds, cmds) _ = c.mapCmdsByNode(ctx, failedCmds, cmds)
@ -1527,7 +1527,7 @@ func (c *ClusterClient) processTxPipelineNode(
) { ) {
cmds = wrapMultiExec(ctx, cmds) cmds = wrapMultiExec(ctx, cmds)
_ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { _ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
cn, err := node.Client.getConn(ctx) ctx, cn, err := node.Client.getConn(ctx)
if err != nil { if err != nil {
_ = c.mapCmdsByNode(ctx, failedCmds, cmds) _ = c.mapCmdsByNode(ctx, failedCmds, cmds)
setCmdsErr(cmds, err) setCmdsErr(cmds, err)

View File

@ -235,23 +235,24 @@ func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
return cn, nil return cn, nil
} }
func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) { func (c *baseClient) getConn(ctx context.Context) (context.Context, *pool.Conn, error) {
var err error
if c.opt.Limiter != nil { if c.opt.Limiter != nil {
err := c.opt.Limiter.Allow() ctx, err = c.opt.Limiter.Allow(ctx)
if err != nil { if err != nil {
return nil, err return ctx, nil, err
} }
} }
cn, err := c._getConn(ctx) cn, err := c._getConn(ctx)
if err != nil { if err != nil {
if c.opt.Limiter != nil { if c.opt.Limiter != nil {
c.opt.Limiter.ReportResult(err) c.opt.Limiter.ReportResult(ctx, err)
} }
return nil, err return ctx, nil, err
} }
return cn, nil return ctx, cn, nil
} }
func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
@ -363,7 +364,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) { func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) {
if c.opt.Limiter != nil { if c.opt.Limiter != nil {
c.opt.Limiter.ReportResult(err) c.opt.Limiter.ReportResult(ctx, err)
} }
if isBadConn(err, false, c.opt.Addr) { if isBadConn(err, false, c.opt.Addr) {
@ -376,7 +377,7 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error)
func (c *baseClient) withConn( func (c *baseClient) withConn(
ctx context.Context, fn func(context.Context, *pool.Conn) error, ctx context.Context, fn func(context.Context, *pool.Conn) error,
) error { ) error {
cn, err := c.getConn(ctx) ctx, cn, err := c.getConn(ctx)
if err != nil { if err != nil {
return err return err
} }