diff --git a/pubsub.go b/pubsub.go index 2bd156db..b64c8a4f 100644 --- a/pubsub.go +++ b/pubsub.go @@ -238,9 +238,17 @@ func (c *PubSub) Unsubscribe(ctx context.Context, channels ...string) error { c.mu.Lock() defer c.mu.Unlock() - for _, channel := range channels { - delete(c.channels, channel) + if len(channels) > 0 { + for _, channel := range channels { + delete(c.channels, channel) + } + } else { + // Unsubscribe from all channels. + for channel := range c.channels { + delete(c.channels, channel) + } } + err := c.subscribe(ctx, "unsubscribe", channels...) return err } @@ -251,9 +259,17 @@ func (c *PubSub) PUnsubscribe(ctx context.Context, patterns ...string) error { c.mu.Lock() defer c.mu.Unlock() - for _, pattern := range patterns { - delete(c.patterns, pattern) + if len(patterns) > 0 { + for _, pattern := range patterns { + delete(c.patterns, pattern) + } + } else { + // Unsubscribe from all patterns. + for pattern := range c.patterns { + delete(c.patterns, pattern) + } } + err := c.subscribe(ctx, "punsubscribe", patterns...) return err } @@ -264,9 +280,17 @@ func (c *PubSub) SUnsubscribe(ctx context.Context, channels ...string) error { c.mu.Lock() defer c.mu.Unlock() - for _, channel := range channels { - delete(c.schannels, channel) + if len(channels) > 0 { + for _, channel := range channels { + delete(c.schannels, channel) + } + } else { + // Unsubscribe from all channels. + for channel := range c.schannels { + delete(c.schannels, channel) + } } + err := c.subscribe(ctx, "sunsubscribe", channels...) return err }