diff --git a/cluster.go b/cluster.go index 1e787add..4a295115 100644 --- a/cluster.go +++ b/cluster.go @@ -1409,3 +1409,31 @@ func appendNode(nodes []*clusterNode, node *clusterNode) []*clusterNode { } return append(nodes, node) } + +func appendIfNotExists(ss []string, es ...string) []string { +loop: + for _, e := range es { + for _, s := range ss { + if s == e { + continue loop + } + } + ss = append(ss, e) + } + return ss +} + +func remove(ss []string, es ...string) []string { + if len(es) == 0 { + return ss[:0] + } + for _, e := range es { + for i, s := range ss { + if s == e { + ss = append(ss[:i], ss[i+1:]...) + break + } + } + } + return ss +} diff --git a/pubsub.go b/pubsub.go index 3ee4ea9d..b56728f3 100644 --- a/pubsub.go +++ b/pubsub.go @@ -24,8 +24,8 @@ type PubSub struct { mu sync.Mutex cn *pool.Conn - channels []string - patterns []string + channels map[string]struct{} + patterns map[string]struct{} closed bool cmd *Cmd @@ -67,12 +67,24 @@ func (c *PubSub) _conn(channels []string) (*pool.Conn, error) { func (c *PubSub) resubscribe(cn *pool.Conn) error { var firstErr error if len(c.channels) > 0 { - if err := c._subscribe(cn, "subscribe", c.channels...); err != nil && firstErr == nil { + channels := make([]string, len(c.channels)) + i := 0 + for channel := range c.channels { + channels[i] = channel + i++ + } + if err := c._subscribe(cn, "subscribe", channels...); err != nil && firstErr == nil { firstErr = err } } if len(c.patterns) > 0 { - if err := c._subscribe(cn, "psubscribe", c.patterns...); err != nil && firstErr == nil { + patterns := make([]string, len(c.patterns)) + i := 0 + for pattern := range c.patterns { + patterns[i] = pattern + i++ + } + if err := c._subscribe(cn, "psubscribe", patterns...); err != nil && firstErr == nil { firstErr = err } } @@ -132,7 +144,12 @@ func (c *PubSub) Close() error { func (c *PubSub) Subscribe(channels ...string) error { c.mu.Lock() err := c.subscribe("subscribe", channels...) - c.channels = appendIfNotExists(c.channels, channels...) + if c.channels == nil { + c.channels = make(map[string]struct{}) + } + for _, channel := range channels { + c.channels[channel] = struct{}{} + } c.mu.Unlock() return err } @@ -142,7 +159,12 @@ func (c *PubSub) Subscribe(channels ...string) error { func (c *PubSub) PSubscribe(patterns ...string) error { c.mu.Lock() err := c.subscribe("psubscribe", patterns...) - c.patterns = appendIfNotExists(c.patterns, patterns...) + if c.patterns == nil { + c.patterns = make(map[string]struct{}) + } + for _, pattern := range patterns { + c.patterns[pattern] = struct{}{} + } c.mu.Unlock() return err } @@ -152,7 +174,9 @@ func (c *PubSub) PSubscribe(patterns ...string) error { func (c *PubSub) Unsubscribe(channels ...string) error { c.mu.Lock() err := c.subscribe("unsubscribe", channels...) - c.channels = remove(c.channels, channels...) + for _, channel := range channels { + delete(c.channels, channel) + } c.mu.Unlock() return err } @@ -162,7 +186,9 @@ func (c *PubSub) Unsubscribe(channels ...string) error { func (c *PubSub) PUnsubscribe(patterns ...string) error { c.mu.Lock() err := c.subscribe("punsubscribe", patterns...) - c.patterns = remove(c.patterns, patterns...) + for _, pattern := range patterns { + delete(c.patterns, pattern) + } c.mu.Unlock() return err } @@ -371,31 +397,3 @@ func (c *PubSub) Channel() <-chan *Message { }) return c.ch } - -func appendIfNotExists(ss []string, es ...string) []string { -loop: - for _, e := range es { - for _, s := range ss { - if s == e { - continue loop - } - } - ss = append(ss, e) - } - return ss -} - -func remove(ss []string, es ...string) []string { - if len(es) == 0 { - return ss[:0] - } - for _, e := range es { - for i, s := range ss { - if s == e { - ss = append(ss[:i], ss[i+1:]...) - break - } - } - } - return ss -}