From d5576bf8a32c3c239fc8180376681ff4ebfdadfc Mon Sep 17 00:00:00 2001 From: tidwall Date: Fri, 30 Oct 2020 06:46:06 -0700 Subject: [PATCH] Added PubSub --- example/clone.go | 22 ++- go.mod | 8 ++ go.sum | 4 + pubsub.go | 347 +++++++++++++++++++++++++++++++++++++++++++++++ pubsub_test.go | 194 ++++++++++++++++++++++++++ 5 files changed, 574 insertions(+), 1 deletion(-) create mode 100644 go.mod create mode 100644 go.sum create mode 100644 pubsub.go create mode 100644 pubsub_test.go diff --git a/example/clone.go b/example/clone.go index f32bc1c..cac6440 100644 --- a/example/clone.go +++ b/example/clone.go @@ -13,12 +13,33 @@ var addr = ":6380" func main() { var mu sync.RWMutex var items = make(map[string][]byte) + var ps redcon.PubSub go log.Printf("started server at %s", addr) err := redcon.ListenAndServe(addr, func(conn redcon.Conn, cmd redcon.Command) { switch strings.ToLower(string(cmd.Args[0])) { default: conn.WriteError("ERR unknown command '" + string(cmd.Args[0]) + "'") + case "publish": + if len(cmd.Args) != 3 { + conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") + return + } + count := ps.Publish(string(cmd.Args[1]), string(cmd.Args[2])) + conn.WriteInt(count) + case "subscribe", "psubscribe": + if len(cmd.Args) < 2 { + conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") + return + } + command := strings.ToLower(string(cmd.Args[0])) + for i := 1; i < len(cmd.Args); i++ { + if command == "psubscribe" { + ps.Psubscribe(conn, string(cmd.Args[i])) + } else { + ps.Subscribe(conn, string(cmd.Args[i])) + } + } case "detach": hconn := conn.Detach() log.Printf("connection has been detached") @@ -27,7 +48,6 @@ func main() { hconn.WriteString("OK") hconn.Flush() }() - return case "ping": conn.WriteString("PONG") case "quit": diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..bb6ad5a --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module github.com/tidwall/redcon + +go 1.15 + +require ( + github.com/tidwall/btree v0.2.2 + github.com/tidwall/match v1.0.1 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..34a771c --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +github.com/tidwall/btree v0.2.2 h1:VVo0JW/tdidNdQzNsDR4wMbL3heaxA1DGleyzQ3/niY= +github.com/tidwall/btree v0.2.2/go.mod h1:huei1BkDWJ3/sLXmO+bsCNELL+Bp2Kks9OLyQFkzvA8= +github.com/tidwall/match v1.0.1 h1:PnKP62LPNxHKTwvHHZZzdOAOCtsJTjo6dZLCwpKm5xc= +github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E= diff --git a/pubsub.go b/pubsub.go new file mode 100644 index 0000000..86dc92c --- /dev/null +++ b/pubsub.go @@ -0,0 +1,347 @@ +package redcon + +import ( + "fmt" + "strings" + "sync" + + "github.com/tidwall/btree" + "github.com/tidwall/match" +) + +// PubSub is a Redis compatible pub/sub server +type PubSub struct { + mu sync.RWMutex + nextid uint64 + initd bool + chans *btree.BTree + conns map[Conn]*pubSubConn +} + +// Subscribe a connection to PubSub +func (ps *PubSub) Subscribe(conn Conn, channel string) { + ps.subscribe(conn, false, channel) +} + +// Psubscribe a connection to PubSub +func (ps *PubSub) Psubscribe(conn Conn, channel string) { + ps.subscribe(conn, true, channel) +} + +// Publish a message to subscribers +func (ps *PubSub) Publish(channel, message string) int { + ps.mu.RLock() + defer ps.mu.RUnlock() + if !ps.initd { + return 0 + } + var sent int + // write messages to all clients that are subscribed on the channel + pivot := &pubSubEntry{pattern: false, channel: channel} + ps.chans.Ascend(pivot, func(item interface{}) bool { + entry := item.(*pubSubEntry) + if entry.channel != pivot.channel || entry.pattern != pivot.pattern { + return false + } + entry.sconn.writeMessage(entry.pattern, "", channel, message) + sent++ + return true + }) + + // match on and write all psubscribe clients + pivot = &pubSubEntry{pattern: true} + ps.chans.Ascend(pivot, func(item interface{}) bool { + entry := item.(*pubSubEntry) + if match.Match(channel, entry.channel) { + entry.sconn.writeMessage(entry.pattern, entry.channel, channel, + message) + } + sent++ + return true + }) + + return sent +} + +type pubSubConn struct { + id uint64 + mu sync.Mutex + conn Conn + dconn DetachedConn + entries map[*pubSubEntry]bool +} + +type pubSubEntry struct { + pattern bool + sconn *pubSubConn + channel string +} + +func (sconn *pubSubConn) writeMessage(pat bool, pchan, channel, msg string) { + sconn.mu.Lock() + defer sconn.mu.Unlock() + if pat { + sconn.dconn.WriteArray(4) + sconn.dconn.WriteBulkString("pmessage") + sconn.dconn.WriteBulkString(pchan) + sconn.dconn.WriteBulkString(channel) + sconn.dconn.WriteBulkString(msg) + } else { + sconn.dconn.WriteArray(3) + sconn.dconn.WriteBulkString("message") + sconn.dconn.WriteBulkString(channel) + sconn.dconn.WriteBulkString(msg) + } + sconn.dconn.Flush() +} + +// bgrunner runs in the background and reads incoming commands from the +// detached client. +func (sconn *pubSubConn) bgrunner(ps *PubSub) { + defer func() { + // client connection has ended, disconnect from the PubSub instances + // and close the network connection. + ps.mu.Lock() + defer ps.mu.Unlock() + for entry := range sconn.entries { + ps.chans.Delete(entry) + } + delete(ps.conns, sconn.conn) + sconn.mu.Lock() + defer sconn.mu.Unlock() + sconn.dconn.Close() + }() + for { + cmd, err := sconn.dconn.ReadCommand() + if err != nil { + return + } + if len(cmd.Args) == 0 { + continue + } + switch strings.ToLower(string(cmd.Args[0])) { + case "psubscribe", "subscribe": + if len(cmd.Args) < 2 { + func() { + sconn.mu.Lock() + defer sconn.mu.Unlock() + sconn.dconn.WriteError(fmt.Sprintf("ERR wrong number of "+ + "arguments for '%s'", cmd.Args[0])) + sconn.dconn.Flush() + }() + continue + } + command := strings.ToLower(string(cmd.Args[0])) + for i := 1; i < len(cmd.Args); i++ { + if command == "psubscribe" { + ps.Psubscribe(sconn.conn, string(cmd.Args[i])) + } else { + ps.Subscribe(sconn.conn, string(cmd.Args[i])) + } + } + case "unsubscribe", "punsubscribe": + pattern := strings.ToLower(string(cmd.Args[0])) == "punsubscribe" + if len(cmd.Args) == 1 { + ps.unsubscribe(sconn.conn, pattern, true, "") + } else { + for i := 1; i < len(cmd.Args); i++ { + channel := string(cmd.Args[i]) + ps.unsubscribe(sconn.conn, pattern, false, channel) + } + } + case "quit": + func() { + sconn.mu.Lock() + defer sconn.mu.Unlock() + sconn.dconn.WriteString("OK") + sconn.dconn.Flush() + sconn.dconn.Close() + }() + return + case "ping": + var msg string + switch len(cmd.Args) { + case 1: + case 2: + msg = string(cmd.Args[1]) + default: + func() { + sconn.mu.Lock() + defer sconn.mu.Unlock() + sconn.dconn.WriteError(fmt.Sprintf("ERR wrong number of "+ + "arguments for '%s'", cmd.Args[0])) + sconn.dconn.Flush() + }() + continue + } + func() { + sconn.mu.Lock() + defer sconn.mu.Unlock() + sconn.dconn.WriteArray(2) + sconn.dconn.WriteBulkString("pong") + sconn.dconn.WriteBulkString(msg) + sconn.dconn.Flush() + }() + default: + func() { + sconn.mu.Lock() + defer sconn.mu.Unlock() + sconn.dconn.WriteError(fmt.Sprintf("ERR Can't execute '%s': "+ + "only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT are "+ + "allowed in this context", cmd.Args[0])) + sconn.dconn.Flush() + }() + } + } +} + +// byEntry is a "less" function that sorts the entries in a btree. The tree +// is sorted be (pattern, channel, conn.id). All pattern=true entries are at +// the end (right) of the tree. +func byEntry(a, b interface{}) bool { + aa := a.(*pubSubEntry) + bb := b.(*pubSubEntry) + if !aa.pattern && bb.pattern { + return true + } + if aa.pattern && !bb.pattern { + return false + } + if aa.channel < bb.channel { + return true + } + if aa.channel > bb.channel { + return false + } + var aid uint64 + var bid uint64 + if aa.sconn != nil { + aid = aa.sconn.id + } + if bb.sconn != nil { + bid = bb.sconn.id + } + return aid < bid +} + +func (ps *PubSub) subscribe(conn Conn, pattern bool, channel string) { + ps.mu.Lock() + defer ps.mu.Unlock() + + // initialize the PubSub instance + if !ps.initd { + ps.conns = make(map[Conn]*pubSubConn) + ps.chans = btree.New(byEntry) + ps.initd = true + } + + // fetch the pubSubConn + sconn, ok := ps.conns[conn] + if !ok { + // initialize a new pubSubConn, which runs on a detached connection, + // and attach it to the PubSub channels/conn btree + ps.nextid++ + dconn := conn.Detach() + sconn = &pubSubConn{ + id: ps.nextid, + conn: conn, + dconn: dconn, + entries: make(map[*pubSubEntry]bool), + } + ps.conns[conn] = sconn + } + sconn.mu.Lock() + defer sconn.mu.Unlock() + + // add an entry to the pubsub btree + entry := &pubSubEntry{ + pattern: pattern, + channel: channel, + sconn: sconn, + } + ps.chans.Set(entry) + sconn.entries[entry] = true + + // send a message to the client + sconn.dconn.WriteArray(3) + if pattern { + sconn.dconn.WriteBulkString("psubscribe") + } else { + sconn.dconn.WriteBulkString("subscribe") + } + sconn.dconn.WriteBulkString(channel) + var count int + for ient := range sconn.entries { + if ient.pattern == pattern { + count++ + } + } + sconn.dconn.WriteInt(count) + sconn.dconn.Flush() + + // start the background client operation + if !ok { + go sconn.bgrunner(ps) + } +} + +func (ps *PubSub) unsubscribe(conn Conn, pattern, all bool, channel string) { + ps.mu.Lock() + defer ps.mu.Unlock() + // fetch the pubSubConn. This must exist + sconn := ps.conns[conn] + sconn.mu.Lock() + defer sconn.mu.Unlock() + + removeEntry := func(entry *pubSubEntry) { + if entry != nil { + ps.chans.Delete(entry) + delete(sconn.entries, entry) + } + sconn.dconn.WriteArray(3) + if pattern { + sconn.dconn.WriteBulkString("punsubscribe") + } else { + sconn.dconn.WriteBulkString("unsubscribe") + } + if entry != nil { + sconn.dconn.WriteBulkString(entry.channel) + } else { + sconn.dconn.WriteNull() + } + var count int + for ient := range sconn.entries { + if ient.pattern == pattern { + count++ + } + } + sconn.dconn.WriteInt(count) + } + if all { + // unsubscribe from all (p)subscribe entries + var entries []*pubSubEntry + for ient := range sconn.entries { + if ient.pattern == pattern { + entries = append(entries, ient) + } + } + if len(entries) == 0 { + removeEntry(nil) + } else { + for _, entry := range entries { + removeEntry(entry) + } + } + } else { + // unsubscribe single channel from (p)subscribe. + var entry *pubSubEntry + for ient := range sconn.entries { + if ient.pattern == pattern && ient.channel == channel { + removeEntry(entry) + break + } + } + removeEntry(entry) + } + sconn.dconn.Flush() +} diff --git a/pubsub_test.go b/pubsub_test.go new file mode 100644 index 0000000..2b64669 --- /dev/null +++ b/pubsub_test.go @@ -0,0 +1,194 @@ +package redcon + +import ( + "bufio" + "fmt" + "net" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +func TestPubSub(t *testing.T) { + addr := ":12346" + done := make(chan bool) + go func() { + var ps PubSub + go func() { + tch := time.NewTicker(time.Millisecond * 5) + defer tch.Stop() + channels := []string{"achan1", "bchan2", "cchan3", "dchan4"} + for i := 0; ; i++ { + select { + case <-tch.C: + case <-done: + for { + var empty bool + ps.mu.Lock() + if len(ps.conns) == 0 { + if ps.chans.Len() != 0 { + panic("chans not empty") + } + empty = true + } + ps.mu.Unlock() + if empty { + break + } + time.Sleep(time.Millisecond * 10) + } + done <- true + return + } + channel := channels[i%len(channels)] + message := fmt.Sprintf("message %d", i) + ps.Publish(channel, message) + } + }() + t.Fatal(ListenAndServe(addr, func(conn Conn, cmd Command) { + switch strings.ToLower(string(cmd.Args[0])) { + default: + conn.WriteError("ERR unknown command '" + + string(cmd.Args[0]) + "'") + case "publish": + if len(cmd.Args) != 3 { + conn.WriteError("ERR wrong number of arguments for '" + + string(cmd.Args[0]) + "' command") + return + } + count := ps.Publish(string(cmd.Args[1]), string(cmd.Args[2])) + conn.WriteInt(count) + case "subscribe", "psubscribe": + if len(cmd.Args) < 2 { + conn.WriteError("ERR wrong number of arguments for '" + + string(cmd.Args[0]) + "' command") + return + } + command := strings.ToLower(string(cmd.Args[0])) + for i := 1; i < len(cmd.Args); i++ { + if command == "psubscribe" { + ps.Psubscribe(conn, string(cmd.Args[i])) + } else { + ps.Subscribe(conn, string(cmd.Args[i])) + } + } + } + }, nil, nil)) + }() + + final := make(chan bool) + go func() { + select { + case <-time.Tick(time.Second * 30): + panic("timeout") + case <-final: + return + } + }() + + // create 10 connections + var wg sync.WaitGroup + wg.Add(10) + for i := 0; i < 10; i++ { + go func(i int) { + defer wg.Done() + var conn net.Conn + for i := 0; i < 5; i++ { + var err error + conn, err = net.Dial("tcp", addr) + if err != nil { + time.Sleep(time.Second / 10) + continue + } + } + if conn == nil { + panic("could not connect to server") + } + defer conn.Close() + + regs := make(map[string]int) + var maxp int + var maxs int + fmt.Fprintf(conn, "subscribe achan1\r\n") + fmt.Fprintf(conn, "subscribe bchan2 cchan3\r\n") + fmt.Fprintf(conn, "psubscribe a*1\r\n") + fmt.Fprintf(conn, "psubscribe b*2 c*3\r\n") + + // collect 50 messages from each channel + rd := bufio.NewReader(conn) + var buf []byte + for { + line, err := rd.ReadBytes('\n') + if err != nil { + panic(err) + } + buf = append(buf, line...) + n, resp := ReadNextRESP(buf) + if n == 0 { + continue + } + buf = nil + if resp.Type != Array { + panic("expected array") + } + var vals []RESP + resp.ForEach(func(item RESP) bool { + vals = append(vals, item) + return true + }) + + name := string(vals[0].Data) + switch name { + case "subscribe": + if len(vals) != 3 { + panic("invalid count") + } + ch := string(vals[1].Data) + regs[ch] = 0 + maxs, _ = strconv.Atoi(string(vals[2].Data)) + case "psubscribe": + if len(vals) != 3 { + panic("invalid count") + } + ch := string(vals[1].Data) + regs[ch] = 0 + maxp, _ = strconv.Atoi(string(vals[2].Data)) + case "message": + if len(vals) != 3 { + panic("invalid count") + } + ch := string(vals[1].Data) + regs[ch] = regs[ch] + 1 + case "pmessage": + if len(vals) != 4 { + panic("invalid count") + } + ch := string(vals[1].Data) + regs[ch] = regs[ch] + 1 + } + if len(regs) == 6 && maxp == 3 && maxs == 3 { + ready := true + for _, count := range regs { + if count < 50 { + ready = false + break + } + } + if ready { + // all messages have been received + return + } + } + } + }(i) + } + wg.Wait() + // notify sender + done <- true + // wait for sender + <-done + // stop the timeout + final <- true +}