redcon/pubsub.go

348 lines
7.7 KiB
Go

package redcon
import (
"fmt"
"strings"
"sync"
"github.com/tidwall/btree"
"github.com/tidwall/match"
)
// PubSub is a Redis compatible pub/sub server
type PubSub struct {
mu sync.RWMutex
nextid uint64
initd bool
chans *btree.BTree
conns map[Conn]*pubSubConn
}
// Subscribe a connection to PubSub
func (ps *PubSub) Subscribe(conn Conn, channel string) {
ps.subscribe(conn, false, channel)
}
// Psubscribe a connection to PubSub
func (ps *PubSub) Psubscribe(conn Conn, channel string) {
ps.subscribe(conn, true, channel)
}
// Publish a message to subscribers
func (ps *PubSub) Publish(channel, message string) int {
ps.mu.RLock()
defer ps.mu.RUnlock()
if !ps.initd {
return 0
}
var sent int
// write messages to all clients that are subscribed on the channel
pivot := &pubSubEntry{pattern: false, channel: channel}
ps.chans.Ascend(pivot, func(item interface{}) bool {
entry := item.(*pubSubEntry)
if entry.channel != pivot.channel || entry.pattern != pivot.pattern {
return false
}
entry.sconn.writeMessage(entry.pattern, "", channel, message)
sent++
return true
})
// match on and write all psubscribe clients
pivot = &pubSubEntry{pattern: true}
ps.chans.Ascend(pivot, func(item interface{}) bool {
entry := item.(*pubSubEntry)
if match.Match(channel, entry.channel) {
entry.sconn.writeMessage(entry.pattern, entry.channel, channel,
message)
}
sent++
return true
})
return sent
}
type pubSubConn struct {
id uint64
mu sync.Mutex
conn Conn
dconn DetachedConn
entries map[*pubSubEntry]bool
}
type pubSubEntry struct {
pattern bool
sconn *pubSubConn
channel string
}
func (sconn *pubSubConn) writeMessage(pat bool, pchan, channel, msg string) {
sconn.mu.Lock()
defer sconn.mu.Unlock()
if pat {
sconn.dconn.WriteArray(4)
sconn.dconn.WriteBulkString("pmessage")
sconn.dconn.WriteBulkString(pchan)
sconn.dconn.WriteBulkString(channel)
sconn.dconn.WriteBulkString(msg)
} else {
sconn.dconn.WriteArray(3)
sconn.dconn.WriteBulkString("message")
sconn.dconn.WriteBulkString(channel)
sconn.dconn.WriteBulkString(msg)
}
sconn.dconn.Flush()
}
// bgrunner runs in the background and reads incoming commands from the
// detached client.
func (sconn *pubSubConn) bgrunner(ps *PubSub) {
defer func() {
// client connection has ended, disconnect from the PubSub instances
// and close the network connection.
ps.mu.Lock()
defer ps.mu.Unlock()
for entry := range sconn.entries {
ps.chans.Delete(entry)
}
delete(ps.conns, sconn.conn)
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.Close()
}()
for {
cmd, err := sconn.dconn.ReadCommand()
if err != nil {
return
}
if len(cmd.Args) == 0 {
continue
}
switch strings.ToLower(string(cmd.Args[0])) {
case "psubscribe", "subscribe":
if len(cmd.Args) < 2 {
func() {
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.WriteError(fmt.Sprintf("ERR wrong number of "+
"arguments for '%s'", cmd.Args[0]))
sconn.dconn.Flush()
}()
continue
}
command := strings.ToLower(string(cmd.Args[0]))
for i := 1; i < len(cmd.Args); i++ {
if command == "psubscribe" {
ps.Psubscribe(sconn.conn, string(cmd.Args[i]))
} else {
ps.Subscribe(sconn.conn, string(cmd.Args[i]))
}
}
case "unsubscribe", "punsubscribe":
pattern := strings.ToLower(string(cmd.Args[0])) == "punsubscribe"
if len(cmd.Args) == 1 {
ps.unsubscribe(sconn.conn, pattern, true, "")
} else {
for i := 1; i < len(cmd.Args); i++ {
channel := string(cmd.Args[i])
ps.unsubscribe(sconn.conn, pattern, false, channel)
}
}
case "quit":
func() {
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.WriteString("OK")
sconn.dconn.Flush()
sconn.dconn.Close()
}()
return
case "ping":
var msg string
switch len(cmd.Args) {
case 1:
case 2:
msg = string(cmd.Args[1])
default:
func() {
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.WriteError(fmt.Sprintf("ERR wrong number of "+
"arguments for '%s'", cmd.Args[0]))
sconn.dconn.Flush()
}()
continue
}
func() {
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.WriteArray(2)
sconn.dconn.WriteBulkString("pong")
sconn.dconn.WriteBulkString(msg)
sconn.dconn.Flush()
}()
default:
func() {
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.WriteError(fmt.Sprintf("ERR Can't execute '%s': "+
"only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT are "+
"allowed in this context", cmd.Args[0]))
sconn.dconn.Flush()
}()
}
}
}
// byEntry is a "less" function that sorts the entries in a btree. The tree
// is sorted be (pattern, channel, conn.id). All pattern=true entries are at
// the end (right) of the tree.
func byEntry(a, b interface{}) bool {
aa := a.(*pubSubEntry)
bb := b.(*pubSubEntry)
if !aa.pattern && bb.pattern {
return true
}
if aa.pattern && !bb.pattern {
return false
}
if aa.channel < bb.channel {
return true
}
if aa.channel > bb.channel {
return false
}
var aid uint64
var bid uint64
if aa.sconn != nil {
aid = aa.sconn.id
}
if bb.sconn != nil {
bid = bb.sconn.id
}
return aid < bid
}
func (ps *PubSub) subscribe(conn Conn, pattern bool, channel string) {
ps.mu.Lock()
defer ps.mu.Unlock()
// initialize the PubSub instance
if !ps.initd {
ps.conns = make(map[Conn]*pubSubConn)
ps.chans = btree.New(byEntry)
ps.initd = true
}
// fetch the pubSubConn
sconn, ok := ps.conns[conn]
if !ok {
// initialize a new pubSubConn, which runs on a detached connection,
// and attach it to the PubSub channels/conn btree
ps.nextid++
dconn := conn.Detach()
sconn = &pubSubConn{
id: ps.nextid,
conn: conn,
dconn: dconn,
entries: make(map[*pubSubEntry]bool),
}
ps.conns[conn] = sconn
}
sconn.mu.Lock()
defer sconn.mu.Unlock()
// add an entry to the pubsub btree
entry := &pubSubEntry{
pattern: pattern,
channel: channel,
sconn: sconn,
}
ps.chans.Set(entry)
sconn.entries[entry] = true
// send a message to the client
sconn.dconn.WriteArray(3)
if pattern {
sconn.dconn.WriteBulkString("psubscribe")
} else {
sconn.dconn.WriteBulkString("subscribe")
}
sconn.dconn.WriteBulkString(channel)
var count int
for ient := range sconn.entries {
if ient.pattern == pattern {
count++
}
}
sconn.dconn.WriteInt(count)
sconn.dconn.Flush()
// start the background client operation
if !ok {
go sconn.bgrunner(ps)
}
}
func (ps *PubSub) unsubscribe(conn Conn, pattern, all bool, channel string) {
ps.mu.Lock()
defer ps.mu.Unlock()
// fetch the pubSubConn. This must exist
sconn := ps.conns[conn]
sconn.mu.Lock()
defer sconn.mu.Unlock()
removeEntry := func(entry *pubSubEntry) {
if entry != nil {
ps.chans.Delete(entry)
delete(sconn.entries, entry)
}
sconn.dconn.WriteArray(3)
if pattern {
sconn.dconn.WriteBulkString("punsubscribe")
} else {
sconn.dconn.WriteBulkString("unsubscribe")
}
if entry != nil {
sconn.dconn.WriteBulkString(entry.channel)
} else {
sconn.dconn.WriteNull()
}
var count int
for ient := range sconn.entries {
if ient.pattern == pattern {
count++
}
}
sconn.dconn.WriteInt(count)
}
if all {
// unsubscribe from all (p)subscribe entries
var entries []*pubSubEntry
for ient := range sconn.entries {
if ient.pattern == pattern {
entries = append(entries, ient)
}
}
if len(entries) == 0 {
removeEntry(nil)
} else {
for _, entry := range entries {
removeEntry(entry)
}
}
} else {
// unsubscribe single channel from (p)subscribe.
var entry *pubSubEntry
for ient := range sconn.entries {
if ient.pattern == pattern && ient.channel == channel {
removeEntry(entry)
break
}
}
removeEntry(entry)
}
sconn.dconn.Flush()
}