mirror of https://github.com/go-redis/redis.git
feat(pubsub): support sharded pub/sub
This commit is contained in:
parent
084c0c8914
commit
baa48a4415
10
cluster.go
10
cluster.go
|
@ -1528,6 +1528,16 @@ func (c *ClusterClient) PSubscribe(ctx context.Context, channels ...string) *Pub
|
|||
return pubsub
|
||||
}
|
||||
|
||||
// SSubscribe Subscribes the client to the specified shard channels.
|
||||
func (c *ClusterClient) SSubscribe(ctx context.Context, channels ...string) *PubSub {
|
||||
pubsub := c.pubSub()
|
||||
if len(channels) > 0 {
|
||||
_ = pubsub.SSubscribe(ctx, channels...)
|
||||
}
|
||||
return pubsub
|
||||
}
|
||||
|
||||
|
||||
func (c *ClusterClient) retryBackoff(attempt int) time.Duration {
|
||||
return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
|
||||
}
|
||||
|
|
|
@ -549,6 +549,30 @@ var _ = Describe("ClusterClient", func() {
|
|||
}, 30*time.Second).ShouldNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("supports sharded PubSub", func() {
|
||||
pubsub := client.SSubscribe(ctx, "mychannel")
|
||||
defer pubsub.Close()
|
||||
|
||||
Eventually(func() error {
|
||||
_, err := client.SPublish(ctx, "mychannel", "hello").Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msg, err := pubsub.ReceiveTimeout(ctx, time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, ok := msg.(*redis.Message)
|
||||
if !ok {
|
||||
return fmt.Errorf("got %T, wanted *redis.Message", msg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, 30*time.Second).ShouldNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("supports PubSub.Ping without channels", func() {
|
||||
pubsub := client.Subscribe(ctx)
|
||||
defer pubsub.Close()
|
||||
|
|
31
commands.go
31
commands.go
|
@ -345,9 +345,12 @@ type Cmdable interface {
|
|||
ScriptLoad(ctx context.Context, script string) *StringCmd
|
||||
|
||||
Publish(ctx context.Context, channel string, message interface{}) *IntCmd
|
||||
SPublish(ctx context.Context, channel string, message interface{}) *IntCmd
|
||||
PubSubChannels(ctx context.Context, pattern string) *StringSliceCmd
|
||||
PubSubNumSub(ctx context.Context, channels ...string) *StringIntMapCmd
|
||||
PubSubNumPat(ctx context.Context) *IntCmd
|
||||
PubSubShardChannels(ctx context.Context, pattern string) *StringSliceCmd
|
||||
PubSubShardNumSub(ctx context.Context, channels ...string) *StringIntMapCmd
|
||||
|
||||
ClusterSlots(ctx context.Context) *ClusterSlotsCmd
|
||||
ClusterNodes(ctx context.Context) *StringCmd
|
||||
|
@ -3078,6 +3081,12 @@ func (c cmdable) Publish(ctx context.Context, channel string, message interface{
|
|||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) SPublish(ctx context.Context, channel string, message interface{}) *IntCmd {
|
||||
cmd := NewIntCmd(ctx, "spublish", channel, message)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) PubSubChannels(ctx context.Context, pattern string) *StringSliceCmd {
|
||||
args := []interface{}{"pubsub", "channels"}
|
||||
if pattern != "*" {
|
||||
|
@ -3100,6 +3109,28 @@ func (c cmdable) PubSubNumSub(ctx context.Context, channels ...string) *StringIn
|
|||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) PubSubShardChannels(ctx context.Context, pattern string) *StringSliceCmd {
|
||||
args := []interface{}{"pubsub", "shardchannels"}
|
||||
if pattern != "*" {
|
||||
args = append(args, pattern)
|
||||
}
|
||||
cmd := NewStringSliceCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) PubSubShardNumSub(ctx context.Context, channels ...string) *StringIntMapCmd {
|
||||
args := make([]interface{}, 2+len(channels))
|
||||
args[0] = "pubsub"
|
||||
args[1] = "shardnumsub"
|
||||
for i, channel := range channels {
|
||||
args[2+i] = channel
|
||||
}
|
||||
cmd := NewStringIntMapCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) PubSubNumPat(ctx context.Context) *IntCmd {
|
||||
cmd := NewIntCmd(ctx, "pubsub", "numpat")
|
||||
_ = c(ctx, cmd)
|
||||
|
|
41
pubsub.go
41
pubsub.go
|
@ -28,6 +28,7 @@ type PubSub struct {
|
|||
cn *pool.Conn
|
||||
channels map[string]struct{}
|
||||
patterns map[string]struct{}
|
||||
schannels map[string]struct{}
|
||||
|
||||
closed bool
|
||||
exit chan struct{}
|
||||
|
@ -46,6 +47,7 @@ func (c *PubSub) init() {
|
|||
func (c *PubSub) String() string {
|
||||
channels := mapKeys(c.channels)
|
||||
channels = append(channels, mapKeys(c.patterns)...)
|
||||
channels = append(channels, mapKeys(c.schannels)...)
|
||||
return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", "))
|
||||
}
|
||||
|
||||
|
@ -101,6 +103,13 @@ func (c *PubSub) resubscribe(ctx context.Context, cn *pool.Conn) error {
|
|||
}
|
||||
}
|
||||
|
||||
if len(c.schannels) > 0 {
|
||||
err := c._subscribe(ctx, cn, "ssubscribe", mapKeys(c.schannels))
|
||||
if err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
|
||||
return firstErr
|
||||
}
|
||||
|
||||
|
@ -208,6 +217,21 @@ func (c *PubSub) PSubscribe(ctx context.Context, patterns ...string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// SSubscribe Subscribes the client to the specified shard channels.
|
||||
func (c *PubSub) SSubscribe(ctx context.Context, channels ...string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
err := c.subscribe(ctx, "ssubscribe", channels...)
|
||||
if c.schannels == nil {
|
||||
c.schannels = make(map[string]struct{})
|
||||
}
|
||||
for _, s := range channels {
|
||||
c.schannels[s] = struct{}{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Unsubscribe the client from the given channels, or from all of
|
||||
// them if none is given.
|
||||
func (c *PubSub) Unsubscribe(ctx context.Context, channels ...string) error {
|
||||
|
@ -234,6 +258,19 @@ func (c *PubSub) PUnsubscribe(ctx context.Context, patterns ...string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// SUnsubscribe unsubscribes the client from the given shard channels,
|
||||
// or from all of them if none is given.
|
||||
func (c *PubSub) SUnsubscribe(ctx context.Context, channels ...string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
for _, channel := range channels {
|
||||
delete(c.schannels, channel)
|
||||
}
|
||||
err := c.subscribe(ctx, "sunsubscribe", channels...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *PubSub) subscribe(ctx context.Context, redisCmd string, channels ...string) error {
|
||||
cn, err := c.conn(ctx, channels)
|
||||
if err != nil {
|
||||
|
@ -311,7 +348,7 @@ func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
|
|||
}, nil
|
||||
case []interface{}:
|
||||
switch kind := reply[0].(string); kind {
|
||||
case "subscribe", "unsubscribe", "psubscribe", "punsubscribe":
|
||||
case "subscribe", "unsubscribe", "psubscribe", "punsubscribe", "ssubscribe", "sunsubscribe":
|
||||
// Can be nil in case of "unsubscribe".
|
||||
channel, _ := reply[1].(string)
|
||||
return &Subscription{
|
||||
|
@ -319,7 +356,7 @@ func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
|
|||
Channel: channel,
|
||||
Count: int(reply[2].(int64)),
|
||||
}, nil
|
||||
case "message":
|
||||
case "message", "smessage":
|
||||
switch payload := reply[2].(type) {
|
||||
case string:
|
||||
return &Message{
|
||||
|
|
105
pubsub_test.go
105
pubsub_test.go
|
@ -102,6 +102,35 @@ var _ = Describe("PubSub", func() {
|
|||
Expect(len(channels)).To(BeNumerically(">=", 2))
|
||||
})
|
||||
|
||||
It("should sharded pub/sub channels", func() {
|
||||
channels, err := client.PubSubShardChannels(ctx, "mychannel*").Result()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(channels).To(BeEmpty())
|
||||
|
||||
pubsub := client.SSubscribe(ctx, "mychannel", "mychannel2")
|
||||
defer pubsub.Close()
|
||||
|
||||
channels, err = client.PubSubShardChannels(ctx, "mychannel*").Result()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(channels).To(ConsistOf([]string{"mychannel", "mychannel2"}))
|
||||
|
||||
channels, err = client.PubSubShardChannels(ctx, "").Result()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(channels).To(BeEmpty())
|
||||
|
||||
channels, err = client.PubSubShardChannels(ctx, "*").Result()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(len(channels)).To(BeNumerically(">=", 2))
|
||||
|
||||
nums, err := client.PubSubShardNumSub(ctx, "mychannel", "mychannel2", "mychannel3").Result()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(nums).To(Equal(map[string]int64{
|
||||
"mychannel": 1,
|
||||
"mychannel2": 1,
|
||||
"mychannel3": 0,
|
||||
}))
|
||||
})
|
||||
|
||||
It("should return the numbers of subscribers", func() {
|
||||
pubsub := client.Subscribe(ctx, "mychannel", "mychannel2")
|
||||
defer pubsub.Close()
|
||||
|
@ -204,6 +233,82 @@ var _ = Describe("PubSub", func() {
|
|||
Expect(stats.Misses).To(Equal(uint32(1)))
|
||||
})
|
||||
|
||||
It("should sharded pub/sub", func() {
|
||||
pubsub := client.SSubscribe(ctx, "mychannel", "mychannel2")
|
||||
defer pubsub.Close()
|
||||
|
||||
{
|
||||
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
subscr := msgi.(*redis.Subscription)
|
||||
Expect(subscr.Kind).To(Equal("ssubscribe"))
|
||||
Expect(subscr.Channel).To(Equal("mychannel"))
|
||||
Expect(subscr.Count).To(Equal(1))
|
||||
}
|
||||
|
||||
{
|
||||
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
subscr := msgi.(*redis.Subscription)
|
||||
Expect(subscr.Kind).To(Equal("ssubscribe"))
|
||||
Expect(subscr.Channel).To(Equal("mychannel2"))
|
||||
Expect(subscr.Count).To(Equal(2))
|
||||
}
|
||||
|
||||
{
|
||||
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
|
||||
Expect(err.(net.Error).Timeout()).To(Equal(true))
|
||||
Expect(msgi).NotTo(HaveOccurred())
|
||||
}
|
||||
|
||||
n, err := client.SPublish(ctx, "mychannel", "hello").Result()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(n).To(Equal(int64(1)))
|
||||
|
||||
n, err = client.SPublish(ctx, "mychannel2", "hello2").Result()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(n).To(Equal(int64(1)))
|
||||
|
||||
Expect(pubsub.SUnsubscribe(ctx, "mychannel", "mychannel2")).NotTo(HaveOccurred())
|
||||
|
||||
{
|
||||
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
msg := msgi.(*redis.Message)
|
||||
Expect(msg.Channel).To(Equal("mychannel"))
|
||||
Expect(msg.Payload).To(Equal("hello"))
|
||||
}
|
||||
|
||||
{
|
||||
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
msg := msgi.(*redis.Message)
|
||||
Expect(msg.Channel).To(Equal("mychannel2"))
|
||||
Expect(msg.Payload).To(Equal("hello2"))
|
||||
}
|
||||
|
||||
{
|
||||
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
subscr := msgi.(*redis.Subscription)
|
||||
Expect(subscr.Kind).To(Equal("sunsubscribe"))
|
||||
Expect(subscr.Channel).To(Equal("mychannel"))
|
||||
Expect(subscr.Count).To(Equal(1))
|
||||
}
|
||||
|
||||
{
|
||||
msgi, err := pubsub.ReceiveTimeout(ctx, time.Second)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
subscr := msgi.(*redis.Subscription)
|
||||
Expect(subscr.Kind).To(Equal("sunsubscribe"))
|
||||
Expect(subscr.Channel).To(Equal("mychannel2"))
|
||||
Expect(subscr.Count).To(Equal(0))
|
||||
}
|
||||
|
||||
stats := client.PoolStats()
|
||||
Expect(stats.Misses).To(Equal(uint32(1)))
|
||||
})
|
||||
|
||||
It("should ping/pong", func() {
|
||||
pubsub := client.Subscribe(ctx, "mychannel")
|
||||
defer pubsub.Close()
|
||||
|
|
10
redis.go
10
redis.go
|
@ -691,6 +691,16 @@ func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub {
|
|||
return pubsub
|
||||
}
|
||||
|
||||
// SSubscribe Subscribes the client to the specified shard channels.
|
||||
// Channels can be omitted to create empty subscription.
|
||||
func (c *Client) SSubscribe(ctx context.Context, channels ...string) *PubSub {
|
||||
pubsub := c.pubSub()
|
||||
if len(channels) > 0 {
|
||||
_ = pubsub.SSubscribe(ctx, channels...)
|
||||
}
|
||||
return pubsub
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
type conn struct {
|
||||
|
|
13
ring.go
13
ring.go
|
@ -504,6 +504,19 @@ func (c *Ring) PSubscribe(ctx context.Context, channels ...string) *PubSub {
|
|||
return shard.Client.PSubscribe(ctx, channels...)
|
||||
}
|
||||
|
||||
// SSubscribe Subscribes the client to the specified shard channels.
|
||||
func (c *Ring) SSubscribe(ctx context.Context, channels ...string) *PubSub {
|
||||
if len(channels) == 0 {
|
||||
panic("at least one channel is required")
|
||||
}
|
||||
shard, err := c.shards.GetByKey(channels[0])
|
||||
if err != nil {
|
||||
// TODO: return PubSub with sticky error
|
||||
panic(err)
|
||||
}
|
||||
return shard.Client.SSubscribe(ctx, channels...)
|
||||
}
|
||||
|
||||
// ForEachShard concurrently calls the fn on each live shard in the ring.
|
||||
// It returns the first error if any.
|
||||
func (c *Ring) ForEachShard(
|
||||
|
|
|
@ -190,6 +190,7 @@ type UniversalClient interface {
|
|||
Process(ctx context.Context, cmd Cmder) error
|
||||
Subscribe(ctx context.Context, channels ...string) *PubSub
|
||||
PSubscribe(ctx context.Context, channels ...string) *PubSub
|
||||
SSubscribe(ctx context.Context, channels ...string) *PubSub
|
||||
Close() error
|
||||
PoolStats() *PoolStats
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue