This commit is contained in:
Andreas Bergmeier 2024-11-22 00:17:39 -05:00 committed by GitHub
commit a2d86d85a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 12 deletions

View File

@ -21,10 +21,12 @@ type Limiter interface {
// Allow returns nil if operation is allowed or an error otherwise.
// If operation is allowed client must ReportResult of the operation
// whether it is a success or a failure.
Allow() error
// The returned context will be passed to ReportResult.
Allow(ctx context.Context) (context.Context, error)
// ReportResult reports the result of the previously allowed operation.
// nil indicates a success, non-nil error usually indicates a failure.
ReportResult(result error)
// Context can be used to access state tracked by previous Allow call.
ReportResult(ctx context.Context, result error)
}
// Options keeps the settings to set up redis connection.

View File

@ -1342,7 +1342,7 @@ func (c *ClusterClient) processPipelineNode(
ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap,
) {
_ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
cn, err := node.Client.getConn(ctx)
ctx, cn, err := node.Client.getConn(ctx)
if err != nil {
node.MarkAsFailing()
_ = c.mapCmdsByNode(ctx, failedCmds, cmds)
@ -1527,7 +1527,7 @@ func (c *ClusterClient) processTxPipelineNode(
) {
cmds = wrapMultiExec(ctx, cmds)
_ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
cn, err := node.Client.getConn(ctx)
ctx, cn, err := node.Client.getConn(ctx)
if err != nil {
_ = c.mapCmdsByNode(ctx, failedCmds, cmds)
setCmdsErr(cmds, err)

View File

@ -235,23 +235,24 @@ func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
return cn, nil
}
func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
func (c *baseClient) getConn(ctx context.Context) (context.Context, *pool.Conn, error) {
var err error
if c.opt.Limiter != nil {
err := c.opt.Limiter.Allow()
ctx, err = c.opt.Limiter.Allow(ctx)
if err != nil {
return nil, err
return ctx, nil, err
}
}
cn, err := c._getConn(ctx)
if err != nil {
if c.opt.Limiter != nil {
c.opt.Limiter.ReportResult(err)
c.opt.Limiter.ReportResult(ctx, err)
}
return nil, err
return ctx, nil, err
}
return cn, nil
return ctx, cn, nil
}
func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
@ -363,7 +364,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) {
if c.opt.Limiter != nil {
c.opt.Limiter.ReportResult(err)
c.opt.Limiter.ReportResult(ctx, err)
}
if isBadConn(err, false, c.opt.Addr) {
@ -376,7 +377,7 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error)
func (c *baseClient) withConn(
ctx context.Context, fn func(context.Context, *pool.Conn) error,
) error {
cn, err := c.getConn(ctx)
ctx, cn, err := c.getConn(ctx)
if err != nil {
return err
}