package redcon import ( "fmt" "strings" "sync" "github.com/tidwall/btree" "github.com/tidwall/match" ) // PubSub is a Redis compatible pub/sub server type PubSub struct { mu sync.RWMutex nextid uint64 initd bool chans *btree.BTree conns map[Conn]*pubSubConn } // Subscribe a connection to PubSub func (ps *PubSub) Subscribe(conn Conn, channel string) { ps.subscribe(conn, false, channel) } // Psubscribe a connection to PubSub func (ps *PubSub) Psubscribe(conn Conn, channel string) { ps.subscribe(conn, true, channel) } // Publish a message to subscribers func (ps *PubSub) Publish(channel, message string) int { ps.mu.RLock() defer ps.mu.RUnlock() if !ps.initd { return 0 } var sent int // write messages to all clients that are subscribed on the channel pivot := &pubSubEntry{pattern: false, channel: channel} ps.chans.Ascend(pivot, func(item interface{}) bool { entry := item.(*pubSubEntry) if entry.channel != pivot.channel || entry.pattern != pivot.pattern { return false } entry.sconn.writeMessage(entry.pattern, "", channel, message) sent++ return true }) // match on and write all psubscribe clients pivot = &pubSubEntry{pattern: true} ps.chans.Ascend(pivot, func(item interface{}) bool { entry := item.(*pubSubEntry) if match.Match(channel, entry.channel) { entry.sconn.writeMessage(entry.pattern, entry.channel, channel, message) } sent++ return true }) return sent } type pubSubConn struct { id uint64 mu sync.Mutex conn Conn dconn DetachedConn entries map[*pubSubEntry]bool } type pubSubEntry struct { pattern bool sconn *pubSubConn channel string } func (sconn *pubSubConn) writeMessage(pat bool, pchan, channel, msg string) { sconn.mu.Lock() defer sconn.mu.Unlock() if pat { sconn.dconn.WriteArray(4) sconn.dconn.WriteBulkString("pmessage") sconn.dconn.WriteBulkString(pchan) sconn.dconn.WriteBulkString(channel) sconn.dconn.WriteBulkString(msg) } else { sconn.dconn.WriteArray(3) sconn.dconn.WriteBulkString("message") sconn.dconn.WriteBulkString(channel) sconn.dconn.WriteBulkString(msg) } sconn.dconn.Flush() } // bgrunner runs in the background and reads incoming commands from the // detached client. func (sconn *pubSubConn) bgrunner(ps *PubSub) { defer func() { // client connection has ended, disconnect from the PubSub instances // and close the network connection. ps.mu.Lock() defer ps.mu.Unlock() for entry := range sconn.entries { ps.chans.Delete(entry) } delete(ps.conns, sconn.conn) sconn.mu.Lock() defer sconn.mu.Unlock() sconn.dconn.Close() }() for { cmd, err := sconn.dconn.ReadCommand() if err != nil { return } if len(cmd.Args) == 0 { continue } switch strings.ToLower(string(cmd.Args[0])) { case "psubscribe", "subscribe": if len(cmd.Args) < 2 { func() { sconn.mu.Lock() defer sconn.mu.Unlock() sconn.dconn.WriteError(fmt.Sprintf("ERR wrong number of "+ "arguments for '%s'", cmd.Args[0])) sconn.dconn.Flush() }() continue } command := strings.ToLower(string(cmd.Args[0])) for i := 1; i < len(cmd.Args); i++ { if command == "psubscribe" { ps.Psubscribe(sconn.conn, string(cmd.Args[i])) } else { ps.Subscribe(sconn.conn, string(cmd.Args[i])) } } case "unsubscribe", "punsubscribe": pattern := strings.ToLower(string(cmd.Args[0])) == "punsubscribe" if len(cmd.Args) == 1 { ps.unsubscribe(sconn.conn, pattern, true, "") } else { for i := 1; i < len(cmd.Args); i++ { channel := string(cmd.Args[i]) ps.unsubscribe(sconn.conn, pattern, false, channel) } } case "quit": func() { sconn.mu.Lock() defer sconn.mu.Unlock() sconn.dconn.WriteString("OK") sconn.dconn.Flush() sconn.dconn.Close() }() return case "ping": var msg string switch len(cmd.Args) { case 1: case 2: msg = string(cmd.Args[1]) default: func() { sconn.mu.Lock() defer sconn.mu.Unlock() sconn.dconn.WriteError(fmt.Sprintf("ERR wrong number of "+ "arguments for '%s'", cmd.Args[0])) sconn.dconn.Flush() }() continue } func() { sconn.mu.Lock() defer sconn.mu.Unlock() sconn.dconn.WriteArray(2) sconn.dconn.WriteBulkString("pong") sconn.dconn.WriteBulkString(msg) sconn.dconn.Flush() }() default: func() { sconn.mu.Lock() defer sconn.mu.Unlock() sconn.dconn.WriteError(fmt.Sprintf("ERR Can't execute '%s': "+ "only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT are "+ "allowed in this context", cmd.Args[0])) sconn.dconn.Flush() }() } } } // byEntry is a "less" function that sorts the entries in a btree. The tree // is sorted be (pattern, channel, conn.id). All pattern=true entries are at // the end (right) of the tree. func byEntry(a, b interface{}) bool { aa := a.(*pubSubEntry) bb := b.(*pubSubEntry) if !aa.pattern && bb.pattern { return true } if aa.pattern && !bb.pattern { return false } if aa.channel < bb.channel { return true } if aa.channel > bb.channel { return false } var aid uint64 var bid uint64 if aa.sconn != nil { aid = aa.sconn.id } if bb.sconn != nil { bid = bb.sconn.id } return aid < bid } func (ps *PubSub) subscribe(conn Conn, pattern bool, channel string) { ps.mu.Lock() defer ps.mu.Unlock() // initialize the PubSub instance if !ps.initd { ps.conns = make(map[Conn]*pubSubConn) ps.chans = btree.New(byEntry) ps.initd = true } // fetch the pubSubConn sconn, ok := ps.conns[conn] if !ok { // initialize a new pubSubConn, which runs on a detached connection, // and attach it to the PubSub channels/conn btree ps.nextid++ dconn := conn.Detach() sconn = &pubSubConn{ id: ps.nextid, conn: conn, dconn: dconn, entries: make(map[*pubSubEntry]bool), } ps.conns[conn] = sconn } sconn.mu.Lock() defer sconn.mu.Unlock() // add an entry to the pubsub btree entry := &pubSubEntry{ pattern: pattern, channel: channel, sconn: sconn, } ps.chans.Set(entry) sconn.entries[entry] = true // send a message to the client sconn.dconn.WriteArray(3) if pattern { sconn.dconn.WriteBulkString("psubscribe") } else { sconn.dconn.WriteBulkString("subscribe") } sconn.dconn.WriteBulkString(channel) var count int for ient := range sconn.entries { if ient.pattern == pattern { count++ } } sconn.dconn.WriteInt(count) sconn.dconn.Flush() // start the background client operation if !ok { go sconn.bgrunner(ps) } } func (ps *PubSub) unsubscribe(conn Conn, pattern, all bool, channel string) { ps.mu.Lock() defer ps.mu.Unlock() // fetch the pubSubConn. This must exist sconn := ps.conns[conn] sconn.mu.Lock() defer sconn.mu.Unlock() removeEntry := func(entry *pubSubEntry) { if entry != nil { ps.chans.Delete(entry) delete(sconn.entries, entry) } sconn.dconn.WriteArray(3) if pattern { sconn.dconn.WriteBulkString("punsubscribe") } else { sconn.dconn.WriteBulkString("unsubscribe") } if entry != nil { sconn.dconn.WriteBulkString(entry.channel) } else { sconn.dconn.WriteNull() } var count int for ient := range sconn.entries { if ient.pattern == pattern { count++ } } sconn.dconn.WriteInt(count) } if all { // unsubscribe from all (p)subscribe entries var entries []*pubSubEntry for ient := range sconn.entries { if ient.pattern == pattern { entries = append(entries, ient) } } if len(entries) == 0 { removeEntry(nil) } else { for _, entry := range entries { removeEntry(entry) } } } else { // unsubscribe single channel from (p)subscribe. var entry *pubSubEntry for ient := range sconn.entries { if ient.pattern == pattern && ient.channel == channel { removeEntry(entry) break } } removeEntry(entry) } sconn.dconn.Flush() }