From 69bab3800731d01bcdc9127f4f761151810e5cfb Mon Sep 17 00:00:00 2001 From: Andreas Bergmeier Date: Thu, 10 Oct 2024 16:30:30 +0200 Subject: [PATCH] Allow for tracking state in Limiter Extend Allow and ReportResult functions to handle a Context. Allow can override the passed in Context. The returned Context is then further passed down to ReportResult. Using this Context it is then possible to store values/track state between Allow and ReportResult calls. Without this tracking HalfOpen/Generation state is hard to implement efficiently for Circuit Breakers. --- options.go | 6 ++++-- osscluster.go | 4 ++-- redis.go | 17 +++++++++-------- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/options.go b/options.go index 8ba74ccd..4169b928 100644 --- a/options.go +++ b/options.go @@ -21,10 +21,12 @@ type Limiter interface { // Allow returns nil if operation is allowed or an error otherwise. // If operation is allowed client must ReportResult of the operation // whether it is a success or a failure. - Allow() error + // The returned context will be passed to ReportResult. + Allow(ctx context.Context) (context.Context, error) // ReportResult reports the result of the previously allowed operation. // nil indicates a success, non-nil error usually indicates a failure. - ReportResult(result error) + // Context can be used to access state tracked by previous Allow call. + ReportResult(ctx context.Context, result error) } // Options keeps the settings to set up redis connection. diff --git a/osscluster.go b/osscluster.go index ce258ff3..cc40f77a 100644 --- a/osscluster.go +++ b/osscluster.go @@ -1319,7 +1319,7 @@ func (c *ClusterClient) processPipelineNode( ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, ) { _ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { - cn, err := node.Client.getConn(ctx) + ctx, cn, err := node.Client.getConn(ctx) if err != nil { node.MarkAsFailing() _ = c.mapCmdsByNode(ctx, failedCmds, cmds) @@ -1504,7 +1504,7 @@ func (c *ClusterClient) processTxPipelineNode( ) { cmds = wrapMultiExec(ctx, cmds) _ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { - cn, err := node.Client.getConn(ctx) + ctx, cn, err := node.Client.getConn(ctx) if err != nil { _ = c.mapCmdsByNode(ctx, failedCmds, cmds) setCmdsErr(cmds, err) diff --git a/redis.go b/redis.go index c8b50080..ca79b0fe 100644 --- a/redis.go +++ b/redis.go @@ -237,23 +237,24 @@ func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) { return cn, nil } -func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) { +func (c *baseClient) getConn(ctx context.Context) (context.Context, *pool.Conn, error) { + var err error if c.opt.Limiter != nil { - err := c.opt.Limiter.Allow() + ctx, err = c.opt.Limiter.Allow(ctx) if err != nil { - return nil, err + return ctx, nil, err } } cn, err := c._getConn(ctx) if err != nil { if c.opt.Limiter != nil { - c.opt.Limiter.ReportResult(err) + c.opt.Limiter.ReportResult(ctx, err) } - return nil, err + return ctx, nil, err } - return cn, nil + return ctx, cn, nil } func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { @@ -365,7 +366,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) { if c.opt.Limiter != nil { - c.opt.Limiter.ReportResult(err) + c.opt.Limiter.ReportResult(ctx, err) } if isBadConn(err, false, c.opt.Addr) { @@ -378,7 +379,7 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) func (c *baseClient) withConn( ctx context.Context, fn func(context.Context, *pool.Conn) error, ) error { - cn, err := c.getConn(ctx) + ctx, cn, err := c.getConn(ctx) if err != nil { return err }