diff --git a/commands_test.go b/commands_test.go index fdc3452c..821d335a 100644 --- a/commands_test.go +++ b/commands_test.go @@ -4843,6 +4843,24 @@ var _ = Describe("Commands", func() { Expect(err).To(Equal(redis.Nil)) }) + Describe("canceled context", func() { + It("should unblock XRead", func() { + ctx2, cancel := context.WithCancel(ctx) + errCh := make(chan error, 1) + go func() { + errCh <- client.XRead(ctx2, &redis.XReadArgs{ + Streams: []string{"stream", "$"}, + }).Err() + }() + + var gotErr error + Consistently(errCh).ShouldNot(Receive(&gotErr), "Received %v", gotErr) + cancel() + Eventually(errCh).Should(Receive(&gotErr)) + Expect(gotErr).To(HaveOccurred()) + }) + }) + Describe("group", func() { BeforeEach(func() { err := client.XGroupCreate(ctx, "stream", "group", "0").Err() @@ -5023,6 +5041,26 @@ var _ = Describe("Commands", func() { Expect(err).NotTo(HaveOccurred()) Expect(n).To(Equal(int64(2))) }) + + Describe("canceled context", func() { + It("should unblock XReadGroup", func() { + ctx2, cancel := context.WithCancel(ctx) + errCh := make(chan error, 1) + go func() { + errCh <- client.XReadGroup(ctx2, &redis.XReadGroupArgs{ + Group: "group", + Consumer: "consumer", + Streams: []string{"stream", ">"}, + }).Err() + }() + + var gotErr error + Consistently(errCh).ShouldNot(Receive(&gotErr), "Received %v", gotErr) + cancel() + Eventually(errCh).Should(Receive(&gotErr)) + Expect(gotErr).To(HaveOccurred()) + }) + }) }) Describe("xinfo", func() { diff --git a/internal_test.go b/internal_test.go index a6317196..494cb96e 100644 --- a/internal_test.go +++ b/internal_test.go @@ -351,4 +351,21 @@ var _ = Describe("withConn", func() { Expect(newConn).NotTo(Equal(conn)) Expect(client.connPool.Len()).To(Equal(1)) }) + + It("should remove the connection from the pool if the context is canceled", func() { + var conn *pool.Conn + + ctx2, cancel := context.WithCancel(ctx) + cancel() + + client.withConn(ctx2, func(ctx context.Context, c *pool.Conn) error { + conn = c + return nil + }) + + newConn, err := client.connPool.Get(ctx) + Expect(err).To(BeNil()) + Expect(newConn).NotTo(Equal(conn)) + Expect(client.connPool.Len()).To(Equal(1)) + }) }) diff --git a/pubsub_test.go b/pubsub_test.go index a7610065..43b60f0a 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -1,6 +1,7 @@ package redis_test import ( + "context" "io" "net" "sync" @@ -567,4 +568,24 @@ var _ = Describe("PubSub", func() { Expect(msg.Channel).To(Equal("mychannel")) Expect(msg.Payload).To(Equal(text)) }) + + Describe("canceled context", func() { + It("should unblock ReceiveMessage", func() { + pubsub := client.Subscribe(ctx, "mychannel") + defer pubsub.Close() + + ctx2, cancel := context.WithCancel(ctx) + errCh := make(chan error, 1) + go func() { + _, err := pubsub.ReceiveMessage(ctx2) + errCh <- err + }() + + var gotErr error + Consistently(errCh).ShouldNot(Receive(&gotErr), "Received %v", gotErr) + cancel() + Eventually(errCh).Should(Receive(&gotErr)) + Expect(gotErr).To(HaveOccurred()) + }) + }) })