From baa48a4415f7760ed80275acfe0747dd13342937 Mon Sep 17 00:00:00 2001 From: jianghang Date: Wed, 3 Aug 2022 23:10:03 +0800 Subject: [PATCH] feat(pubsub): support sharded pub/sub --- cluster.go | 10 +++++ cluster_test.go | 24 +++++++++++ commands.go | 31 ++++++++++++++ pubsub.go | 41 ++++++++++++++++++- pubsub_test.go | 105 ++++++++++++++++++++++++++++++++++++++++++++++++ redis.go | 10 +++++ ring.go | 13 ++++++ universal.go | 1 + 8 files changed, 233 insertions(+), 2 deletions(-) diff --git a/cluster.go b/cluster.go index 05b234a..ea31228 100644 --- a/cluster.go +++ b/cluster.go @@ -1528,6 +1528,16 @@ func (c *ClusterClient) PSubscribe(ctx context.Context, channels ...string) *Pub return pubsub } +// SSubscribe Subscribes the client to the specified shard channels. +func (c *ClusterClient) SSubscribe(ctx context.Context, channels ...string) *PubSub { + pubsub := c.pubSub() + if len(channels) > 0 { + _ = pubsub.SSubscribe(ctx, channels...) + } + return pubsub +} + + func (c *ClusterClient) retryBackoff(attempt int) time.Duration { return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) } diff --git a/cluster_test.go b/cluster_test.go index ee37dc2..e503ea8 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -549,6 +549,30 @@ var _ = Describe("ClusterClient", func() { }, 30*time.Second).ShouldNot(HaveOccurred()) }) + It("supports sharded PubSub", func() { + pubsub := client.SSubscribe(ctx, "mychannel") + defer pubsub.Close() + + Eventually(func() error { + _, err := client.SPublish(ctx, "mychannel", "hello").Result() + if err != nil { + return err + } + + msg, err := pubsub.ReceiveTimeout(ctx, time.Second) + if err != nil { + return err + } + + _, ok := msg.(*redis.Message) + if !ok { + return fmt.Errorf("got %T, wanted *redis.Message", msg) + } + + return nil + }, 30*time.Second).ShouldNot(HaveOccurred()) + }) + It("supports PubSub.Ping without channels", func() { pubsub := client.Subscribe(ctx) defer pubsub.Close() diff --git a/commands.go b/commands.go index beb3af2..8db115d 100644 --- a/commands.go +++ b/commands.go @@ -345,9 +345,12 @@ type Cmdable interface { ScriptLoad(ctx context.Context, script string) *StringCmd Publish(ctx context.Context, channel string, message interface{}) *IntCmd + SPublish(ctx context.Context, channel string, message interface{}) *IntCmd PubSubChannels(ctx context.Context, pattern string) *StringSliceCmd PubSubNumSub(ctx context.Context, channels ...string) *StringIntMapCmd PubSubNumPat(ctx context.Context) *IntCmd + PubSubShardChannels(ctx context.Context, pattern string) *StringSliceCmd + PubSubShardNumSub(ctx context.Context, channels ...string) *StringIntMapCmd ClusterSlots(ctx context.Context) *ClusterSlotsCmd ClusterNodes(ctx context.Context) *StringCmd @@ -3078,6 +3081,12 @@ func (c cmdable) Publish(ctx context.Context, channel string, message interface{ return cmd } +func (c cmdable) SPublish(ctx context.Context, channel string, message interface{}) *IntCmd { + cmd := NewIntCmd(ctx, "spublish", channel, message) + _ = c(ctx, cmd) + return cmd +} + func (c cmdable) PubSubChannels(ctx context.Context, pattern string) *StringSliceCmd { args := []interface{}{"pubsub", "channels"} if pattern != "*" { @@ -3100,6 +3109,28 @@ func (c cmdable) PubSubNumSub(ctx context.Context, channels ...string) *StringIn return cmd } +func (c cmdable) PubSubShardChannels(ctx context.Context, pattern string) *StringSliceCmd { + args := []interface{}{"pubsub", "shardchannels"} + if pattern != "*" { + args = append(args, pattern) + } + cmd := NewStringSliceCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +func (c cmdable) PubSubShardNumSub(ctx context.Context, channels ...string) *StringIntMapCmd { + args := make([]interface{}, 2+len(channels)) + args[0] = "pubsub" + args[1] = "shardnumsub" + for i, channel := range channels { + args[2+i] = channel + } + cmd := NewStringIntMapCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + func (c cmdable) PubSubNumPat(ctx context.Context) *IntCmd { cmd := NewIntCmd(ctx, "pubsub", "numpat") _ = c(ctx, cmd) diff --git a/pubsub.go b/pubsub.go index 75e5097..6eede91 100644 --- a/pubsub.go +++ b/pubsub.go @@ -28,6 +28,7 @@ type PubSub struct { cn *pool.Conn channels map[string]struct{} patterns map[string]struct{} + schannels map[string]struct{} closed bool exit chan struct{} @@ -46,6 +47,7 @@ func (c *PubSub) init() { func (c *PubSub) String() string { channels := mapKeys(c.channels) channels = append(channels, mapKeys(c.patterns)...) + channels = append(channels, mapKeys(c.schannels)...) return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", ")) } @@ -101,6 +103,13 @@ func (c *PubSub) resubscribe(ctx context.Context, cn *pool.Conn) error { } } + if len(c.schannels) > 0 { + err := c._subscribe(ctx, cn, "ssubscribe", mapKeys(c.schannels)) + if err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr } @@ -208,6 +217,21 @@ func (c *PubSub) PSubscribe(ctx context.Context, patterns ...string) error { return err } +// SSubscribe Subscribes the client to the specified shard channels. +func (c *PubSub) SSubscribe(ctx context.Context, channels ...string) error { + c.mu.Lock() + defer c.mu.Unlock() + + err := c.subscribe(ctx, "ssubscribe", channels...) + if c.schannels == nil { + c.schannels = make(map[string]struct{}) + } + for _, s := range channels { + c.schannels[s] = struct{}{} + } + return err +} + // Unsubscribe the client from the given channels, or from all of // them if none is given. func (c *PubSub) Unsubscribe(ctx context.Context, channels ...string) error { @@ -234,6 +258,19 @@ func (c *PubSub) PUnsubscribe(ctx context.Context, patterns ...string) error { return err } +// SUnsubscribe unsubscribes the client from the given shard channels, +// or from all of them if none is given. +func (c *PubSub) SUnsubscribe(ctx context.Context, channels ...string) error { + c.mu.Lock() + defer c.mu.Unlock() + + for _, channel := range channels { + delete(c.schannels, channel) + } + err := c.subscribe(ctx, "sunsubscribe", channels...) + return err +} + func (c *PubSub) subscribe(ctx context.Context, redisCmd string, channels ...string) error { cn, err := c.conn(ctx, channels) if err != nil { @@ -311,7 +348,7 @@ func (c *PubSub) newMessage(reply interface{}) (interface{}, error) { }, nil case []interface{}: switch kind := reply[0].(string); kind { - case "subscribe", "unsubscribe", "psubscribe", "punsubscribe": + case "subscribe", "unsubscribe", "psubscribe", "punsubscribe", "ssubscribe", "sunsubscribe": // Can be nil in case of "unsubscribe". channel, _ := reply[1].(string) return &Subscription{ @@ -319,7 +356,7 @@ func (c *PubSub) newMessage(reply interface{}) (interface{}, error) { Channel: channel, Count: int(reply[2].(int64)), }, nil - case "message": + case "message", "smessage": switch payload := reply[2].(type) { case string: return &Message{ diff --git a/pubsub_test.go b/pubsub_test.go index 892118e..37f0e99 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -102,6 +102,35 @@ var _ = Describe("PubSub", func() { Expect(len(channels)).To(BeNumerically(">=", 2)) }) + It("should sharded pub/sub channels", func() { + channels, err := client.PubSubShardChannels(ctx, "mychannel*").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(channels).To(BeEmpty()) + + pubsub := client.SSubscribe(ctx, "mychannel", "mychannel2") + defer pubsub.Close() + + channels, err = client.PubSubShardChannels(ctx, "mychannel*").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(channels).To(ConsistOf([]string{"mychannel", "mychannel2"})) + + channels, err = client.PubSubShardChannels(ctx, "").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(channels).To(BeEmpty()) + + channels, err = client.PubSubShardChannels(ctx, "*").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(len(channels)).To(BeNumerically(">=", 2)) + + nums, err := client.PubSubShardNumSub(ctx, "mychannel", "mychannel2", "mychannel3").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(nums).To(Equal(map[string]int64{ + "mychannel": 1, + "mychannel2": 1, + "mychannel3": 0, + })) + }) + It("should return the numbers of subscribers", func() { pubsub := client.Subscribe(ctx, "mychannel", "mychannel2") defer pubsub.Close() @@ -204,6 +233,82 @@ var _ = Describe("PubSub", func() { Expect(stats.Misses).To(Equal(uint32(1))) }) + It("should sharded pub/sub", func() { + pubsub := client.SSubscribe(ctx, "mychannel", "mychannel2") + defer pubsub.Close() + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*redis.Subscription) + Expect(subscr.Kind).To(Equal("ssubscribe")) + Expect(subscr.Channel).To(Equal("mychannel")) + Expect(subscr.Count).To(Equal(1)) + } + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*redis.Subscription) + Expect(subscr.Kind).To(Equal("ssubscribe")) + Expect(subscr.Channel).To(Equal("mychannel2")) + Expect(subscr.Count).To(Equal(2)) + } + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err.(net.Error).Timeout()).To(Equal(true)) + Expect(msgi).NotTo(HaveOccurred()) + } + + n, err := client.SPublish(ctx, "mychannel", "hello").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(1))) + + n, err = client.SPublish(ctx, "mychannel2", "hello2").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(1))) + + Expect(pubsub.SUnsubscribe(ctx, "mychannel", "mychannel2")).NotTo(HaveOccurred()) + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + msg := msgi.(*redis.Message) + Expect(msg.Channel).To(Equal("mychannel")) + Expect(msg.Payload).To(Equal("hello")) + } + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + msg := msgi.(*redis.Message) + Expect(msg.Channel).To(Equal("mychannel2")) + Expect(msg.Payload).To(Equal("hello2")) + } + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*redis.Subscription) + Expect(subscr.Kind).To(Equal("sunsubscribe")) + Expect(subscr.Channel).To(Equal("mychannel")) + Expect(subscr.Count).To(Equal(1)) + } + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*redis.Subscription) + Expect(subscr.Kind).To(Equal("sunsubscribe")) + Expect(subscr.Channel).To(Equal("mychannel2")) + Expect(subscr.Count).To(Equal(0)) + } + + stats := client.PoolStats() + Expect(stats.Misses).To(Equal(uint32(1))) + }) + It("should ping/pong", func() { pubsub := client.Subscribe(ctx, "mychannel") defer pubsub.Close() diff --git a/redis.go b/redis.go index f558181..324a1ca 100644 --- a/redis.go +++ b/redis.go @@ -691,6 +691,16 @@ func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub { return pubsub } +// SSubscribe Subscribes the client to the specified shard channels. +// Channels can be omitted to create empty subscription. +func (c *Client) SSubscribe(ctx context.Context, channels ...string) *PubSub { + pubsub := c.pubSub() + if len(channels) > 0 { + _ = pubsub.SSubscribe(ctx, channels...) + } + return pubsub +} + //------------------------------------------------------------------------------ type conn struct { diff --git a/ring.go b/ring.go index dede1e4..9386dfe 100644 --- a/ring.go +++ b/ring.go @@ -504,6 +504,19 @@ func (c *Ring) PSubscribe(ctx context.Context, channels ...string) *PubSub { return shard.Client.PSubscribe(ctx, channels...) } +// SSubscribe Subscribes the client to the specified shard channels. +func (c *Ring) SSubscribe(ctx context.Context, channels ...string) *PubSub { + if len(channels) == 0 { + panic("at least one channel is required") + } + shard, err := c.shards.GetByKey(channels[0]) + if err != nil { + // TODO: return PubSub with sticky error + panic(err) + } + return shard.Client.SSubscribe(ctx, channels...) +} + // ForEachShard concurrently calls the fn on each live shard in the ring. // It returns the first error if any. func (c *Ring) ForEachShard( diff --git a/universal.go b/universal.go index a3c5b9d..f800935 100644 --- a/universal.go +++ b/universal.go @@ -190,6 +190,7 @@ type UniversalClient interface { Process(ctx context.Context, cmd Cmder) error Subscribe(ctx context.Context, channels ...string) *PubSub PSubscribe(ctx context.Context, channels ...string) *PubSub + SSubscribe(ctx context.Context, channels ...string) *PubSub Close() error PoolStats() *PoolStats }