From d2e52839eefb9d43dcd76d1d28e538ab3b405d43 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Sun, 2 Feb 2020 14:59:27 +0200 Subject: [PATCH] Add WithTimeout --- options.go | 5 +++++ redis.go | 51 +++++++++++++++++++++++++++++++++++++++++---------- redis_test.go | 15 +++++++++++++-- sentinel.go | 9 +++------ 4 files changed, 62 insertions(+), 18 deletions(-) diff --git a/options.go b/options.go index ae269199..621d3a37 100644 --- a/options.go +++ b/options.go @@ -169,6 +169,11 @@ func (opt *Options) init() { } } +func (opt *Options) clone() *Options { + clone := *opt + return &clone +} + // ParseURL parses an URL into Options that can be used to connect to Redis. func ParseURL(redisURL string) (*Options, error) { o := &Options{Network: "tcp"} diff --git a/redis.go b/redis.go index 66dc72f8..c9a9e172 100644 --- a/redis.go +++ b/redis.go @@ -137,6 +137,29 @@ type baseClient struct { onClose func() error // hook called when client is closed } +func newBaseClient(opt *Options, connPool pool.Pooler) *baseClient { + return &baseClient{ + opt: opt, + connPool: connPool, + } +} + +func (c *baseClient) clone() *baseClient { + clone := *c + return &clone +} + +func (c *baseClient) withTimeout(timeout time.Duration) *baseClient { + opt := c.opt.clone() + opt.ReadTimeout = timeout + opt.WriteTimeout = timeout + + clone := c.clone() + clone.opt = opt + + return clone +} + func (c *baseClient) String() string { return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB) } @@ -481,7 +504,7 @@ func txPipelineReadQueued(rd *proto.Reader, cmds []Cmder) error { // underlying connections. It's safe for concurrent use by multiple // goroutines. type Client struct { - baseClient + *baseClient cmdable hooks ctx context.Context @@ -492,17 +515,27 @@ func NewClient(opt *Options) *Client { opt.init() c := Client{ - baseClient: baseClient{ - opt: opt, - connPool: newConnPool(opt), - }, - ctx: context.Background(), + baseClient: newBaseClient(opt, newConnPool(opt)), + ctx: context.Background(), } c.cmdable = c.Process return &c } +func (c *Client) clone() *Client { + clone := *c + clone.cmdable = clone.Process + clone.hooks.Lock() + return &clone +} + +func (c *Client) WithTimeout(timeout time.Duration) *Client { + clone := c.clone() + clone.baseClient = c.baseClient.withTimeout(timeout) + return clone +} + func (c *Client) Context() context.Context { return c.ctx } @@ -511,11 +544,9 @@ func (c *Client) WithContext(ctx context.Context) *Client { if ctx == nil { panic("nil context") } - clone := *c - clone.cmdable = clone.Process - clone.hooks.Lock() + clone := c.clone() clone.ctx = ctx - return &clone + return clone } func (c *Client) Conn() *Conn { diff --git a/redis_test.go b/redis_test.go index 8a1a1493..d8d1d63b 100644 --- a/redis_test.go +++ b/redis_test.go @@ -59,6 +59,10 @@ var _ = Describe("Client", func() { client.Close() }) + It("should Stringer", func() { + Expect(client.String()).To(Equal("Redis<:6380 db:15>")) + }) + It("supports WithContext", func() { c, cancel := context.WithCancel(context.Background()) cancel() @@ -67,8 +71,15 @@ var _ = Describe("Client", func() { Expect(err).To(MatchError("context canceled")) }) - It("should Stringer", func() { - Expect(client.String()).To(Equal("Redis<:6380 db:15>")) + It("supports WithTimeout", func() { + err := client.ClientPause(time.Second).Err() + Expect(err).NotTo(HaveOccurred()) + + err = client.WithTimeout(10 * time.Millisecond).Ping().Err() + Expect(err).To(HaveOccurred()) + + err = client.Ping().Err() + Expect(err).NotTo(HaveOccurred()) }) It("should ping", func() { diff --git a/sentinel.go b/sentinel.go index c10a26c7..6487ef63 100644 --- a/sentinel.go +++ b/sentinel.go @@ -94,14 +94,11 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { } c := Client{ - baseClient: baseClient{ - opt: opt, - connPool: failover.Pool(), - onClose: failover.Close, - }, - ctx: context.Background(), + baseClient: newBaseClient(opt, failover.Pool()), + ctx: context.Background(), } c.cmdable = c.Process + c.onClose = failover.Close return &c }