diff --git a/cluster.go b/cluster.go index b6c4daac..be8217b1 100644 --- a/cluster.go +++ b/cluster.go @@ -1554,7 +1554,7 @@ func (c *ClusterClient) retryBackoff(attempt int) time.Duration { return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) } -func (c *ClusterClient) cmdsInfo() (map[string]*CommandInfo, error) { +func (c *ClusterClient) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, error) { // Try 3 random nodes. const nodeLimit = 3 @@ -1581,7 +1581,7 @@ func (c *ClusterClient) cmdsInfo() (map[string]*CommandInfo, error) { continue } - info, err := node.Client.Command(c.ctx).Result() + info, err := node.Client.Command(ctx).Result() if err == nil { return info, nil } @@ -1597,7 +1597,7 @@ func (c *ClusterClient) cmdsInfo() (map[string]*CommandInfo, error) { } func (c *ClusterClient) cmdInfo(name string) *CommandInfo { - cmdsInfo, err := c.cmdsInfoCache.Get() + cmdsInfo, err := c.cmdsInfoCache.Get(c.ctx) if err != nil { return nil } diff --git a/command.go b/command.go index 1cf76168..55a5bd5c 100644 --- a/command.go +++ b/command.go @@ -2135,21 +2135,21 @@ func commandInfoParser(rd *proto.Reader, n int64) (interface{}, error) { //------------------------------------------------------------------------------ type cmdsInfoCache struct { - fn func() (map[string]*CommandInfo, error) + fn func(ctx context.Context) (map[string]*CommandInfo, error) once internal.Once cmds map[string]*CommandInfo } -func newCmdsInfoCache(fn func() (map[string]*CommandInfo, error)) *cmdsInfoCache { +func newCmdsInfoCache(fn func(ctx context.Context) (map[string]*CommandInfo, error)) *cmdsInfoCache { return &cmdsInfoCache{ fn: fn, } } -func (c *cmdsInfoCache) Get() (map[string]*CommandInfo, error) { +func (c *cmdsInfoCache) Get(ctx context.Context) (map[string]*CommandInfo, error) { err := c.once.Do(func() error { - cmds, err := c.fn() + cmds, err := c.fn(ctx) if err != nil { return err } diff --git a/ring.go b/ring.go index 370216f0..c095e4f7 100644 --- a/ring.go +++ b/ring.go @@ -547,11 +547,11 @@ func (c *Ring) ForEachShard( } } -func (c *Ring) cmdsInfo() (map[string]*CommandInfo, error) { +func (c *Ring) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, error) { shards := c.shards.List() var firstErr error for _, shard := range shards { - cmdsInfo, err := shard.Client.Command(context.TODO()).Result() + cmdsInfo, err := shard.Client.Command(ctx).Result() if err == nil { return cmdsInfo, nil } @@ -566,7 +566,7 @@ func (c *Ring) cmdsInfo() (map[string]*CommandInfo, error) { } func (c *Ring) cmdInfo(name string) *CommandInfo { - cmdsInfo, err := c.cmdsInfoCache.Get() + cmdsInfo, err := c.cmdsInfoCache.Get(c.ctx) if err != nil { return nil }