feat(pubsub): support sharded pub/sub

This commit is contained in:
jianghang 2022-08-03 23:10:03 +08:00
parent 084c0c8914
commit baa48a4415
8 changed files with 233 additions and 2 deletions

View File

@ -1528,6 +1528,16 @@ func (c *ClusterClient) PSubscribe(ctx context.Context, channels ...string) *Pub
return pubsub
}
// SSubscribe Subscribes the client to the specified shard channels.
func (c *ClusterClient) SSubscribe(ctx context.Context, channels ...string) *PubSub {
pubsub := c.pubSub()
if len(channels) > 0 {
_ = pubsub.SSubscribe(ctx, channels...)
}
return pubsub
}
func (c *ClusterClient) retryBackoff(attempt int) time.Duration {
return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
}

View File

@ -549,6 +549,30 @@ var _ = Describe("ClusterClient", func() {
}, 30*time.Second).ShouldNot(HaveOccurred())
})
It("supports sharded PubSub", func() {
pubsub := client.SSubscribe(ctx, "mychannel")
defer pubsub.Close()
Eventually(func() error {
_, err := client.SPublish(ctx, "mychannel", "hello").Result()
if err != nil {
return err
}
msg, err := pubsub.ReceiveTimeout(ctx, time.Second)
if err != nil {
return err
}
_, ok := msg.(*redis.Message)
if !ok {
return fmt.Errorf("got %T, wanted *redis.Message", msg)
}
return nil
}, 30*time.Second).ShouldNot(HaveOccurred())
})
It("supports PubSub.Ping without channels", func() {
pubsub := client.Subscribe(ctx)
defer pubsub.Close()

View File

@ -345,9 +345,12 @@ type Cmdable interface {
ScriptLoad(ctx context.Context, script string) *StringCmd
Publish(ctx context.Context, channel string, message interface{}) *IntCmd
SPublish(ctx context.Context, channel string, message interface{}) *IntCmd
PubSubChannels(ctx context.Context, pattern string) *StringSliceCmd
PubSubNumSub(ctx context.Context, channels ...string) *StringIntMapCmd
PubSubNumPat(ctx context.Context) *IntCmd
PubSubShardChannels(ctx context.Context, pattern string) *StringSliceCmd
PubSubShardNumSub(ctx context.Context, channels ...string) *StringIntMapCmd
ClusterSlots(ctx context.Context) *ClusterSlotsCmd
ClusterNodes(ctx context.Context) *StringCmd
@ -3078,6 +3081,12 @@ func (c cmdable) Publish(ctx context.Context, channel string, message interface{
return cmd
}
func (c cmdable) SPublish(ctx context.Context, channel string, message interface{}) *IntCmd {
cmd := NewIntCmd(ctx, "spublish", channel, message)
_ = c(ctx, cmd)
return cmd
}
func (c cmdable) PubSubChannels(ctx context.Context, pattern string) *StringSliceCmd {
args := []interface{}{"pubsub", "channels"}
if pattern != "*" {
@ -3100,6 +3109,28 @@ func (c cmdable) PubSubNumSub(ctx context.Context, channels ...string) *StringIn
return cmd
}
func (c cmdable) PubSubShardChannels(ctx context.Context, pattern string) *StringSliceCmd {
args := []interface{}{"pubsub", "shardchannels"}
if pattern != "*" {
args = append(args, pattern)
}
cmd := NewStringSliceCmd(ctx, args...)
_ = c(ctx, cmd)
return cmd
}
func (c cmdable) PubSubShardNumSub(ctx context.Context, channels ...string) *StringIntMapCmd {
args := make([]interface{}, 2+len(channels))
args[0] = "pubsub"
args[1] = "shardnumsub"
for i, channel := range channels {
args[2+i] = channel
}
cmd := NewStringIntMapCmd(ctx, args...)
_ = c(ctx, cmd)
return cmd
}
func (c cmdable) PubSubNumPat(ctx context.Context) *IntCmd {
cmd := NewIntCmd(ctx, "pubsub", "numpat")
_ = c(ctx, cmd)

View File

@ -28,6 +28,7 @@ type PubSub struct {
cn *pool.Conn
channels map[string]struct{}
patterns map[string]struct{}
schannels map[string]struct{}
closed bool
exit chan struct{}
@ -46,6 +47,7 @@ func (c *PubSub) init() {
func (c *PubSub) String() string {
channels := mapKeys(c.channels)
channels = append(channels, mapKeys(c.patterns)...)
channels = append(channels, mapKeys(c.schannels)...)
return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", "))
}
@ -101,6 +103,13 @@ func (c *PubSub) resubscribe(ctx context.Context, cn *pool.Conn) error {
}
}
if len(c.schannels) > 0 {
err := c._subscribe(ctx, cn, "ssubscribe", mapKeys(c.schannels))
if err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
@ -208,6 +217,21 @@ func (c *PubSub) PSubscribe(ctx context.Context, patterns ...string) error {
return err
}
// SSubscribe Subscribes the client to the specified shard channels.
func (c *PubSub) SSubscribe(ctx context.Context, channels ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
err := c.subscribe(ctx, "ssubscribe", channels...)
if c.schannels == nil {
c.schannels = make(map[string]struct{})
}
for _, s := range channels {
c.schannels[s] = struct{}{}
}
return err
}
// Unsubscribe the client from the given channels, or from all of
// them if none is given.
func (c *PubSub) Unsubscribe(ctx context.Context, channels ...string) error {
@ -234,6 +258,19 @@ func (c *PubSub) PUnsubscribe(ctx context.Context, patterns ...string) error {
return err
}
// SUnsubscribe unsubscribes the client from the given shard channels,
// or from all of them if none is given.
func (c *PubSub) SUnsubscribe(ctx context.Context, channels ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
for _, channel := range channels {
delete(c.schannels, channel)
}
err := c.subscribe(ctx, "sunsubscribe", channels...)
return err
}
func (c *PubSub) subscribe(ctx context.Context, redisCmd string, channels ...string) error {
cn, err := c.conn(ctx, channels)
if err != nil {
@ -311,7 +348,7 @@ func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
}, nil
case []interface{}:
switch kind := reply[0].(string); kind {
case "subscribe", "unsubscribe", "psubscribe", "punsubscribe":
case "subscribe", "unsubscribe", "psubscribe", "punsubscribe", "ssubscribe", "sunsubscribe":
// Can be nil in case of "unsubscribe".
channel, _ := reply[1].(string)
return &Subscription{
@ -319,7 +356,7 @@ func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
Channel: channel,
Count: int(reply[2].(int64)),
}, nil
case "message":
case "message", "smessage":
switch payload := reply[2].(type) {
case string:
return &Message{

View File

@ -102,6 +102,35 @@ var _ = Describe("PubSub", func() {
Expect(len(channels)).To(BeNumerically(">=", 2))
})
It("should sharded pub/sub channels", func() {
channels, err := client.PubSubShardChannels(ctx, "mychannel*").Result()
Expect(err).NotTo(HaveOccurred())
Expect(channels).To(BeEmpty())
pubsub := client.SSubscribe(ctx, "mychannel", "mychannel2")
defer pubsub.Close()
channels, err = client.PubSubShardChannels(ctx, "mychannel*").Result()
Expect(err).NotTo(HaveOccurred())
Expect(channels).To(ConsistOf([]string{"mychannel", "mychannel2"}))
channels, err = client.PubSubShardChannels(ctx, "").Result()
Expect(err).NotTo(HaveOccurred())
Expect(channels).To(BeEmpty())
channels, err = client.PubSubShardChannels(ctx, "*").Result()
Expect(err).NotTo(HaveOccurred())
Expect(len(channels)).To(BeNumerically(">=", 2))
nums, err := client.PubSubShardNumSub(ctx, "mychannel", "mychannel2", "mychannel3").Result()
Expect(err).NotTo(HaveOccurred())
Expect(nums).To(Equal(map[string]int64{
"mychannel": 1,
"mychannel2": 1,
"mychannel3": 0,
}))
})
It("should return the numbers of subscribers", func() {
pubsub := client.Subscribe(ctx, "mychannel", "mychannel2")
defer pubsub.Close()
@ -204,6 +233,82 @@ var _ = Describe("PubSub", func() {
Expect(stats.Misses).To(Equal(uint32(1)))
})
It("should sharded pub/sub", func() {
pubsub := client.SSubscribe(ctx, "mychannel", "mychannel2")
defer pubsub.Close()
{
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
subscr := msgi.(*redis.Subscription)
Expect(subscr.Kind).To(Equal("ssubscribe"))
Expect(subscr.Channel).To(Equal("mychannel"))
Expect(subscr.Count).To(Equal(1))
}
{
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
subscr := msgi.(*redis.Subscription)
Expect(subscr.Kind).To(Equal("ssubscribe"))
Expect(subscr.Channel).To(Equal("mychannel2"))
Expect(subscr.Count).To(Equal(2))
}
{
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err.(net.Error).Timeout()).To(Equal(true))
Expect(msgi).NotTo(HaveOccurred())
}
n, err := client.SPublish(ctx, "mychannel", "hello").Result()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1)))
n, err = client.SPublish(ctx, "mychannel2", "hello2").Result()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1)))
Expect(pubsub.SUnsubscribe(ctx, "mychannel", "mychannel2")).NotTo(HaveOccurred())
{
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
msg := msgi.(*redis.Message)
Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal("hello"))
}
{
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
msg := msgi.(*redis.Message)
Expect(msg.Channel).To(Equal("mychannel2"))
Expect(msg.Payload).To(Equal("hello2"))
}
{
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
subscr := msgi.(*redis.Subscription)
Expect(subscr.Kind).To(Equal("sunsubscribe"))
Expect(subscr.Channel).To(Equal("mychannel"))
Expect(subscr.Count).To(Equal(1))
}
{
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
Expect(err).NotTo(HaveOccurred())
subscr := msgi.(*redis.Subscription)
Expect(subscr.Kind).To(Equal("sunsubscribe"))
Expect(subscr.Channel).To(Equal("mychannel2"))
Expect(subscr.Count).To(Equal(0))
}
stats := client.PoolStats()
Expect(stats.Misses).To(Equal(uint32(1)))
})
It("should ping/pong", func() {
pubsub := client.Subscribe(ctx, "mychannel")
defer pubsub.Close()

View File

@ -691,6 +691,16 @@ func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub {
return pubsub
}
// SSubscribe Subscribes the client to the specified shard channels.
// Channels can be omitted to create empty subscription.
func (c *Client) SSubscribe(ctx context.Context, channels ...string) *PubSub {
pubsub := c.pubSub()
if len(channels) > 0 {
_ = pubsub.SSubscribe(ctx, channels...)
}
return pubsub
}
//------------------------------------------------------------------------------
type conn struct {

13
ring.go
View File

@ -504,6 +504,19 @@ func (c *Ring) PSubscribe(ctx context.Context, channels ...string) *PubSub {
return shard.Client.PSubscribe(ctx, channels...)
}
// SSubscribe Subscribes the client to the specified shard channels.
func (c *Ring) SSubscribe(ctx context.Context, channels ...string) *PubSub {
if len(channels) == 0 {
panic("at least one channel is required")
}
shard, err := c.shards.GetByKey(channels[0])
if err != nil {
// TODO: return PubSub with sticky error
panic(err)
}
return shard.Client.SSubscribe(ctx, channels...)
}
// ForEachShard concurrently calls the fn on each live shard in the ring.
// It returns the first error if any.
func (c *Ring) ForEachShard(

View File

@ -190,6 +190,7 @@ type UniversalClient interface {
Process(ctx context.Context, cmd Cmder) error
Subscribe(ctx context.Context, channels ...string) *PubSub
PSubscribe(ctx context.Context, channels ...string) *PubSub
SSubscribe(ctx context.Context, channels ...string) *PubSub
Close() error
PoolStats() *PoolStats
}