mirror of https://github.com/tidwall/redcon.git
Added PubSub
This commit is contained in:
parent
2797057b75
commit
d5576bf8a3
|
@ -13,12 +13,33 @@ var addr = ":6380"
|
|||
func main() {
|
||||
var mu sync.RWMutex
|
||||
var items = make(map[string][]byte)
|
||||
var ps redcon.PubSub
|
||||
go log.Printf("started server at %s", addr)
|
||||
err := redcon.ListenAndServe(addr,
|
||||
func(conn redcon.Conn, cmd redcon.Command) {
|
||||
switch strings.ToLower(string(cmd.Args[0])) {
|
||||
default:
|
||||
conn.WriteError("ERR unknown command '" + string(cmd.Args[0]) + "'")
|
||||
case "publish":
|
||||
if len(cmd.Args) != 3 {
|
||||
conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command")
|
||||
return
|
||||
}
|
||||
count := ps.Publish(string(cmd.Args[1]), string(cmd.Args[2]))
|
||||
conn.WriteInt(count)
|
||||
case "subscribe", "psubscribe":
|
||||
if len(cmd.Args) < 2 {
|
||||
conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command")
|
||||
return
|
||||
}
|
||||
command := strings.ToLower(string(cmd.Args[0]))
|
||||
for i := 1; i < len(cmd.Args); i++ {
|
||||
if command == "psubscribe" {
|
||||
ps.Psubscribe(conn, string(cmd.Args[i]))
|
||||
} else {
|
||||
ps.Subscribe(conn, string(cmd.Args[i]))
|
||||
}
|
||||
}
|
||||
case "detach":
|
||||
hconn := conn.Detach()
|
||||
log.Printf("connection has been detached")
|
||||
|
@ -27,7 +48,6 @@ func main() {
|
|||
hconn.WriteString("OK")
|
||||
hconn.Flush()
|
||||
}()
|
||||
return
|
||||
case "ping":
|
||||
conn.WriteString("PONG")
|
||||
case "quit":
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
module github.com/tidwall/redcon
|
||||
|
||||
go 1.15
|
||||
|
||||
require (
|
||||
github.com/tidwall/btree v0.2.2
|
||||
github.com/tidwall/match v1.0.1
|
||||
)
|
|
@ -0,0 +1,4 @@
|
|||
github.com/tidwall/btree v0.2.2 h1:VVo0JW/tdidNdQzNsDR4wMbL3heaxA1DGleyzQ3/niY=
|
||||
github.com/tidwall/btree v0.2.2/go.mod h1:huei1BkDWJ3/sLXmO+bsCNELL+Bp2Kks9OLyQFkzvA8=
|
||||
github.com/tidwall/match v1.0.1 h1:PnKP62LPNxHKTwvHHZZzdOAOCtsJTjo6dZLCwpKm5xc=
|
||||
github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E=
|
|
@ -0,0 +1,347 @@
|
|||
package redcon
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/tidwall/btree"
|
||||
"github.com/tidwall/match"
|
||||
)
|
||||
|
||||
// PubSub is a Redis compatible pub/sub server
|
||||
type PubSub struct {
|
||||
mu sync.RWMutex
|
||||
nextid uint64
|
||||
initd bool
|
||||
chans *btree.BTree
|
||||
conns map[Conn]*pubSubConn
|
||||
}
|
||||
|
||||
// Subscribe a connection to PubSub
|
||||
func (ps *PubSub) Subscribe(conn Conn, channel string) {
|
||||
ps.subscribe(conn, false, channel)
|
||||
}
|
||||
|
||||
// Psubscribe a connection to PubSub
|
||||
func (ps *PubSub) Psubscribe(conn Conn, channel string) {
|
||||
ps.subscribe(conn, true, channel)
|
||||
}
|
||||
|
||||
// Publish a message to subscribers
|
||||
func (ps *PubSub) Publish(channel, message string) int {
|
||||
ps.mu.RLock()
|
||||
defer ps.mu.RUnlock()
|
||||
if !ps.initd {
|
||||
return 0
|
||||
}
|
||||
var sent int
|
||||
// write messages to all clients that are subscribed on the channel
|
||||
pivot := &pubSubEntry{pattern: false, channel: channel}
|
||||
ps.chans.Ascend(pivot, func(item interface{}) bool {
|
||||
entry := item.(*pubSubEntry)
|
||||
if entry.channel != pivot.channel || entry.pattern != pivot.pattern {
|
||||
return false
|
||||
}
|
||||
entry.sconn.writeMessage(entry.pattern, "", channel, message)
|
||||
sent++
|
||||
return true
|
||||
})
|
||||
|
||||
// match on and write all psubscribe clients
|
||||
pivot = &pubSubEntry{pattern: true}
|
||||
ps.chans.Ascend(pivot, func(item interface{}) bool {
|
||||
entry := item.(*pubSubEntry)
|
||||
if match.Match(channel, entry.channel) {
|
||||
entry.sconn.writeMessage(entry.pattern, entry.channel, channel,
|
||||
message)
|
||||
}
|
||||
sent++
|
||||
return true
|
||||
})
|
||||
|
||||
return sent
|
||||
}
|
||||
|
||||
type pubSubConn struct {
|
||||
id uint64
|
||||
mu sync.Mutex
|
||||
conn Conn
|
||||
dconn DetachedConn
|
||||
entries map[*pubSubEntry]bool
|
||||
}
|
||||
|
||||
type pubSubEntry struct {
|
||||
pattern bool
|
||||
sconn *pubSubConn
|
||||
channel string
|
||||
}
|
||||
|
||||
func (sconn *pubSubConn) writeMessage(pat bool, pchan, channel, msg string) {
|
||||
sconn.mu.Lock()
|
||||
defer sconn.mu.Unlock()
|
||||
if pat {
|
||||
sconn.dconn.WriteArray(4)
|
||||
sconn.dconn.WriteBulkString("pmessage")
|
||||
sconn.dconn.WriteBulkString(pchan)
|
||||
sconn.dconn.WriteBulkString(channel)
|
||||
sconn.dconn.WriteBulkString(msg)
|
||||
} else {
|
||||
sconn.dconn.WriteArray(3)
|
||||
sconn.dconn.WriteBulkString("message")
|
||||
sconn.dconn.WriteBulkString(channel)
|
||||
sconn.dconn.WriteBulkString(msg)
|
||||
}
|
||||
sconn.dconn.Flush()
|
||||
}
|
||||
|
||||
// bgrunner runs in the background and reads incoming commands from the
|
||||
// detached client.
|
||||
func (sconn *pubSubConn) bgrunner(ps *PubSub) {
|
||||
defer func() {
|
||||
// client connection has ended, disconnect from the PubSub instances
|
||||
// and close the network connection.
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
for entry := range sconn.entries {
|
||||
ps.chans.Delete(entry)
|
||||
}
|
||||
delete(ps.conns, sconn.conn)
|
||||
sconn.mu.Lock()
|
||||
defer sconn.mu.Unlock()
|
||||
sconn.dconn.Close()
|
||||
}()
|
||||
for {
|
||||
cmd, err := sconn.dconn.ReadCommand()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(cmd.Args) == 0 {
|
||||
continue
|
||||
}
|
||||
switch strings.ToLower(string(cmd.Args[0])) {
|
||||
case "psubscribe", "subscribe":
|
||||
if len(cmd.Args) < 2 {
|
||||
func() {
|
||||
sconn.mu.Lock()
|
||||
defer sconn.mu.Unlock()
|
||||
sconn.dconn.WriteError(fmt.Sprintf("ERR wrong number of "+
|
||||
"arguments for '%s'", cmd.Args[0]))
|
||||
sconn.dconn.Flush()
|
||||
}()
|
||||
continue
|
||||
}
|
||||
command := strings.ToLower(string(cmd.Args[0]))
|
||||
for i := 1; i < len(cmd.Args); i++ {
|
||||
if command == "psubscribe" {
|
||||
ps.Psubscribe(sconn.conn, string(cmd.Args[i]))
|
||||
} else {
|
||||
ps.Subscribe(sconn.conn, string(cmd.Args[i]))
|
||||
}
|
||||
}
|
||||
case "unsubscribe", "punsubscribe":
|
||||
pattern := strings.ToLower(string(cmd.Args[0])) == "punsubscribe"
|
||||
if len(cmd.Args) == 1 {
|
||||
ps.unsubscribe(sconn.conn, pattern, true, "")
|
||||
} else {
|
||||
for i := 1; i < len(cmd.Args); i++ {
|
||||
channel := string(cmd.Args[i])
|
||||
ps.unsubscribe(sconn.conn, pattern, false, channel)
|
||||
}
|
||||
}
|
||||
case "quit":
|
||||
func() {
|
||||
sconn.mu.Lock()
|
||||
defer sconn.mu.Unlock()
|
||||
sconn.dconn.WriteString("OK")
|
||||
sconn.dconn.Flush()
|
||||
sconn.dconn.Close()
|
||||
}()
|
||||
return
|
||||
case "ping":
|
||||
var msg string
|
||||
switch len(cmd.Args) {
|
||||
case 1:
|
||||
case 2:
|
||||
msg = string(cmd.Args[1])
|
||||
default:
|
||||
func() {
|
||||
sconn.mu.Lock()
|
||||
defer sconn.mu.Unlock()
|
||||
sconn.dconn.WriteError(fmt.Sprintf("ERR wrong number of "+
|
||||
"arguments for '%s'", cmd.Args[0]))
|
||||
sconn.dconn.Flush()
|
||||
}()
|
||||
continue
|
||||
}
|
||||
func() {
|
||||
sconn.mu.Lock()
|
||||
defer sconn.mu.Unlock()
|
||||
sconn.dconn.WriteArray(2)
|
||||
sconn.dconn.WriteBulkString("pong")
|
||||
sconn.dconn.WriteBulkString(msg)
|
||||
sconn.dconn.Flush()
|
||||
}()
|
||||
default:
|
||||
func() {
|
||||
sconn.mu.Lock()
|
||||
defer sconn.mu.Unlock()
|
||||
sconn.dconn.WriteError(fmt.Sprintf("ERR Can't execute '%s': "+
|
||||
"only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT are "+
|
||||
"allowed in this context", cmd.Args[0]))
|
||||
sconn.dconn.Flush()
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// byEntry is a "less" function that sorts the entries in a btree. The tree
|
||||
// is sorted be (pattern, channel, conn.id). All pattern=true entries are at
|
||||
// the end (right) of the tree.
|
||||
func byEntry(a, b interface{}) bool {
|
||||
aa := a.(*pubSubEntry)
|
||||
bb := b.(*pubSubEntry)
|
||||
if !aa.pattern && bb.pattern {
|
||||
return true
|
||||
}
|
||||
if aa.pattern && !bb.pattern {
|
||||
return false
|
||||
}
|
||||
if aa.channel < bb.channel {
|
||||
return true
|
||||
}
|
||||
if aa.channel > bb.channel {
|
||||
return false
|
||||
}
|
||||
var aid uint64
|
||||
var bid uint64
|
||||
if aa.sconn != nil {
|
||||
aid = aa.sconn.id
|
||||
}
|
||||
if bb.sconn != nil {
|
||||
bid = bb.sconn.id
|
||||
}
|
||||
return aid < bid
|
||||
}
|
||||
|
||||
func (ps *PubSub) subscribe(conn Conn, pattern bool, channel string) {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
|
||||
// initialize the PubSub instance
|
||||
if !ps.initd {
|
||||
ps.conns = make(map[Conn]*pubSubConn)
|
||||
ps.chans = btree.New(byEntry)
|
||||
ps.initd = true
|
||||
}
|
||||
|
||||
// fetch the pubSubConn
|
||||
sconn, ok := ps.conns[conn]
|
||||
if !ok {
|
||||
// initialize a new pubSubConn, which runs on a detached connection,
|
||||
// and attach it to the PubSub channels/conn btree
|
||||
ps.nextid++
|
||||
dconn := conn.Detach()
|
||||
sconn = &pubSubConn{
|
||||
id: ps.nextid,
|
||||
conn: conn,
|
||||
dconn: dconn,
|
||||
entries: make(map[*pubSubEntry]bool),
|
||||
}
|
||||
ps.conns[conn] = sconn
|
||||
}
|
||||
sconn.mu.Lock()
|
||||
defer sconn.mu.Unlock()
|
||||
|
||||
// add an entry to the pubsub btree
|
||||
entry := &pubSubEntry{
|
||||
pattern: pattern,
|
||||
channel: channel,
|
||||
sconn: sconn,
|
||||
}
|
||||
ps.chans.Set(entry)
|
||||
sconn.entries[entry] = true
|
||||
|
||||
// send a message to the client
|
||||
sconn.dconn.WriteArray(3)
|
||||
if pattern {
|
||||
sconn.dconn.WriteBulkString("psubscribe")
|
||||
} else {
|
||||
sconn.dconn.WriteBulkString("subscribe")
|
||||
}
|
||||
sconn.dconn.WriteBulkString(channel)
|
||||
var count int
|
||||
for ient := range sconn.entries {
|
||||
if ient.pattern == pattern {
|
||||
count++
|
||||
}
|
||||
}
|
||||
sconn.dconn.WriteInt(count)
|
||||
sconn.dconn.Flush()
|
||||
|
||||
// start the background client operation
|
||||
if !ok {
|
||||
go sconn.bgrunner(ps)
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *PubSub) unsubscribe(conn Conn, pattern, all bool, channel string) {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
// fetch the pubSubConn. This must exist
|
||||
sconn := ps.conns[conn]
|
||||
sconn.mu.Lock()
|
||||
defer sconn.mu.Unlock()
|
||||
|
||||
removeEntry := func(entry *pubSubEntry) {
|
||||
if entry != nil {
|
||||
ps.chans.Delete(entry)
|
||||
delete(sconn.entries, entry)
|
||||
}
|
||||
sconn.dconn.WriteArray(3)
|
||||
if pattern {
|
||||
sconn.dconn.WriteBulkString("punsubscribe")
|
||||
} else {
|
||||
sconn.dconn.WriteBulkString("unsubscribe")
|
||||
}
|
||||
if entry != nil {
|
||||
sconn.dconn.WriteBulkString(entry.channel)
|
||||
} else {
|
||||
sconn.dconn.WriteNull()
|
||||
}
|
||||
var count int
|
||||
for ient := range sconn.entries {
|
||||
if ient.pattern == pattern {
|
||||
count++
|
||||
}
|
||||
}
|
||||
sconn.dconn.WriteInt(count)
|
||||
}
|
||||
if all {
|
||||
// unsubscribe from all (p)subscribe entries
|
||||
var entries []*pubSubEntry
|
||||
for ient := range sconn.entries {
|
||||
if ient.pattern == pattern {
|
||||
entries = append(entries, ient)
|
||||
}
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
removeEntry(nil)
|
||||
} else {
|
||||
for _, entry := range entries {
|
||||
removeEntry(entry)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// unsubscribe single channel from (p)subscribe.
|
||||
var entry *pubSubEntry
|
||||
for ient := range sconn.entries {
|
||||
if ient.pattern == pattern && ient.channel == channel {
|
||||
removeEntry(entry)
|
||||
break
|
||||
}
|
||||
}
|
||||
removeEntry(entry)
|
||||
}
|
||||
sconn.dconn.Flush()
|
||||
}
|
|
@ -0,0 +1,194 @@
|
|||
package redcon
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPubSub(t *testing.T) {
|
||||
addr := ":12346"
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
var ps PubSub
|
||||
go func() {
|
||||
tch := time.NewTicker(time.Millisecond * 5)
|
||||
defer tch.Stop()
|
||||
channels := []string{"achan1", "bchan2", "cchan3", "dchan4"}
|
||||
for i := 0; ; i++ {
|
||||
select {
|
||||
case <-tch.C:
|
||||
case <-done:
|
||||
for {
|
||||
var empty bool
|
||||
ps.mu.Lock()
|
||||
if len(ps.conns) == 0 {
|
||||
if ps.chans.Len() != 0 {
|
||||
panic("chans not empty")
|
||||
}
|
||||
empty = true
|
||||
}
|
||||
ps.mu.Unlock()
|
||||
if empty {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
}
|
||||
done <- true
|
||||
return
|
||||
}
|
||||
channel := channels[i%len(channels)]
|
||||
message := fmt.Sprintf("message %d", i)
|
||||
ps.Publish(channel, message)
|
||||
}
|
||||
}()
|
||||
t.Fatal(ListenAndServe(addr, func(conn Conn, cmd Command) {
|
||||
switch strings.ToLower(string(cmd.Args[0])) {
|
||||
default:
|
||||
conn.WriteError("ERR unknown command '" +
|
||||
string(cmd.Args[0]) + "'")
|
||||
case "publish":
|
||||
if len(cmd.Args) != 3 {
|
||||
conn.WriteError("ERR wrong number of arguments for '" +
|
||||
string(cmd.Args[0]) + "' command")
|
||||
return
|
||||
}
|
||||
count := ps.Publish(string(cmd.Args[1]), string(cmd.Args[2]))
|
||||
conn.WriteInt(count)
|
||||
case "subscribe", "psubscribe":
|
||||
if len(cmd.Args) < 2 {
|
||||
conn.WriteError("ERR wrong number of arguments for '" +
|
||||
string(cmd.Args[0]) + "' command")
|
||||
return
|
||||
}
|
||||
command := strings.ToLower(string(cmd.Args[0]))
|
||||
for i := 1; i < len(cmd.Args); i++ {
|
||||
if command == "psubscribe" {
|
||||
ps.Psubscribe(conn, string(cmd.Args[i]))
|
||||
} else {
|
||||
ps.Subscribe(conn, string(cmd.Args[i]))
|
||||
}
|
||||
}
|
||||
}
|
||||
}, nil, nil))
|
||||
}()
|
||||
|
||||
final := make(chan bool)
|
||||
go func() {
|
||||
select {
|
||||
case <-time.Tick(time.Second * 30):
|
||||
panic("timeout")
|
||||
case <-final:
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// create 10 connections
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(10)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
var conn net.Conn
|
||||
for i := 0; i < 5; i++ {
|
||||
var err error
|
||||
conn, err = net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
time.Sleep(time.Second / 10)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if conn == nil {
|
||||
panic("could not connect to server")
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
regs := make(map[string]int)
|
||||
var maxp int
|
||||
var maxs int
|
||||
fmt.Fprintf(conn, "subscribe achan1\r\n")
|
||||
fmt.Fprintf(conn, "subscribe bchan2 cchan3\r\n")
|
||||
fmt.Fprintf(conn, "psubscribe a*1\r\n")
|
||||
fmt.Fprintf(conn, "psubscribe b*2 c*3\r\n")
|
||||
|
||||
// collect 50 messages from each channel
|
||||
rd := bufio.NewReader(conn)
|
||||
var buf []byte
|
||||
for {
|
||||
line, err := rd.ReadBytes('\n')
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
buf = append(buf, line...)
|
||||
n, resp := ReadNextRESP(buf)
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
buf = nil
|
||||
if resp.Type != Array {
|
||||
panic("expected array")
|
||||
}
|
||||
var vals []RESP
|
||||
resp.ForEach(func(item RESP) bool {
|
||||
vals = append(vals, item)
|
||||
return true
|
||||
})
|
||||
|
||||
name := string(vals[0].Data)
|
||||
switch name {
|
||||
case "subscribe":
|
||||
if len(vals) != 3 {
|
||||
panic("invalid count")
|
||||
}
|
||||
ch := string(vals[1].Data)
|
||||
regs[ch] = 0
|
||||
maxs, _ = strconv.Atoi(string(vals[2].Data))
|
||||
case "psubscribe":
|
||||
if len(vals) != 3 {
|
||||
panic("invalid count")
|
||||
}
|
||||
ch := string(vals[1].Data)
|
||||
regs[ch] = 0
|
||||
maxp, _ = strconv.Atoi(string(vals[2].Data))
|
||||
case "message":
|
||||
if len(vals) != 3 {
|
||||
panic("invalid count")
|
||||
}
|
||||
ch := string(vals[1].Data)
|
||||
regs[ch] = regs[ch] + 1
|
||||
case "pmessage":
|
||||
if len(vals) != 4 {
|
||||
panic("invalid count")
|
||||
}
|
||||
ch := string(vals[1].Data)
|
||||
regs[ch] = regs[ch] + 1
|
||||
}
|
||||
if len(regs) == 6 && maxp == 3 && maxs == 3 {
|
||||
ready := true
|
||||
for _, count := range regs {
|
||||
if count < 50 {
|
||||
ready = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if ready {
|
||||
// all messages have been received
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
// notify sender
|
||||
done <- true
|
||||
// wait for sender
|
||||
<-done
|
||||
// stop the timeout
|
||||
final <- true
|
||||
}
|
Loading…
Reference in New Issue