package controller import ( "io" "net" "strconv" "sync" "time" "github.com/tidwall/gjson" "github.com/tidwall/match" "github.com/tidwall/redcon" "github.com/tidwall/resp" "github.com/tidwall/tile38/internal/log" "github.com/tidwall/tile38/internal/server" ) const ( pubsubChannel = iota pubsubPattern ) type pubsub struct { mu sync.RWMutex hubs [2]map[string]*subhub } func newPubsub() *pubsub { return &pubsub{ hubs: [2]map[string]*subhub{ make(map[string]*subhub), make(map[string]*subhub), }, } } // Publish a message to subscribers func (c *Controller) Publish(channel string, message ...string) int { var msgs []submsg c.pubsub.mu.RLock() if hub := c.pubsub.hubs[pubsubChannel][channel]; hub != nil { for target := range hub.targets { for _, message := range message { msgs = append(msgs, submsg{ kind: pubsubChannel, target: target, channel: channel, message: message, }) } } } for pattern, hub := range c.pubsub.hubs[pubsubPattern] { if match.Match(channel, pattern) { for target := range hub.targets { for _, message := range message { msgs = append(msgs, submsg{ kind: pubsubPattern, target: target, channel: channel, pattern: pattern, message: message, }) } } } } c.pubsub.mu.RUnlock() for _, msg := range msgs { msg.target.cond.L.Lock() msg.target.msgs = append(msg.target.msgs, msg) msg.target.cond.Broadcast() msg.target.cond.L.Unlock() } return len(msgs) } func (ps *pubsub) register(kind int, channel string, target *subtarget) { ps.mu.Lock() hub, ok := ps.hubs[kind][channel] if !ok { hub = newSubhub() ps.hubs[kind][channel] = hub } hub.targets[target] = true ps.mu.Unlock() } func (ps *pubsub) unregister(kind int, channel string, target *subtarget) { ps.mu.Lock() hub, ok := ps.hubs[kind][channel] if ok { delete(hub.targets, target) if len(hub.targets) == 0 { delete(ps.hubs[kind], channel) } } ps.mu.Unlock() } type submsg struct { kind byte target *subtarget pattern string channel string message string } type subtarget struct { cond *sync.Cond msgs []submsg closed bool } func newSubtarget() *subtarget { target := new(subtarget) target.cond = sync.NewCond(&sync.Mutex{}) return target } type subhub struct { targets map[*subtarget]bool } func newSubhub() *subhub { hub := new(subhub) hub.targets = make(map[*subtarget]bool) return hub } type liveSubscriptionSwitches struct { // no fields. everything is managed through the server.Message } func (sub liveSubscriptionSwitches) Error() string { return goingLive } func (c *Controller) cmdSubscribe(msg *server.Message) (resp.Value, error) { if len(msg.Values) < 2 { return resp.Value{}, errInvalidNumberOfArguments } return server.NOMessage, liveSubscriptionSwitches{} } func (c *Controller) cmdPsubscribe(msg *server.Message) (resp.Value, error) { if len(msg.Values) < 2 { return resp.Value{}, errInvalidNumberOfArguments } return server.NOMessage, liveSubscriptionSwitches{} } func (c *Controller) cmdPublish(msg *server.Message) (resp.Value, error) { start := time.Now() if len(msg.Values) != 3 { return resp.Value{}, errInvalidNumberOfArguments } channel := msg.Values[1].String() message := msg.Values[2].String() //geofence := gjson.Valid(message) && gjson.Get(message, "fence").Bool() n := c.Publish(channel, message) //, geofence) var res resp.Value switch msg.OutputType { case server.JSON: res = resp.StringValue(`{"ok":true` + `,"published":` + strconv.FormatInt(int64(n), 10) + `,"elapsed":"` + time.Now().Sub(start).String() + `"}`) case server.RESP: res = resp.IntegerValue(n) } return res, nil } func (c *Controller) liveSubscription( conn net.Conn, rd *server.PipelineReader, msg *server.Message, websocket bool, ) error { defer conn.Close() // close connection when we are done outputType := msg.OutputType connType := msg.ConnType if websocket { outputType = server.JSON } var start time.Time // write helpers var writeLock sync.Mutex write := func(data []byte) { writeLock.Lock() defer writeLock.Unlock() writeLiveMessage(conn, data, false, connType, websocket) } writeOK := func() { switch outputType { case server.JSON: write([]byte(`{"ok":true` + `,"elapsed":"` + time.Now().Sub(start).String() + `"}`)) case server.RESP: write([]byte(`+OK\r\n`)) } } writeWrongNumberOfArgsErr := func(command string) { switch outputType { case server.JSON: write([]byte(`{"ok":false,"err":"invalid number of arguments"` + `,"elapsed":"` + time.Now().Sub(start).String() + `"}`)) case server.RESP: write([]byte(`-ERR wrong number of arguments ` + `for '` + command + `' command\r\n`)) } } writeOnlyPubsubErr := func() { switch outputType { case server.JSON: write([]byte(`{"ok":false` + `,"err":"only (P)SUBSCRIBE / (P)UNSUBSCRIBE / ` + `PING / QUIT allowed in this context"` + `,"elapsed":"` + time.Now().Sub(start).String() + `"}`)) case server.RESP: write([]byte("-ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / " + "PING / QUIT allowed in this context\r\n")) } } writeSubscribe := func(command, channel string, num int) { switch outputType { case server.JSON: write([]byte(`{"ok":true` + `,"command":` + jsonString(command) + `,"channel":` + jsonString(channel) + `,"num":` + strconv.FormatInt(int64(num), 10) + `,"elapsed":"` + time.Now().Sub(start).String() + `"}`)) case server.RESP: b := redcon.AppendArray(nil, 3) b = redcon.AppendBulkString(b, command) b = redcon.AppendBulkString(b, channel) b = redcon.AppendInt(b, int64(num)) write(b) } } writeMessage := func(msg submsg) { if msg.kind == pubsubChannel { switch outputType { case server.JSON: var data []byte if !gjson.Valid(msg.message) { data = appendJSONString(nil, msg.message) } else { data = []byte(msg.message) } write(data) case server.RESP: b := redcon.AppendArray(nil, 3) b = redcon.AppendBulkString(b, "message") b = redcon.AppendBulkString(b, msg.channel) b = redcon.AppendBulkString(b, msg.message) write(b) } } else { switch outputType { case server.JSON: var data []byte if !gjson.Valid(msg.message) { data = appendJSONString(nil, msg.message) } else { data = []byte(msg.message) } write(data) case server.RESP: b := redcon.AppendArray(nil, 4) b = redcon.AppendBulkString(b, "pmessage") b = redcon.AppendBulkString(b, msg.pattern) b = redcon.AppendBulkString(b, msg.channel) b = redcon.AppendBulkString(b, msg.message) write(b) } } } m := [2]map[string]bool{ make(map[string]bool), make(map[string]bool), } target := newSubtarget() defer func() { for i := 0; i < 2; i++ { for channel := range m[i] { c.pubsub.unregister(i, channel, target) } } target.cond.L.Lock() target.closed = true target.cond.Broadcast() target.cond.L.Unlock() }() go func() { log.Debugf("pubsub open") defer log.Debugf("pubsub closed") for { var msgs []submsg target.cond.L.Lock() if len(target.msgs) > 0 { msgs = target.msgs target.msgs = nil } target.cond.L.Unlock() for _, msg := range msgs { writeMessage(msg) } target.cond.L.Lock() if target.closed { target.cond.L.Unlock() return } target.cond.Wait() target.cond.L.Unlock() } }() msgs := []*server.Message{msg} for { for _, msg := range msgs { start = time.Now() var kind int switch msg.Command { case "quit": writeOK() return nil case "psubscribe": kind = pubsubPattern case "subscribe": kind = pubsubChannel default: writeOnlyPubsubErr() } if len(msg.Values) < 2 { writeWrongNumberOfArgsErr(msg.Command) } for i := 1; i < len(msg.Values); i++ { channel := msg.Values[i].String() m[kind][channel] = true c.pubsub.register(kind, channel, target) writeSubscribe(msg.Command, channel, len(m[0])+len(m[1])) } } var err error msgs, err = rd.ReadMessages() if err != nil { if err == io.EOF { return nil } return err } } }