mirror of https://github.com/tidwall/tile38.git
364 lines
8.0 KiB
Go
364 lines
8.0 KiB
Go
package controller
|
|
|
|
import (
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/match"
|
|
"github.com/tidwall/redcon"
|
|
"github.com/tidwall/resp"
|
|
"github.com/tidwall/tile38/internal/log"
|
|
"github.com/tidwall/tile38/internal/server"
|
|
)
|
|
|
|
const (
|
|
pubsubChannel = iota
|
|
pubsubPattern
|
|
)
|
|
|
|
type pubsub struct {
|
|
mu sync.RWMutex
|
|
hubs [2]map[string]*subhub
|
|
}
|
|
|
|
func newPubsub() *pubsub {
|
|
return &pubsub{
|
|
hubs: [2]map[string]*subhub{
|
|
make(map[string]*subhub),
|
|
make(map[string]*subhub),
|
|
},
|
|
}
|
|
}
|
|
|
|
// Publish a message to subscribers
|
|
func (c *Controller) Publish(channel string, message ...string) int {
|
|
var msgs []submsg
|
|
c.pubsub.mu.RLock()
|
|
if hub := c.pubsub.hubs[pubsubChannel][channel]; hub != nil {
|
|
for target := range hub.targets {
|
|
for _, message := range message {
|
|
msgs = append(msgs, submsg{
|
|
kind: pubsubChannel,
|
|
target: target,
|
|
channel: channel,
|
|
message: message,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
for pattern, hub := range c.pubsub.hubs[pubsubPattern] {
|
|
if match.Match(channel, pattern) {
|
|
for target := range hub.targets {
|
|
for _, message := range message {
|
|
msgs = append(msgs, submsg{
|
|
kind: pubsubPattern,
|
|
target: target,
|
|
channel: channel,
|
|
pattern: pattern,
|
|
message: message,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
c.pubsub.mu.RUnlock()
|
|
|
|
for _, msg := range msgs {
|
|
msg.target.cond.L.Lock()
|
|
msg.target.msgs = append(msg.target.msgs, msg)
|
|
msg.target.cond.Broadcast()
|
|
msg.target.cond.L.Unlock()
|
|
}
|
|
|
|
return len(msgs)
|
|
}
|
|
|
|
func (ps *pubsub) register(kind int, channel string, target *subtarget) {
|
|
ps.mu.Lock()
|
|
hub, ok := ps.hubs[kind][channel]
|
|
if !ok {
|
|
hub = newSubhub()
|
|
ps.hubs[kind][channel] = hub
|
|
}
|
|
hub.targets[target] = true
|
|
ps.mu.Unlock()
|
|
}
|
|
|
|
func (ps *pubsub) unregister(kind int, channel string, target *subtarget) {
|
|
ps.mu.Lock()
|
|
hub, ok := ps.hubs[kind][channel]
|
|
if ok {
|
|
delete(hub.targets, target)
|
|
if len(hub.targets) == 0 {
|
|
delete(ps.hubs[kind], channel)
|
|
}
|
|
}
|
|
ps.mu.Unlock()
|
|
}
|
|
|
|
type submsg struct {
|
|
kind byte
|
|
target *subtarget
|
|
pattern string
|
|
channel string
|
|
message string
|
|
}
|
|
|
|
type subtarget struct {
|
|
cond *sync.Cond
|
|
msgs []submsg
|
|
closed bool
|
|
}
|
|
|
|
func newSubtarget() *subtarget {
|
|
target := new(subtarget)
|
|
target.cond = sync.NewCond(&sync.Mutex{})
|
|
return target
|
|
}
|
|
|
|
type subhub struct {
|
|
targets map[*subtarget]bool
|
|
}
|
|
|
|
func newSubhub() *subhub {
|
|
hub := new(subhub)
|
|
hub.targets = make(map[*subtarget]bool)
|
|
return hub
|
|
}
|
|
|
|
type liveSubscriptionSwitches struct {
|
|
// no fields. everything is managed through the server.Message
|
|
}
|
|
|
|
func (sub liveSubscriptionSwitches) Error() string {
|
|
return goingLive
|
|
}
|
|
|
|
func (c *Controller) cmdSubscribe(msg *server.Message) (resp.Value, error) {
|
|
if len(msg.Values) < 2 {
|
|
return resp.Value{}, errInvalidNumberOfArguments
|
|
}
|
|
return server.NOMessage, liveSubscriptionSwitches{}
|
|
}
|
|
|
|
func (c *Controller) cmdPsubscribe(msg *server.Message) (resp.Value, error) {
|
|
if len(msg.Values) < 2 {
|
|
return resp.Value{}, errInvalidNumberOfArguments
|
|
}
|
|
return server.NOMessage, liveSubscriptionSwitches{}
|
|
}
|
|
|
|
func (c *Controller) cmdPublish(msg *server.Message) (resp.Value, error) {
|
|
start := time.Now()
|
|
if len(msg.Values) != 3 {
|
|
return resp.Value{}, errInvalidNumberOfArguments
|
|
}
|
|
|
|
channel := msg.Values[1].String()
|
|
message := msg.Values[2].String()
|
|
//geofence := gjson.Valid(message) && gjson.Get(message, "fence").Bool()
|
|
n := c.Publish(channel, message) //, geofence)
|
|
var res resp.Value
|
|
switch msg.OutputType {
|
|
case server.JSON:
|
|
res = resp.StringValue(`{"ok":true` +
|
|
`,"published":` + strconv.FormatInt(int64(n), 10) +
|
|
`,"elapsed":"` + time.Now().Sub(start).String() + `"}`)
|
|
case server.RESP:
|
|
res = resp.IntegerValue(n)
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func (c *Controller) liveSubscription(
|
|
conn net.Conn,
|
|
rd *server.PipelineReader,
|
|
msg *server.Message,
|
|
websocket bool,
|
|
) error {
|
|
defer conn.Close() // close connection when we are done
|
|
|
|
outputType := msg.OutputType
|
|
connType := msg.ConnType
|
|
if websocket {
|
|
outputType = server.JSON
|
|
}
|
|
|
|
var start time.Time
|
|
|
|
// write helpers
|
|
var writeLock sync.Mutex
|
|
write := func(data []byte) {
|
|
writeLock.Lock()
|
|
defer writeLock.Unlock()
|
|
writeLiveMessage(conn, data, false, connType, websocket)
|
|
}
|
|
writeOK := func() {
|
|
switch outputType {
|
|
case server.JSON:
|
|
write([]byte(`{"ok":true` +
|
|
`,"elapsed":"` + time.Now().Sub(start).String() + `"}`))
|
|
case server.RESP:
|
|
write([]byte(`+OK\r\n`))
|
|
}
|
|
}
|
|
writeWrongNumberOfArgsErr := func(command string) {
|
|
switch outputType {
|
|
case server.JSON:
|
|
write([]byte(`{"ok":false,"err":"invalid number of arguments"` +
|
|
`,"elapsed":"` + time.Now().Sub(start).String() + `"}`))
|
|
case server.RESP:
|
|
write([]byte(`-ERR wrong number of arguments ` +
|
|
`for '` + command + `' command\r\n`))
|
|
}
|
|
}
|
|
writeOnlyPubsubErr := func() {
|
|
switch outputType {
|
|
case server.JSON:
|
|
write([]byte(`{"ok":false` +
|
|
`,"err":"only (P)SUBSCRIBE / (P)UNSUBSCRIBE / ` +
|
|
`PING / QUIT allowed in this context"` +
|
|
`,"elapsed":"` + time.Now().Sub(start).String() + `"}`))
|
|
case server.RESP:
|
|
write([]byte("-ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / " +
|
|
"PING / QUIT allowed in this context\r\n"))
|
|
}
|
|
}
|
|
writeSubscribe := func(command, channel string, num int) {
|
|
switch outputType {
|
|
case server.JSON:
|
|
write([]byte(`{"ok":true` +
|
|
`,"command":` + jsonString(command) +
|
|
`,"channel":` + jsonString(channel) +
|
|
`,"num":` + strconv.FormatInt(int64(num), 10) +
|
|
`,"elapsed":"` + time.Now().Sub(start).String() + `"}`))
|
|
case server.RESP:
|
|
b := redcon.AppendArray(nil, 3)
|
|
b = redcon.AppendBulkString(b, command)
|
|
b = redcon.AppendBulkString(b, channel)
|
|
b = redcon.AppendInt(b, int64(num))
|
|
write(b)
|
|
}
|
|
}
|
|
writeMessage := func(msg submsg) {
|
|
if msg.kind == pubsubChannel {
|
|
switch outputType {
|
|
case server.JSON:
|
|
var data []byte
|
|
if !gjson.Valid(msg.message) {
|
|
data = appendJSONString(nil, msg.message)
|
|
} else {
|
|
data = []byte(msg.message)
|
|
}
|
|
write(data)
|
|
case server.RESP:
|
|
b := redcon.AppendArray(nil, 3)
|
|
b = redcon.AppendBulkString(b, "message")
|
|
b = redcon.AppendBulkString(b, msg.channel)
|
|
b = redcon.AppendBulkString(b, msg.message)
|
|
write(b)
|
|
}
|
|
} else {
|
|
switch outputType {
|
|
case server.JSON:
|
|
var data []byte
|
|
if !gjson.Valid(msg.message) {
|
|
data = appendJSONString(nil, msg.message)
|
|
} else {
|
|
data = []byte(msg.message)
|
|
}
|
|
write(data)
|
|
case server.RESP:
|
|
b := redcon.AppendArray(nil, 4)
|
|
b = redcon.AppendBulkString(b, "pmessage")
|
|
b = redcon.AppendBulkString(b, msg.pattern)
|
|
b = redcon.AppendBulkString(b, msg.channel)
|
|
b = redcon.AppendBulkString(b, msg.message)
|
|
write(b)
|
|
}
|
|
}
|
|
}
|
|
|
|
m := [2]map[string]bool{
|
|
make(map[string]bool),
|
|
make(map[string]bool),
|
|
}
|
|
|
|
target := newSubtarget()
|
|
|
|
defer func() {
|
|
for i := 0; i < 2; i++ {
|
|
for channel := range m[i] {
|
|
c.pubsub.unregister(i, channel, target)
|
|
}
|
|
}
|
|
target.cond.L.Lock()
|
|
target.closed = true
|
|
target.cond.Broadcast()
|
|
target.cond.L.Unlock()
|
|
}()
|
|
go func() {
|
|
log.Debugf("pubsub open")
|
|
defer log.Debugf("pubsub closed")
|
|
for {
|
|
var msgs []submsg
|
|
target.cond.L.Lock()
|
|
if len(target.msgs) > 0 {
|
|
msgs = target.msgs
|
|
target.msgs = nil
|
|
}
|
|
target.cond.L.Unlock()
|
|
for _, msg := range msgs {
|
|
writeMessage(msg)
|
|
}
|
|
target.cond.L.Lock()
|
|
if target.closed {
|
|
target.cond.L.Unlock()
|
|
return
|
|
}
|
|
target.cond.Wait()
|
|
target.cond.L.Unlock()
|
|
}
|
|
}()
|
|
|
|
msgs := []*server.Message{msg}
|
|
for {
|
|
for _, msg := range msgs {
|
|
start = time.Now()
|
|
var kind int
|
|
switch msg.Command {
|
|
case "quit":
|
|
writeOK()
|
|
return nil
|
|
case "psubscribe":
|
|
kind = pubsubPattern
|
|
case "subscribe":
|
|
kind = pubsubChannel
|
|
default:
|
|
writeOnlyPubsubErr()
|
|
}
|
|
if len(msg.Values) < 2 {
|
|
writeWrongNumberOfArgsErr(msg.Command)
|
|
}
|
|
for i := 1; i < len(msg.Values); i++ {
|
|
channel := msg.Values[i].String()
|
|
m[kind][channel] = true
|
|
c.pubsub.register(kind, channel, target)
|
|
writeSubscribe(msg.Command, channel, len(m[0])+len(m[1]))
|
|
}
|
|
}
|
|
var err error
|
|
msgs, err = rd.ReadMessages()
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
}
|