Compare commits

..

No commits in common. "master" and "pubsub" have entirely different histories.

19 changed files with 1255 additions and 1492 deletions

View File

@ -1,36 +0,0 @@
name: Go
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
jobs:
build:
name: Build
runs-on: ubuntu-latest
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v2
with:
go-version: ^1.13
- name: Check out code into the Go module directory
uses: actions/checkout@v2
- name: Get dependencies
run: |
go get -v -t -d ./...
if [ -f Gopkg.toml ]; then
curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh
dep ensure
fi
- name: Build
run: go build -v .
- name: Test
run: go test -v .

4
.travis.yml Normal file
View File

@ -0,0 +1,4 @@
language: go
go:
- 1.13.x

View File

@ -1,12 +1,15 @@
<p align="center"> <p align="center">
<img <img
src="logo.png" src="logo.png"
width="336" border="0" alt="REDCON"> width="336" height="75" border="0" alt="REDCON">
<br> <br>
<a href="https://godoc.org/git.internal/re/redcon"><img src="https://img.shields.io/badge/api-reference-blue.svg?style=flat-square" alt="GoDoc"></a> <a href="https://travis-ci.org/tidwall/redcon"><img src="https://img.shields.io/travis/tidwall/redcon.svg?style=flat-square" alt="Build Status"></a>
<a href="https://godoc.org/github.com/tidwall/redcon"><img src="https://img.shields.io/badge/api-reference-blue.svg?style=flat-square" alt="GoDoc"></a>
</p> </p>
<p align="center">Redis compatible server framework for Go</p> <p align="center">Fast Redis compatible server framework for Go</p>
Redcon is a custom Redis server framework for Go that is fast and simple to use. The reason for this library it to give an efficient server front-end for the [BuntDB](https://github.com/tidwall/buntdb) and [Tile38](https://github.com/tidwall/tile38) projects.
Features Features
-------- --------
@ -15,14 +18,13 @@ Features
- Support for pipelining and telnet commands - Support for pipelining and telnet commands
- Works with Redis clients such as [redigo](https://github.com/garyburd/redigo), [redis-py](https://github.com/andymccurdy/redis-py), [node_redis](https://github.com/NodeRedis/node_redis), and [jedis](https://github.com/xetorthio/jedis) - Works with Redis clients such as [redigo](https://github.com/garyburd/redigo), [redis-py](https://github.com/andymccurdy/redis-py), [node_redis](https://github.com/NodeRedis/node_redis), and [jedis](https://github.com/xetorthio/jedis)
- [TLS Support](#tls-example) - [TLS Support](#tls-example)
- Compatible pub/sub support
- Multithreaded - Multithreaded
Installing Installing
---------- ----------
``` ```
go get -u git.internal/re/redcon go get -u github.com/tidwall/redcon
``` ```
Example Example
@ -35,8 +37,6 @@ Here's a full example of a Redis clone that accepts:
- DEL key - DEL key
- PING - PING
- QUIT - QUIT
- PUBLISH channel message
- SUBSCRIBE channel
You can run this example from a terminal: You can run this example from a terminal:
@ -52,7 +52,7 @@ import (
"strings" "strings"
"sync" "sync"
"git.internal/re/redcon" "github.com/tidwall/redcon"
) )
var addr = ":6380" var addr = ":6380"
@ -60,7 +60,6 @@ var addr = ":6380"
func main() { func main() {
var mu sync.RWMutex var mu sync.RWMutex
var items = make(map[string][]byte) var items = make(map[string][]byte)
var ps redcon.PubSub
go log.Printf("started server at %s", addr) go log.Printf("started server at %s", addr)
err := redcon.ListenAndServe(addr, err := redcon.ListenAndServe(addr,
func(conn redcon.Conn, cmd redcon.Command) { func(conn redcon.Conn, cmd redcon.Command) {
@ -108,34 +107,15 @@ func main() {
} else { } else {
conn.WriteInt(1) conn.WriteInt(1)
} }
case "publish":
if len(cmd.Args) != 3 {
conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command")
return
}
conn.WriteInt(ps.Publish(string(cmd.Args[1]), string(cmd.Args[2])))
case "subscribe", "psubscribe":
if len(cmd.Args) < 2 {
conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command")
return
}
command := strings.ToLower(string(cmd.Args[0]))
for i := 1; i < len(cmd.Args); i++ {
if command == "psubscribe" {
ps.Psubscribe(conn, string(cmd.Args[i]))
} else {
ps.Subscribe(conn, string(cmd.Args[i]))
}
}
} }
}, },
func(conn redcon.Conn) bool { func(conn redcon.Conn) bool {
// Use this function to accept or deny the connection. // use this function to accept or deny the connection.
// log.Printf("accept: %s", conn.RemoteAddr()) // log.Printf("accept: %s", conn.RemoteAddr())
return true return true
}, },
func(conn redcon.Conn, err error) { func(conn redcon.Conn, err error) {
// This is called when the connection has been closed // this is called when the connection has been closed
// log.Printf("closed: %s, err: %v", conn.RemoteAddr(), err) // log.Printf("closed: %s, err: %v", conn.RemoteAddr(), err)
}, },
) )

469
append.go Normal file
View File

@ -0,0 +1,469 @@
package redcon
import (
"fmt"
"reflect"
"sort"
"strconv"
"strings"
)
// Kind is the kind of command
type Kind int
const (
// Redis is returned for Redis protocol commands
Redis Kind = iota
// Tile38 is returnd for Tile38 native protocol commands
Tile38
// Telnet is returnd for plain telnet commands
Telnet
)
var errInvalidMessage = &errProtocol{"invalid message"}
// ReadNextCommand reads the next command from the provided packet. It's
// possible that the packet contains multiple commands, or zero commands
// when the packet is incomplete.
// 'argsbuf' is an optional reusable buffer and it can be nil.
// 'complete' indicates that a command was read. false means no more commands.
// 'args' are the output arguments for the command.
// 'kind' is the type of command that was read.
// 'leftover' is any remaining unused bytes which belong to the next command.
// 'err' is returned when a protocol error was encountered.
func ReadNextCommand(packet []byte, argsbuf [][]byte) (
complete bool, args [][]byte, kind Kind, leftover []byte, err error,
) {
args = argsbuf[:0]
if len(packet) > 0 {
if packet[0] != '*' {
if packet[0] == '$' {
return readTile38Command(packet, args)
}
return readTelnetCommand(packet, args)
}
// standard redis command
for s, i := 1, 1; i < len(packet); i++ {
if packet[i] == '\n' {
if packet[i-1] != '\r' {
return false, args[:0], Redis, packet, errInvalidMultiBulkLength
}
count, ok := parseInt(packet[s : i-1])
if !ok || count < 0 {
return false, args[:0], Redis, packet, errInvalidMultiBulkLength
}
i++
if count == 0 {
return true, args[:0], Redis, packet[i:], nil
}
nextArg:
for j := 0; j < count; j++ {
if i == len(packet) {
break
}
if packet[i] != '$' {
return false, args[:0], Redis, packet,
&errProtocol{"expected '$', got '" +
string(packet[i]) + "'"}
}
for s := i + 1; i < len(packet); i++ {
if packet[i] == '\n' {
if packet[i-1] != '\r' {
return false, args[:0], Redis, packet, errInvalidBulkLength
}
n, ok := parseInt(packet[s : i-1])
if !ok || count <= 0 {
return false, args[:0], Redis, packet, errInvalidBulkLength
}
i++
if len(packet)-i >= n+2 {
if packet[i+n] != '\r' || packet[i+n+1] != '\n' {
return false, args[:0], Redis, packet, errInvalidBulkLength
}
args = append(args, packet[i:i+n])
i += n + 2
if j == count-1 {
// done reading
return true, args, Redis, packet[i:], nil
}
continue nextArg
}
break
}
}
break
}
break
}
}
}
return false, args[:0], Redis, packet, nil
}
func readTile38Command(packet []byte, argsbuf [][]byte) (
complete bool, args [][]byte, kind Kind, leftover []byte, err error,
) {
for i := 1; i < len(packet); i++ {
if packet[i] == ' ' {
n, ok := parseInt(packet[1:i])
if !ok || n < 0 {
return false, args[:0], Tile38, packet, errInvalidMessage
}
i++
if len(packet) >= i+n+2 {
if packet[i+n] != '\r' || packet[i+n+1] != '\n' {
return false, args[:0], Tile38, packet, errInvalidMessage
}
line := packet[i : i+n]
reading:
for len(line) != 0 {
if line[0] == '{' {
// The native protocol cannot understand json boundaries so it assumes that
// a json element must be at the end of the line.
args = append(args, line)
break
}
if line[0] == '"' && line[len(line)-1] == '"' {
if len(args) > 0 &&
strings.ToLower(string(args[0])) == "set" &&
strings.ToLower(string(args[len(args)-1])) == "string" {
// Setting a string value that is contained inside double quotes.
// This is only because of the boundary issues of the native protocol.
args = append(args, line[1:len(line)-1])
break
}
}
i := 0
for ; i < len(line); i++ {
if line[i] == ' ' {
value := line[:i]
if len(value) > 0 {
args = append(args, value)
}
line = line[i+1:]
continue reading
}
}
args = append(args, line)
break
}
return true, args, Tile38, packet[i+n+2:], nil
}
break
}
}
return false, args[:0], Tile38, packet, nil
}
func readTelnetCommand(packet []byte, argsbuf [][]byte) (
complete bool, args [][]byte, kind Kind, leftover []byte, err error,
) {
// just a plain text command
for i := 0; i < len(packet); i++ {
if packet[i] == '\n' {
var line []byte
if i > 0 && packet[i-1] == '\r' {
line = packet[:i-1]
} else {
line = packet[:i]
}
var quote bool
var quotech byte
var escape bool
outer:
for {
nline := make([]byte, 0, len(line))
for i := 0; i < len(line); i++ {
c := line[i]
if !quote {
if c == ' ' {
if len(nline) > 0 {
args = append(args, nline)
}
line = line[i+1:]
continue outer
}
if c == '"' || c == '\'' {
if i != 0 {
return false, args[:0], Telnet, packet, errUnbalancedQuotes
}
quotech = c
quote = true
line = line[i+1:]
continue outer
}
} else {
if escape {
escape = false
switch c {
case 'n':
c = '\n'
case 'r':
c = '\r'
case 't':
c = '\t'
}
} else if c == quotech {
quote = false
quotech = 0
args = append(args, nline)
line = line[i+1:]
if len(line) > 0 && line[0] != ' ' {
return false, args[:0], Telnet, packet, errUnbalancedQuotes
}
continue outer
} else if c == '\\' {
escape = true
continue
}
}
nline = append(nline, c)
}
if quote {
return false, args[:0], Telnet, packet, errUnbalancedQuotes
}
if len(line) > 0 {
args = append(args, line)
}
break
}
return true, args, Telnet, packet[i+1:], nil
}
}
return false, args[:0], Telnet, packet, nil
}
// appendPrefix will append a "$3\r\n" style redis prefix for a message.
func appendPrefix(b []byte, c byte, n int64) []byte {
if n >= 0 && n <= 9 {
return append(b, c, byte('0'+n), '\r', '\n')
}
b = append(b, c)
b = strconv.AppendInt(b, n, 10)
return append(b, '\r', '\n')
}
// AppendUint appends a Redis protocol uint64 to the input bytes.
func AppendUint(b []byte, n uint64) []byte {
b = append(b, ':')
b = strconv.AppendUint(b, n, 10)
return append(b, '\r', '\n')
}
// AppendInt appends a Redis protocol int64 to the input bytes.
func AppendInt(b []byte, n int64) []byte {
return appendPrefix(b, ':', n)
}
// AppendArray appends a Redis protocol array to the input bytes.
func AppendArray(b []byte, n int) []byte {
return appendPrefix(b, '*', int64(n))
}
// AppendBulk appends a Redis protocol bulk byte slice to the input bytes.
func AppendBulk(b []byte, bulk []byte) []byte {
b = appendPrefix(b, '$', int64(len(bulk)))
b = append(b, bulk...)
return append(b, '\r', '\n')
}
// AppendBulkString appends a Redis protocol bulk string to the input bytes.
func AppendBulkString(b []byte, bulk string) []byte {
b = appendPrefix(b, '$', int64(len(bulk)))
b = append(b, bulk...)
return append(b, '\r', '\n')
}
// AppendString appends a Redis protocol string to the input bytes.
func AppendString(b []byte, s string) []byte {
b = append(b, '+')
b = append(b, stripNewlines(s)...)
return append(b, '\r', '\n')
}
// AppendError appends a Redis protocol error to the input bytes.
func AppendError(b []byte, s string) []byte {
b = append(b, '-')
b = append(b, stripNewlines(s)...)
return append(b, '\r', '\n')
}
// AppendOK appends a Redis protocol OK to the input bytes.
func AppendOK(b []byte) []byte {
return append(b, '+', 'O', 'K', '\r', '\n')
}
func stripNewlines(s string) string {
for i := 0; i < len(s); i++ {
if s[i] == '\r' || s[i] == '\n' {
s = strings.Replace(s, "\r", " ", -1)
s = strings.Replace(s, "\n", " ", -1)
break
}
}
return s
}
// AppendTile38 appends a Tile38 message to the input bytes.
func AppendTile38(b []byte, data []byte) []byte {
b = append(b, '$')
b = strconv.AppendInt(b, int64(len(data)), 10)
b = append(b, ' ')
b = append(b, data...)
return append(b, '\r', '\n')
}
// AppendNull appends a Redis protocol null to the input bytes.
func AppendNull(b []byte) []byte {
return append(b, '$', '-', '1', '\r', '\n')
}
// AppendBulkFloat appends a float64, as bulk bytes.
func AppendBulkFloat(dst []byte, f float64) []byte {
return AppendBulk(dst, strconv.AppendFloat(nil, f, 'f', -1, 64))
}
// AppendBulkInt appends an int64, as bulk bytes.
func AppendBulkInt(dst []byte, x int64) []byte {
return AppendBulk(dst, strconv.AppendInt(nil, x, 10))
}
// AppendBulkUint appends an uint64, as bulk bytes.
func AppendBulkUint(dst []byte, x uint64) []byte {
return AppendBulk(dst, strconv.AppendUint(nil, x, 10))
}
func prefixERRIfNeeded(msg string) string {
msg = strings.TrimSpace(msg)
firstWord := strings.Split(msg, " ")[0]
addERR := len(firstWord) == 0
for i := 0; i < len(firstWord); i++ {
if firstWord[i] < 'A' || firstWord[i] > 'Z' {
addERR = true
break
}
}
if addERR {
msg = strings.TrimSpace("ERR " + msg)
}
return msg
}
// SimpleString is for representing a non-bulk representation of a string
// from an *Any call.
type SimpleString string
// SimpleInt is for representing a non-bulk representation of a int
// from an *Any call.
type SimpleInt int
// AppendAny appends any type to valid Redis type.
// nil -> null
// error -> error (adds "ERR " when first word is not uppercase)
// string -> bulk-string
// numbers -> bulk-string
// []byte -> bulk-string
// bool -> bulk-string ("0" or "1")
// slice -> array
// map -> array with key/value pairs
// SimpleString -> string
// SimpleInt -> integer
// everything-else -> bulk-string representation using fmt.Sprint()
func AppendAny(b []byte, v interface{}) []byte {
switch v := v.(type) {
case SimpleString:
b = AppendString(b, string(v))
case SimpleInt:
b = AppendInt(b, int64(v))
case nil:
b = AppendNull(b)
case error:
b = AppendError(b, prefixERRIfNeeded(v.Error()))
case string:
b = AppendBulkString(b, v)
case []byte:
b = AppendBulk(b, v)
case bool:
if v {
b = AppendBulkString(b, "1")
} else {
b = AppendBulkString(b, "0")
}
case int:
b = AppendBulkInt(b, int64(v))
case int8:
b = AppendBulkInt(b, int64(v))
case int16:
b = AppendBulkInt(b, int64(v))
case int32:
b = AppendBulkInt(b, int64(v))
case int64:
b = AppendBulkInt(b, int64(v))
case uint:
b = AppendBulkUint(b, uint64(v))
case uint8:
b = AppendBulkUint(b, uint64(v))
case uint16:
b = AppendBulkUint(b, uint64(v))
case uint32:
b = AppendBulkUint(b, uint64(v))
case uint64:
b = AppendBulkUint(b, uint64(v))
case float32:
b = AppendBulkFloat(b, float64(v))
case float64:
b = AppendBulkFloat(b, float64(v))
default:
vv := reflect.ValueOf(v)
switch vv.Kind() {
case reflect.Slice:
n := vv.Len()
b = AppendArray(b, n)
for i := 0; i < n; i++ {
b = AppendAny(b, vv.Index(i).Interface())
}
case reflect.Map:
n := vv.Len()
b = AppendArray(b, n*2)
var i int
var strKey bool
var strsKeyItems []strKeyItem
iter := vv.MapRange()
for iter.Next() {
key := iter.Key().Interface()
if i == 0 {
if _, ok := key.(string); ok {
strKey = true
strsKeyItems = make([]strKeyItem, n)
}
}
if strKey {
strsKeyItems[i] = strKeyItem{
key.(string), iter.Value().Interface(),
}
} else {
b = AppendAny(b, key)
b = AppendAny(b, iter.Value().Interface())
}
i++
}
if strKey {
sort.Slice(strsKeyItems, func(i, j int) bool {
return strsKeyItems[i].key < strsKeyItems[j].key
})
for _, item := range strsKeyItems {
b = AppendBulkString(b, item.key)
b = AppendAny(b, item.value)
}
}
default:
b = AppendBulkString(b, fmt.Sprint(v))
}
}
return b
}
type strKeyItem struct {
key string
value interface{}
}

127
append_test.go Normal file
View File

@ -0,0 +1,127 @@
package redcon
import (
"bytes"
"math/rand"
"testing"
"time"
)
func TestNextCommand(t *testing.T) {
rand.Seed(time.Now().UnixNano())
start := time.Now()
for time.Since(start) < time.Second {
// keep copy of pipeline args for final compare
var plargs [][][]byte
// create a pipeline of random number of commands with random data.
N := rand.Int() % 10000
var data []byte
for i := 0; i < N; i++ {
nargs := rand.Int() % 10
data = AppendArray(data, nargs)
var args [][]byte
for j := 0; j < nargs; j++ {
arg := make([]byte, rand.Int()%100)
if _, err := rand.Read(arg); err != nil {
t.Fatal(err)
}
data = AppendBulk(data, arg)
args = append(args, arg)
}
plargs = append(plargs, args)
}
// break data into random number of chunks
chunkn := rand.Int() % 100
if chunkn == 0 {
chunkn = 1
}
if len(data) < chunkn {
continue
}
var chunks [][]byte
var chunksz int
for i := 0; i < len(data); i += chunksz {
chunksz = rand.Int() % (len(data) / chunkn)
var chunk []byte
if i+chunksz < len(data) {
chunk = data[i : i+chunksz]
} else {
chunk = data[i:]
}
chunks = append(chunks, chunk)
}
// process chunks
var rbuf []byte
var fargs [][][]byte
for _, chunk := range chunks {
var data []byte
if len(rbuf) > 0 {
data = append(rbuf, chunk...)
} else {
data = chunk
}
for {
complete, args, _, leftover, err := ReadNextCommand(data, nil)
data = leftover
if err != nil {
t.Fatal(err)
}
if !complete {
break
}
fargs = append(fargs, args)
}
rbuf = append(rbuf[:0], data...)
}
// compare final args to original
if len(plargs) != len(fargs) {
t.Fatalf("not equal size: %v != %v", len(plargs), len(fargs))
}
for i := 0; i < len(plargs); i++ {
if len(plargs[i]) != len(fargs[i]) {
t.Fatalf("not equal size for item %v: %v != %v", i, len(plargs[i]), len(fargs[i]))
}
for j := 0; j < len(plargs[i]); j++ {
if !bytes.Equal(plargs[i][j], plargs[i][j]) {
t.Fatalf("not equal for item %v:%v: %v != %v", i, j, len(plargs[i][j]), len(fargs[i][j]))
}
}
}
}
}
func TestAppendBulkFloat(t *testing.T) {
var b []byte
b = AppendString(b, "HELLO")
b = AppendBulkFloat(b, 9.123192839)
b = AppendString(b, "HELLO")
exp := "+HELLO\r\n$11\r\n9.123192839\r\n+HELLO\r\n"
if string(b) != exp {
t.Fatalf("expected '%s', got '%s'", exp, b)
}
}
func TestAppendBulkInt(t *testing.T) {
var b []byte
b = AppendString(b, "HELLO")
b = AppendBulkInt(b, -9182739137)
b = AppendString(b, "HELLO")
exp := "+HELLO\r\n$11\r\n-9182739137\r\n+HELLO\r\n"
if string(b) != exp {
t.Fatalf("expected '%s', got '%s'", exp, b)
}
}
func TestAppendBulkUint(t *testing.T) {
var b []byte
b = AppendString(b, "HELLO")
b = AppendBulkInt(b, 91827391370)
b = AppendString(b, "HELLO")
exp := "+HELLO\r\n$11\r\n91827391370\r\n+HELLO\r\n"
if string(b) != exp {
t.Fatalf("expected '%s', got '%s'", exp, b)
}
}

View File

@ -5,25 +5,22 @@ import (
"strings" "strings"
"sync" "sync"
"git.internal/re/redcon" "github.com/tidwall/redcon"
) )
var addr = ":6380" var addr = ":6380"
func main() { func main() {
var mu sync.RWMutex var mu sync.RWMutex
items := make(map[string][]byte) var items = make(map[string][]byte)
var ps redcon.PubSub var ps redcon.PubSub
go log.Printf("started server at %s", addr) go log.Printf("started server at %s", addr)
err := redcon.ListenAndServe(addr, err := redcon.ListenAndServe(addr,
func(conn redcon.Conn, cmd redcon.Command) { func(conn redcon.Conn, cmd redcon.Command) {
switch strings.ToLower(string(cmd.Args[0])) { switch strings.ToLower(string(cmd.Args[0])) {
default: default:
conn.WriteError("ERR unknown command '" + string(cmd.Args[0]) + "'") conn.WriteError("ERR unknown command '" + string(cmd.Args[0]) + "'")
case "publish": case "publish":
// Publish to all pub/sub subscribers and return the number of
// messages that were sent.
if len(cmd.Args) != 3 { if len(cmd.Args) != 3 {
conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command")
return return
@ -31,10 +28,6 @@ func main() {
count := ps.Publish(string(cmd.Args[1]), string(cmd.Args[2])) count := ps.Publish(string(cmd.Args[1]), string(cmd.Args[2]))
conn.WriteInt(count) conn.WriteInt(count)
case "subscribe", "psubscribe": case "subscribe", "psubscribe":
// Subscribe to a pub/sub channel. The `Psubscribe` and
// `Subscribe` operations will detach the connection from the
// event handler and manage all network I/O for this connection
// in the background.
if len(cmd.Args) < 2 { if len(cmd.Args) < 2 {
conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command")
return return
@ -96,21 +89,15 @@ func main() {
} else { } else {
conn.WriteInt(1) conn.WriteInt(1)
} }
case "config":
// This simple (blank) response is only here to allow for the
// redis-benchmark command to work with this example.
conn.WriteArray(2)
conn.WriteBulk(cmd.Args[2])
conn.WriteBulkString("")
} }
}, },
func(conn redcon.Conn) bool { func(conn redcon.Conn) bool {
// Use this function to accept or deny the connection. // use this function to accept or deny the connection.
// log.Printf("accept: %s", conn.RemoteAddr()) // log.Printf("accept: %s", conn.RemoteAddr())
return true return true
}, },
func(conn redcon.Conn, err error) { func(conn redcon.Conn, err error) {
// This is called when the connection has been closed // this is called when the connection has been closed
// log.Printf("closed: %s, err: %v", conn.RemoteAddr(), err) // log.Printf("closed: %s, err: %v", conn.RemoteAddr(), err)
}, },
) )

View File

View File

@ -3,7 +3,7 @@ package main
import ( import (
"log" "log"
"git.internal/re/redcon" "github.com/tidwall/redcon"
) )
var addr = ":6380" var addr = ":6380"

View File

@ -4,7 +4,7 @@ import (
"log" "log"
"sync" "sync"
"git.internal/re/redcon" "github.com/tidwall/redcon"
) )
type Handler struct { type Handler struct {

View File

@ -6,7 +6,7 @@ import (
"strings" "strings"
"sync" "sync"
"git.internal/re/redcon" "github.com/tidwall/redcon"
) )
const serverKey = `-----BEGIN EC PARAMETERS----- const serverKey = `-----BEGIN EC PARAMETERS-----
@ -44,7 +44,7 @@ func main() {
config := &tls.Config{Certificates: []tls.Certificate{cer}} config := &tls.Config{Certificates: []tls.Certificate{cer}}
var mu sync.RWMutex var mu sync.RWMutex
items := make(map[string][]byte) var items = make(map[string][]byte)
go log.Printf("started server at %s", addr) go log.Printf("started server at %s", addr)
err = redcon.ListenAndServeTLS(addr, err = redcon.ListenAndServeTLS(addr,

8
go.mod
View File

@ -1,8 +1,8 @@
module git.internal/re/redcon module github.com/tidwall/redcon
go 1.19 go 1.15
require ( require (
github.com/tidwall/btree v1.1.0 github.com/tidwall/btree v0.2.2
github.com/tidwall/match v1.1.1 github.com/tidwall/match v1.0.1
) )

8
go.sum
View File

@ -1,4 +1,4 @@
github.com/tidwall/btree v1.1.0 h1:5P+9WU8ui5uhmcg3SoPyTwoI0mVyZ1nps7YQzTZFkYM= github.com/tidwall/btree v0.2.2 h1:VVo0JW/tdidNdQzNsDR4wMbL3heaxA1DGleyzQ3/niY=
github.com/tidwall/btree v1.1.0/go.mod h1:TzIRzen6yHbibdSfK6t8QimqbUnoxUSrZfeW7Uob0q4= github.com/tidwall/btree v0.2.2/go.mod h1:huei1BkDWJ3/sLXmO+bsCNELL+Bp2Kks9OLyQFkzvA8=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.0.1 h1:PnKP62LPNxHKTwvHHZZzdOAOCtsJTjo6dZLCwpKm5xc=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E=

View File

@ -1,3 +0,0 @@
git.internal/re/redcon,Unknown,MIT
github.com/tidwall/btree,https://github.com/tidwall/btree/blob/v1.1.0/LICENSE,MIT
github.com/tidwall/match,https://github.com/tidwall/match/blob/v1.1.1/LICENSE,MIT

347
pubsub.go Normal file
View File

@ -0,0 +1,347 @@
package redcon
import (
"fmt"
"strings"
"sync"
"github.com/tidwall/btree"
"github.com/tidwall/match"
)
// PubSub is a Redis compatible pub/sub server
type PubSub struct {
mu sync.RWMutex
nextid uint64
initd bool
chans *btree.BTree
conns map[Conn]*pubSubConn
}
// Subscribe a connection to PubSub
func (ps *PubSub) Subscribe(conn Conn, channel string) {
ps.subscribe(conn, false, channel)
}
// Psubscribe a connection to PubSub
func (ps *PubSub) Psubscribe(conn Conn, channel string) {
ps.subscribe(conn, true, channel)
}
// Publish a message to subscribers
func (ps *PubSub) Publish(channel, message string) int {
ps.mu.RLock()
defer ps.mu.RUnlock()
if !ps.initd {
return 0
}
var sent int
// write messages to all clients that are subscribed on the channel
pivot := &pubSubEntry{pattern: false, channel: channel}
ps.chans.Ascend(pivot, func(item interface{}) bool {
entry := item.(*pubSubEntry)
if entry.channel != pivot.channel || entry.pattern != pivot.pattern {
return false
}
entry.sconn.writeMessage(entry.pattern, "", channel, message)
sent++
return true
})
// match on and write all psubscribe clients
pivot = &pubSubEntry{pattern: true}
ps.chans.Ascend(pivot, func(item interface{}) bool {
entry := item.(*pubSubEntry)
if match.Match(channel, entry.channel) {
entry.sconn.writeMessage(entry.pattern, entry.channel, channel,
message)
}
sent++
return true
})
return sent
}
type pubSubConn struct {
id uint64
mu sync.Mutex
conn Conn
dconn DetachedConn
entries map[*pubSubEntry]bool
}
type pubSubEntry struct {
pattern bool
sconn *pubSubConn
channel string
}
func (sconn *pubSubConn) writeMessage(pat bool, pchan, channel, msg string) {
sconn.mu.Lock()
defer sconn.mu.Unlock()
if pat {
sconn.dconn.WriteArray(4)
sconn.dconn.WriteBulkString("pmessage")
sconn.dconn.WriteBulkString(pchan)
sconn.dconn.WriteBulkString(channel)
sconn.dconn.WriteBulkString(msg)
} else {
sconn.dconn.WriteArray(3)
sconn.dconn.WriteBulkString("message")
sconn.dconn.WriteBulkString(channel)
sconn.dconn.WriteBulkString(msg)
}
sconn.dconn.Flush()
}
// bgrunner runs in the background and reads incoming commands from the
// detached client.
func (sconn *pubSubConn) bgrunner(ps *PubSub) {
defer func() {
// client connection has ended, disconnect from the PubSub instances
// and close the network connection.
ps.mu.Lock()
defer ps.mu.Unlock()
for entry := range sconn.entries {
ps.chans.Delete(entry)
}
delete(ps.conns, sconn.conn)
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.Close()
}()
for {
cmd, err := sconn.dconn.ReadCommand()
if err != nil {
return
}
if len(cmd.Args) == 0 {
continue
}
switch strings.ToLower(string(cmd.Args[0])) {
case "psubscribe", "subscribe":
if len(cmd.Args) < 2 {
func() {
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.WriteError(fmt.Sprintf("ERR wrong number of "+
"arguments for '%s'", cmd.Args[0]))
sconn.dconn.Flush()
}()
continue
}
command := strings.ToLower(string(cmd.Args[0]))
for i := 1; i < len(cmd.Args); i++ {
if command == "psubscribe" {
ps.Psubscribe(sconn.conn, string(cmd.Args[i]))
} else {
ps.Subscribe(sconn.conn, string(cmd.Args[i]))
}
}
case "unsubscribe", "punsubscribe":
pattern := strings.ToLower(string(cmd.Args[0])) == "punsubscribe"
if len(cmd.Args) == 1 {
ps.unsubscribe(sconn.conn, pattern, true, "")
} else {
for i := 1; i < len(cmd.Args); i++ {
channel := string(cmd.Args[i])
ps.unsubscribe(sconn.conn, pattern, false, channel)
}
}
case "quit":
func() {
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.WriteString("OK")
sconn.dconn.Flush()
sconn.dconn.Close()
}()
return
case "ping":
var msg string
switch len(cmd.Args) {
case 1:
case 2:
msg = string(cmd.Args[1])
default:
func() {
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.WriteError(fmt.Sprintf("ERR wrong number of "+
"arguments for '%s'", cmd.Args[0]))
sconn.dconn.Flush()
}()
continue
}
func() {
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.WriteArray(2)
sconn.dconn.WriteBulkString("pong")
sconn.dconn.WriteBulkString(msg)
sconn.dconn.Flush()
}()
default:
func() {
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.WriteError(fmt.Sprintf("ERR Can't execute '%s': "+
"only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT are "+
"allowed in this context", cmd.Args[0]))
sconn.dconn.Flush()
}()
}
}
}
// byEntry is a "less" function that sorts the entries in a btree. The tree
// is sorted be (pattern, channel, conn.id). All pattern=true entries are at
// the end (right) of the tree.
func byEntry(a, b interface{}) bool {
aa := a.(*pubSubEntry)
bb := b.(*pubSubEntry)
if !aa.pattern && bb.pattern {
return true
}
if aa.pattern && !bb.pattern {
return false
}
if aa.channel < bb.channel {
return true
}
if aa.channel > bb.channel {
return false
}
var aid uint64
var bid uint64
if aa.sconn != nil {
aid = aa.sconn.id
}
if bb.sconn != nil {
bid = bb.sconn.id
}
return aid < bid
}
func (ps *PubSub) subscribe(conn Conn, pattern bool, channel string) {
ps.mu.Lock()
defer ps.mu.Unlock()
// initialize the PubSub instance
if !ps.initd {
ps.conns = make(map[Conn]*pubSubConn)
ps.chans = btree.New(byEntry)
ps.initd = true
}
// fetch the pubSubConn
sconn, ok := ps.conns[conn]
if !ok {
// initialize a new pubSubConn, which runs on a detached connection,
// and attach it to the PubSub channels/conn btree
ps.nextid++
dconn := conn.Detach()
sconn = &pubSubConn{
id: ps.nextid,
conn: conn,
dconn: dconn,
entries: make(map[*pubSubEntry]bool),
}
ps.conns[conn] = sconn
}
sconn.mu.Lock()
defer sconn.mu.Unlock()
// add an entry to the pubsub btree
entry := &pubSubEntry{
pattern: pattern,
channel: channel,
sconn: sconn,
}
ps.chans.Set(entry)
sconn.entries[entry] = true
// send a message to the client
sconn.dconn.WriteArray(3)
if pattern {
sconn.dconn.WriteBulkString("psubscribe")
} else {
sconn.dconn.WriteBulkString("subscribe")
}
sconn.dconn.WriteBulkString(channel)
var count int
for ient := range sconn.entries {
if ient.pattern == pattern {
count++
}
}
sconn.dconn.WriteInt(count)
sconn.dconn.Flush()
// start the background client operation
if !ok {
go sconn.bgrunner(ps)
}
}
func (ps *PubSub) unsubscribe(conn Conn, pattern, all bool, channel string) {
ps.mu.Lock()
defer ps.mu.Unlock()
// fetch the pubSubConn. This must exist
sconn := ps.conns[conn]
sconn.mu.Lock()
defer sconn.mu.Unlock()
removeEntry := func(entry *pubSubEntry) {
if entry != nil {
ps.chans.Delete(entry)
delete(sconn.entries, entry)
}
sconn.dconn.WriteArray(3)
if pattern {
sconn.dconn.WriteBulkString("punsubscribe")
} else {
sconn.dconn.WriteBulkString("unsubscribe")
}
if entry != nil {
sconn.dconn.WriteBulkString(entry.channel)
} else {
sconn.dconn.WriteNull()
}
var count int
for ient := range sconn.entries {
if ient.pattern == pattern {
count++
}
}
sconn.dconn.WriteInt(count)
}
if all {
// unsubscribe from all (p)subscribe entries
var entries []*pubSubEntry
for ient := range sconn.entries {
if ient.pattern == pattern {
entries = append(entries, ient)
}
}
if len(entries) == 0 {
removeEntry(nil)
} else {
for _, entry := range entries {
removeEntry(entry)
}
}
} else {
// unsubscribe single channel from (p)subscribe.
var entry *pubSubEntry
for ient := range sconn.entries {
if ient.pattern == pattern && ient.channel == channel {
removeEntry(entry)
break
}
}
removeEntry(entry)
}
sconn.dconn.Flush()
}

194
pubsub_test.go Normal file
View File

@ -0,0 +1,194 @@
package redcon
import (
"bufio"
"fmt"
"net"
"strconv"
"strings"
"sync"
"testing"
"time"
)
func TestPubSub(t *testing.T) {
addr := ":12346"
done := make(chan bool)
go func() {
var ps PubSub
go func() {
tch := time.NewTicker(time.Millisecond * 5)
defer tch.Stop()
channels := []string{"achan1", "bchan2", "cchan3", "dchan4"}
for i := 0; ; i++ {
select {
case <-tch.C:
case <-done:
for {
var empty bool
ps.mu.Lock()
if len(ps.conns) == 0 {
if ps.chans.Len() != 0 {
panic("chans not empty")
}
empty = true
}
ps.mu.Unlock()
if empty {
break
}
time.Sleep(time.Millisecond * 10)
}
done <- true
return
}
channel := channels[i%len(channels)]
message := fmt.Sprintf("message %d", i)
ps.Publish(channel, message)
}
}()
t.Fatal(ListenAndServe(addr, func(conn Conn, cmd Command) {
switch strings.ToLower(string(cmd.Args[0])) {
default:
conn.WriteError("ERR unknown command '" +
string(cmd.Args[0]) + "'")
case "publish":
if len(cmd.Args) != 3 {
conn.WriteError("ERR wrong number of arguments for '" +
string(cmd.Args[0]) + "' command")
return
}
count := ps.Publish(string(cmd.Args[1]), string(cmd.Args[2]))
conn.WriteInt(count)
case "subscribe", "psubscribe":
if len(cmd.Args) < 2 {
conn.WriteError("ERR wrong number of arguments for '" +
string(cmd.Args[0]) + "' command")
return
}
command := strings.ToLower(string(cmd.Args[0]))
for i := 1; i < len(cmd.Args); i++ {
if command == "psubscribe" {
ps.Psubscribe(conn, string(cmd.Args[i]))
} else {
ps.Subscribe(conn, string(cmd.Args[i]))
}
}
}
}, nil, nil))
}()
final := make(chan bool)
go func() {
select {
case <-time.Tick(time.Second * 30):
panic("timeout")
case <-final:
return
}
}()
// create 10 connections
var wg sync.WaitGroup
wg.Add(10)
for i := 0; i < 10; i++ {
go func(i int) {
defer wg.Done()
var conn net.Conn
for i := 0; i < 5; i++ {
var err error
conn, err = net.Dial("tcp", addr)
if err != nil {
time.Sleep(time.Second / 10)
continue
}
}
if conn == nil {
panic("could not connect to server")
}
defer conn.Close()
regs := make(map[string]int)
var maxp int
var maxs int
fmt.Fprintf(conn, "subscribe achan1\r\n")
fmt.Fprintf(conn, "subscribe bchan2 cchan3\r\n")
fmt.Fprintf(conn, "psubscribe a*1\r\n")
fmt.Fprintf(conn, "psubscribe b*2 c*3\r\n")
// collect 50 messages from each channel
rd := bufio.NewReader(conn)
var buf []byte
for {
line, err := rd.ReadBytes('\n')
if err != nil {
panic(err)
}
buf = append(buf, line...)
n, resp := ReadNextRESP(buf)
if n == 0 {
continue
}
buf = nil
if resp.Type != Array {
panic("expected array")
}
var vals []RESP
resp.ForEach(func(item RESP) bool {
vals = append(vals, item)
return true
})
name := string(vals[0].Data)
switch name {
case "subscribe":
if len(vals) != 3 {
panic("invalid count")
}
ch := string(vals[1].Data)
regs[ch] = 0
maxs, _ = strconv.Atoi(string(vals[2].Data))
case "psubscribe":
if len(vals) != 3 {
panic("invalid count")
}
ch := string(vals[1].Data)
regs[ch] = 0
maxp, _ = strconv.Atoi(string(vals[2].Data))
case "message":
if len(vals) != 3 {
panic("invalid count")
}
ch := string(vals[1].Data)
regs[ch] = regs[ch] + 1
case "pmessage":
if len(vals) != 4 {
panic("invalid count")
}
ch := string(vals[1].Data)
regs[ch] = regs[ch] + 1
}
if len(regs) == 6 && maxp == 3 && maxs == 3 {
ready := true
for _, count := range regs {
if count < 50 {
ready = false
break
}
}
if ready {
// all messages have been received
return
}
}
}
}(i)
}
wg.Wait()
// notify sender
done <- true
// wait for sender
<-done
// stop the timeout
final <- true
}

486
redcon.go
View File

@ -5,15 +5,10 @@ import (
"bufio" "bufio"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt"
"io" "io"
"net" "net"
"strings" "strings"
"sync" "sync"
"time"
"github.com/tidwall/btree"
"github.com/tidwall/match"
) )
var ( var (
@ -25,8 +20,6 @@ var (
errTooMuchData = errors.New("too much data") errTooMuchData = errors.New("too much data")
) )
const maxBufferCap = 262144
type errProtocol struct { type errProtocol struct {
msg string msg string
} }
@ -60,8 +53,8 @@ type Conn interface {
// For example to write two strings: // For example to write two strings:
// //
// c.WriteArray(2) // c.WriteArray(2)
// c.WriteBulkString("item 1") // c.WriteBulk("item 1")
// c.WriteBulkString("item 2") // c.WriteBulk("item 2")
WriteArray(count int) WriteArray(count int)
// WriteNull writes a null to the client // WriteNull writes a null to the client
WriteNull() WriteNull()
@ -145,12 +138,14 @@ func NewServerNetwork(
if handler == nil { if handler == nil {
panic("handler is nil") panic("handler is nil")
} }
s := newServer() s := &Server{
s.net = net net: net,
s.laddr = laddr laddr: laddr,
s.handler = handler handler: handler,
s.accept = accept accept: accept,
s.closed = closed closed: closed,
conns: make(map[*conn]bool),
}
return s return s
} }
@ -221,26 +216,22 @@ func (s *TLSServer) ListenAndServe() error {
return s.ListenServeAndSignal(nil) return s.ListenServeAndSignal(nil)
} }
func newServer() *Server {
s := &Server{
conns: make(map[*conn]bool),
}
return s
}
// Serve creates a new server and serves with the given net.Listener. // Serve creates a new server and serves with the given net.Listener.
func Serve(ln net.Listener, func Serve(ln net.Listener,
handler func(conn Conn, cmd Command), handler func(conn Conn, cmd Command),
accept func(conn Conn) bool, accept func(conn Conn) bool,
closed func(conn Conn, err error), closed func(conn Conn, err error),
) error { ) error {
s := newServer() s := &Server{
s.net = ln.Addr().Network() net: ln.Addr().Network(),
s.laddr = ln.Addr().String() laddr: ln.Addr().String(),
s.ln = ln ln: ln,
s.handler = handler handler: handler,
s.accept = accept accept: accept,
s.closed = closed closed: closed,
conns: make(map[*conn]bool),
}
return serve(s) return serve(s)
} }
@ -296,9 +287,7 @@ func (s *Server) ListenServeAndSignal(signal chan error) error {
} }
return err return err
} }
s.mu.Lock()
s.ln = ln s.ln = ln
s.mu.Unlock()
if signal != nil { if signal != nil {
signal <- nil signal <- nil
} }
@ -323,9 +312,7 @@ func (s *TLSServer) ListenServeAndSignal(signal chan error) error {
} }
return err return err
} }
s.mu.Lock()
s.ln = ln s.ln = ln
s.mu.Unlock()
if signal != nil { if signal != nil {
signal <- nil signal <- nil
} }
@ -353,10 +340,6 @@ func serve(s *Server) error {
if done { if done {
return nil return nil
} }
if errors.Is(err, net.ErrClosed) {
// see https://git.internal/re/redcon/issues/46
return nil
}
if s.AcceptError != nil { if s.AcceptError != nil {
s.AcceptError(err) s.AcceptError(err)
} }
@ -369,7 +352,6 @@ func serve(s *Server) error {
rd: NewReader(lnconn), rd: NewReader(lnconn),
} }
s.mu.Lock() s.mu.Lock()
c.idleClose = s.idleClose
s.conns[c] = true s.conns[c] = true
s.mu.Unlock() s.mu.Unlock()
if s.accept != nil && !s.accept(c) { if s.accept != nil && !s.accept(c) {
@ -409,9 +391,6 @@ func handle(s *Server, c *conn) {
// read commands and feed back to the client // read commands and feed back to the client
for { for {
// read pipeline commands // read pipeline commands
if c.idleClose != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.idleClose))
}
cmds, err := c.rd.readCommands(nil) cmds, err := c.rd.readCommands(nil)
if err != nil { if err != nil {
if err, ok := err.(*errProtocol); ok { if err, ok := err.(*errProtocol); ok {
@ -456,7 +435,6 @@ type conn struct {
detached bool detached bool
closed bool closed bool
cmds []Command cmds []Command
idleClose time.Duration
} }
func (c *conn) Close() error { func (c *conn) Close() error {
@ -484,11 +462,9 @@ func (c *conn) ReadPipeline() []Command {
c.cmds = nil c.cmds = nil
return cmds return cmds
} }
func (c *conn) PeekPipeline() []Command { func (c *conn) PeekPipeline() []Command {
return c.cmds return c.cmds
} }
func (c *conn) NetConn() net.Conn { func (c *conn) NetConn() net.Conn {
return c.conn return c.conn
} }
@ -570,7 +546,6 @@ type Server struct {
conns map[*conn]bool conns map[*conn]bool
ln net.Listener ln net.Listener
done bool done bool
idleClose time.Duration
// AcceptError is an optional function used to handle Accept errors. // AcceptError is an optional function used to handle Accept errors.
AcceptError func(err error) AcceptError func(err error)
@ -586,7 +561,6 @@ type TLSServer struct {
type Writer struct { type Writer struct {
w io.Writer w io.Writer
b []byte b []byte
err error
} }
// NewWriter creates a new RESP writer. // NewWriter creates a new RESP writer.
@ -598,9 +572,6 @@ func NewWriter(wr io.Writer) *Writer {
// WriteNull writes a null to the client // WriteNull writes a null to the client
func (w *Writer) WriteNull() { func (w *Writer) WriteNull() {
if w.err != nil {
return
}
w.b = AppendNull(w.b) w.b = AppendNull(w.b)
} }
@ -609,113 +580,74 @@ func (w *Writer) WriteNull() {
// For example to write two strings: // For example to write two strings:
// //
// c.WriteArray(2) // c.WriteArray(2)
// c.WriteBulkString("item 1") // c.WriteBulk("item 1")
// c.WriteBulkString("item 2") // c.WriteBulk("item 2")
func (w *Writer) WriteArray(count int) { func (w *Writer) WriteArray(count int) {
if w.err != nil {
return
}
w.b = AppendArray(w.b, count) w.b = AppendArray(w.b, count)
} }
// WriteBulk writes bulk bytes to the client. // WriteBulk writes bulk bytes to the client.
func (w *Writer) WriteBulk(bulk []byte) { func (w *Writer) WriteBulk(bulk []byte) {
if w.err != nil {
return
}
w.b = AppendBulk(w.b, bulk) w.b = AppendBulk(w.b, bulk)
} }
// WriteBulkString writes a bulk string to the client. // WriteBulkString writes a bulk string to the client.
func (w *Writer) WriteBulkString(bulk string) { func (w *Writer) WriteBulkString(bulk string) {
if w.err != nil {
return
}
w.b = AppendBulkString(w.b, bulk) w.b = AppendBulkString(w.b, bulk)
} }
// Buffer returns the unflushed buffer. This is a copy so changes // Buffer returns the unflushed buffer. This is a copy so changes
// to the resulting []byte will not affect the writer. // to the resulting []byte will not affect the writer.
func (w *Writer) Buffer() []byte { func (w *Writer) Buffer() []byte {
if w.err != nil {
return nil
}
return append([]byte(nil), w.b...) return append([]byte(nil), w.b...)
} }
// SetBuffer replaces the unflushed buffer with new bytes. // SetBuffer replaces the unflushed buffer with new bytes.
func (w *Writer) SetBuffer(raw []byte) { func (w *Writer) SetBuffer(raw []byte) {
if w.err != nil {
return
}
w.b = w.b[:0] w.b = w.b[:0]
w.b = append(w.b, raw...) w.b = append(w.b, raw...)
} }
// Flush writes all unflushed Write* calls to the underlying writer. // Flush writes all unflushed Write* calls to the underlying writer.
func (w *Writer) Flush() error { func (w *Writer) Flush() error {
if w.err != nil { if _, err := w.w.Write(w.b); err != nil {
return w.err return err
} }
_, w.err = w.w.Write(w.b)
if cap(w.b) > maxBufferCap || w.err != nil {
w.b = nil
} else {
w.b = w.b[:0] w.b = w.b[:0]
} return nil
return w.err
} }
// WriteError writes an error to the client. // WriteError writes an error to the client.
func (w *Writer) WriteError(msg string) { func (w *Writer) WriteError(msg string) {
if w.err != nil {
return
}
w.b = AppendError(w.b, msg) w.b = AppendError(w.b, msg)
} }
// WriteString writes a string to the client. // WriteString writes a string to the client.
func (w *Writer) WriteString(msg string) { func (w *Writer) WriteString(msg string) {
if w.err != nil {
return
}
w.b = AppendString(w.b, msg) w.b = AppendString(w.b, msg)
} }
// WriteInt writes an integer to the client. // WriteInt writes an integer to the client.
func (w *Writer) WriteInt(num int) { func (w *Writer) WriteInt(num int) {
if w.err != nil {
return
}
w.WriteInt64(int64(num)) w.WriteInt64(int64(num))
} }
// WriteInt64 writes a 64-bit signed integer to the client. // WriteInt64 writes a 64-bit signed integer to the client.
func (w *Writer) WriteInt64(num int64) { func (w *Writer) WriteInt64(num int64) {
if w.err != nil {
return
}
w.b = AppendInt(w.b, num) w.b = AppendInt(w.b, num)
} }
// WriteUint64 writes a 64-bit unsigned integer to the client. // WriteUint64 writes a 64-bit unsigned integer to the client.
func (w *Writer) WriteUint64(num uint64) { func (w *Writer) WriteUint64(num uint64) {
if w.err != nil {
return
}
w.b = AppendUint(w.b, num) w.b = AppendUint(w.b, num)
} }
// WriteRaw writes raw data to the client. // WriteRaw writes raw data to the client.
func (w *Writer) WriteRaw(data []byte) { func (w *Writer) WriteRaw(data []byte) {
if w.err != nil {
return
}
w.b = append(w.b, data...) w.b = append(w.b, data...)
} }
// WriteAny writes any type to client. // WriteAny writes any type to client.
//
// nil -> null // nil -> null
// error -> error (adds "ERR " when first word is not uppercase) // error -> error (adds "ERR " when first word is not uppercase)
// string -> bulk-string // string -> bulk-string
@ -728,9 +660,6 @@ func (w *Writer) WriteRaw(data []byte) {
// SimpleInt -> integer // SimpleInt -> integer
// everything-else -> bulk-string representation using fmt.Sprint() // everything-else -> bulk-string representation using fmt.Sprint()
func (w *Writer) WriteAny(v interface{}) { func (w *Writer) WriteAny(v interface{}) {
if w.err != nil {
return
}
w.b = AppendAny(w.b, v) w.b = AppendAny(w.b, v)
} }
@ -984,22 +913,6 @@ func (rd *Reader) readCommands(leftover *int) ([]Command, error) {
return rd.readCommands(leftover) return rd.readCommands(leftover)
} }
// ReadCommands reads the next pipeline commands.
func (rd *Reader) ReadCommands() ([]Command, error) {
for {
if len(rd.cmds) > 0 {
cmds := rd.cmds
rd.cmds = nil
return cmds, nil
}
cmds, err := rd.readCommands(nil)
if err != nil {
return []Command{}, err
}
rd.cmds = cmds
}
}
// ReadCommand reads the next command. // ReadCommand reads the next command.
func (rd *Reader) ReadCommand() (Command, error) { func (rd *Reader) ReadCommand() (Command, error) {
if len(rd.cmds) > 0 { if len(rd.cmds) > 0 {
@ -1027,6 +940,7 @@ func Parse(raw []byte) (Command, error) {
return Command{}, errTooMuchData return Command{}, errTooMuchData
} }
return cmds[0], nil return cmds[0], nil
} }
// A Handler responds to an RESP request. // A Handler responds to an RESP request.
@ -1091,351 +1005,3 @@ func (m *ServeMux) ServeRESP(conn Conn, cmd Command) {
conn.WriteError("ERR unknown command '" + command + "'") conn.WriteError("ERR unknown command '" + command + "'")
} }
} }
// PubSub is a Redis compatible pub/sub server
type PubSub struct {
mu sync.RWMutex
nextid uint64
initd bool
chans *btree.BTree
conns map[Conn]*pubSubConn
}
// Subscribe a connection to PubSub
func (ps *PubSub) Subscribe(conn Conn, channel string) {
ps.subscribe(conn, false, channel)
}
// Psubscribe a connection to PubSub
func (ps *PubSub) Psubscribe(conn Conn, channel string) {
ps.subscribe(conn, true, channel)
}
// Unsubscribe a connection from PubSub
func (ps *PubSub) Unsubscribe(conn Conn, pattern, all bool, channel string) {
ps.unsubscribe(conn, pattern, all, channel)
}
// Publish a message to subscribers
func (ps *PubSub) Publish(channel, message string) int {
ps.mu.RLock()
defer ps.mu.RUnlock()
if !ps.initd {
return 0
}
var sent int
// write messages to all clients that are subscribed on the channel
pivot := &pubSubEntry{pattern: false, channel: channel}
ps.chans.Ascend(pivot, func(item interface{}) bool {
entry := item.(*pubSubEntry)
if entry.channel != pivot.channel || entry.pattern != pivot.pattern {
return false
}
entry.sconn.writeMessage(entry.pattern, "", channel, message)
sent++
return true
})
// match on and write all psubscribe clients
pivot = &pubSubEntry{pattern: true}
ps.chans.Ascend(pivot, func(item interface{}) bool {
entry := item.(*pubSubEntry)
if match.Match(channel, entry.channel) {
entry.sconn.writeMessage(entry.pattern, entry.channel, channel,
message)
}
sent++
return true
})
return sent
}
type pubSubConn struct {
id uint64
mu sync.Mutex
conn Conn
dconn DetachedConn
entries map[*pubSubEntry]bool
}
type pubSubEntry struct {
pattern bool
sconn *pubSubConn
channel string
}
func (sconn *pubSubConn) writeMessage(pat bool, pchan, channel, msg string) {
sconn.mu.Lock()
defer sconn.mu.Unlock()
if pat {
sconn.dconn.WriteArray(4)
sconn.dconn.WriteBulkString("pmessage")
sconn.dconn.WriteBulkString(pchan)
sconn.dconn.WriteBulkString(channel)
sconn.dconn.WriteBulkString(msg)
} else {
sconn.dconn.WriteArray(3)
sconn.dconn.WriteBulkString("message")
sconn.dconn.WriteBulkString(channel)
sconn.dconn.WriteBulkString(msg)
}
sconn.dconn.Flush()
}
// bgrunner runs in the background and reads incoming commands from the
// detached client.
func (sconn *pubSubConn) bgrunner(ps *PubSub) {
defer func() {
// client connection has ended, disconnect from the PubSub instances
// and close the network connection.
ps.mu.Lock()
defer ps.mu.Unlock()
for entry := range sconn.entries {
ps.chans.Delete(entry)
}
delete(ps.conns, sconn.conn)
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.Close()
}()
for {
cmd, err := sconn.dconn.ReadCommand()
if err != nil {
return
}
if len(cmd.Args) == 0 {
continue
}
switch strings.ToLower(string(cmd.Args[0])) {
case "psubscribe", "subscribe":
if len(cmd.Args) < 2 {
func() {
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.WriteError(fmt.Sprintf("ERR wrong number of "+
"arguments for '%s'", cmd.Args[0]))
sconn.dconn.Flush()
}()
continue
}
command := strings.ToLower(string(cmd.Args[0]))
for i := 1; i < len(cmd.Args); i++ {
if command == "psubscribe" {
ps.Psubscribe(sconn.conn, string(cmd.Args[i]))
} else {
ps.Subscribe(sconn.conn, string(cmd.Args[i]))
}
}
case "unsubscribe", "punsubscribe":
pattern := strings.ToLower(string(cmd.Args[0])) == "punsubscribe"
if len(cmd.Args) == 1 {
ps.unsubscribe(sconn.conn, pattern, true, "")
} else {
for i := 1; i < len(cmd.Args); i++ {
channel := string(cmd.Args[i])
ps.unsubscribe(sconn.conn, pattern, false, channel)
}
}
case "quit":
func() {
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.WriteString("OK")
sconn.dconn.Flush()
sconn.dconn.Close()
}()
return
case "ping":
var msg string
switch len(cmd.Args) {
case 1:
case 2:
msg = string(cmd.Args[1])
default:
func() {
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.WriteError(fmt.Sprintf("ERR wrong number of "+
"arguments for '%s'", cmd.Args[0]))
sconn.dconn.Flush()
}()
continue
}
func() {
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.WriteArray(2)
sconn.dconn.WriteBulkString("pong")
sconn.dconn.WriteBulkString(msg)
sconn.dconn.Flush()
}()
default:
func() {
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.WriteError(fmt.Sprintf("ERR Can't execute '%s': "+
"only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT are "+
"allowed in this context", cmd.Args[0]))
sconn.dconn.Flush()
}()
}
}
}
// byEntry is a "less" function that sorts the entries in a btree. The tree
// is sorted be (pattern, channel, conn.id). All pattern=true entries are at
// the end (right) of the tree.
func byEntry(a, b interface{}) bool {
aa := a.(*pubSubEntry)
bb := b.(*pubSubEntry)
if !aa.pattern && bb.pattern {
return true
}
if aa.pattern && !bb.pattern {
return false
}
if aa.channel < bb.channel {
return true
}
if aa.channel > bb.channel {
return false
}
var aid uint64
var bid uint64
if aa.sconn != nil {
aid = aa.sconn.id
}
if bb.sconn != nil {
bid = bb.sconn.id
}
return aid < bid
}
func (ps *PubSub) subscribe(conn Conn, pattern bool, channel string) {
ps.mu.Lock()
defer ps.mu.Unlock()
// initialize the PubSub instance
if !ps.initd {
ps.conns = make(map[Conn]*pubSubConn)
ps.chans = btree.New(byEntry)
ps.initd = true
}
// fetch the pubSubConn
sconn, ok := ps.conns[conn]
if !ok {
// initialize a new pubSubConn, which runs on a detached connection,
// and attach it to the PubSub channels/conn btree
ps.nextid++
dconn := conn.Detach()
sconn = &pubSubConn{
id: ps.nextid,
conn: conn,
dconn: dconn,
entries: make(map[*pubSubEntry]bool),
}
ps.conns[conn] = sconn
}
sconn.mu.Lock()
defer sconn.mu.Unlock()
// add an entry to the pubsub btree
entry := &pubSubEntry{
pattern: pattern,
channel: channel,
sconn: sconn,
}
ps.chans.Set(entry)
sconn.entries[entry] = true
// send a message to the client
sconn.dconn.WriteArray(3)
if pattern {
sconn.dconn.WriteBulkString("psubscribe")
} else {
sconn.dconn.WriteBulkString("subscribe")
}
sconn.dconn.WriteBulkString(channel)
var count int
for entry := range sconn.entries {
if entry.pattern == pattern {
count++
}
}
sconn.dconn.WriteInt(count)
sconn.dconn.Flush()
// start the background client operation
if !ok {
go sconn.bgrunner(ps)
}
}
func (ps *PubSub) unsubscribe(conn Conn, pattern, all bool, channel string) {
ps.mu.Lock()
defer ps.mu.Unlock()
// fetch the pubSubConn. This must exist
sconn := ps.conns[conn]
sconn.mu.Lock()
defer sconn.mu.Unlock()
removeEntry := func(entry *pubSubEntry) {
if entry != nil {
ps.chans.Delete(entry)
delete(sconn.entries, entry)
}
sconn.dconn.WriteArray(3)
if pattern {
sconn.dconn.WriteBulkString("punsubscribe")
} else {
sconn.dconn.WriteBulkString("unsubscribe")
}
if entry != nil {
sconn.dconn.WriteBulkString(entry.channel)
} else {
sconn.dconn.WriteNull()
}
var count int
for entry := range sconn.entries {
if entry.pattern == pattern {
count++
}
}
sconn.dconn.WriteInt(count)
}
if all {
// unsubscribe from all (p)subscribe entries
var entries []*pubSubEntry
for entry := range sconn.entries {
if entry.pattern == pattern {
entries = append(entries, entry)
}
}
if len(entries) == 0 {
removeEntry(nil)
} else {
for _, entry := range entries {
removeEntry(entry)
}
}
} else {
// unsubscribe single channel from (p)subscribe.
for entry := range sconn.entries {
if entry.pattern == pattern && entry.channel == channel {
removeEntry(entry)
break
}
}
}
sconn.dconn.Flush()
}
// SetIdleClose will automatically close idle connections after the specified
// duration. Use zero to disable this feature.
func (s *Server) SetIdleClose(dur time.Duration) {
s.mu.Lock()
s.idleClose = dur
s.mu.Unlock()
}

View File

@ -1,7 +1,6 @@
package redcon package redcon
import ( import (
"bufio"
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
@ -11,7 +10,6 @@ import (
"os" "os"
"strconv" "strconv"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
) )
@ -200,14 +198,14 @@ func TestRandomCommands(t *testing.T) {
cnt++ cnt++
} }
if false { if false {
dur := time.Since(start) dur := time.Now().Sub(start)
fmt.Printf("%d commands in %s - %.0f ops/sec\n", cnt, dur, float64(cnt)/(float64(dur)/float64(time.Second))) fmt.Printf("%d commands in %s - %.0f ops/sec\n", cnt, dur, float64(cnt)/(float64(dur)/float64(time.Second)))
} }
} }
func testDetached(conn DetachedConn) { func testDetached(t *testing.T, conn DetachedConn) {
conn.WriteString("DETACHED") conn.WriteString("DETACHED")
if err := conn.Flush(); err != nil { if err := conn.Flush(); err != nil {
panic(err) t.Fatal(err)
} }
} }
func TestServerTCP(t *testing.T) { func TestServerTCP(t *testing.T) {
@ -231,7 +229,7 @@ func testServerNetwork(t *testing.T, network, laddr string) {
conn.WriteString("OK") conn.WriteString("OK")
conn.Close() conn.Close()
case "detach": case "detach":
go testDetached(conn.Detach()) go testDetached(t, conn.Detach())
case "int": case "int":
conn.WriteInt(100) conn.WriteInt(100)
case "bulk": case "bulk":
@ -262,17 +260,17 @@ func testServerNetwork(t *testing.T, network, laddr string) {
go func() { go func() {
time.Sleep(time.Second / 4) time.Sleep(time.Second / 4)
if err := ListenAndServeNetwork(network, laddr, func(conn Conn, cmd Command) {}, nil, nil); err == nil { if err := ListenAndServeNetwork(network, laddr, func(conn Conn, cmd Command) {}, nil, nil); err == nil {
panic("expected an error, should not be able to listen on the same port") t.Fatalf("expected an error, should not be able to listen on the same port")
} }
time.Sleep(time.Second / 4) time.Sleep(time.Second / 4)
err := s.Close() err := s.Close()
if err != nil { if err != nil {
panic(err) t.Fatal(err)
} }
err = s.Close() err = s.Close()
if err == nil { if err == nil {
panic("expected an error") t.Fatalf("expected an error")
} }
}() }()
done := make(chan bool) done := make(chan bool)
@ -283,11 +281,11 @@ func testServerNetwork(t *testing.T, network, laddr string) {
}() }()
err := <-signal err := <-signal
if err != nil { if err != nil {
panic(err) t.Fatal(err)
} }
c, err := net.Dial(network, laddr) c, err := net.Dial(network, laddr)
if err != nil { if err != nil {
panic(err) t.Fatal(err)
} }
defer c.Close() defer c.Close()
do := func(cmd string) (string, error) { do := func(cmd string) (string, error) {
@ -301,65 +299,65 @@ func testServerNetwork(t *testing.T, network, laddr string) {
} }
res, err := do("PING\r\n") res, err := do("PING\r\n")
if err != nil { if err != nil {
panic(err) t.Fatal(err)
} }
if res != "+PONG\r\n" { if res != "+PONG\r\n" {
panic(fmt.Sprintf("expecting '+PONG\r\n', got '%v'", res)) t.Fatalf("expecting '+PONG\r\n', got '%v'", res)
} }
res, err = do("BULK\r\n") res, err = do("BULK\r\n")
if err != nil { if err != nil {
panic(err) t.Fatal(err)
} }
if res != "$4\r\nbulk\r\n" { if res != "$4\r\nbulk\r\n" {
panic(fmt.Sprintf("expecting bulk, got '%v'", res)) t.Fatalf("expecting bulk, got '%v'", res)
} }
res, err = do("BULKBYTES\r\n") res, err = do("BULKBYTES\r\n")
if err != nil { if err != nil {
panic(err) t.Fatal(err)
} }
if res != "$9\r\nbulkbytes\r\n" { if res != "$9\r\nbulkbytes\r\n" {
panic(fmt.Sprintf("expecting bulkbytes, got '%v'", res)) t.Fatalf("expecting bulkbytes, got '%v'", res)
} }
res, err = do("INT\r\n") res, err = do("INT\r\n")
if err != nil { if err != nil {
panic(err) t.Fatal(err)
} }
if res != ":100\r\n" { if res != ":100\r\n" {
panic(fmt.Sprintf("expecting int, got '%v'", res)) t.Fatalf("expecting int, got '%v'", res)
} }
res, err = do("NULL\r\n") res, err = do("NULL\r\n")
if err != nil { if err != nil {
panic(err) t.Fatal(err)
} }
if res != "$-1\r\n" { if res != "$-1\r\n" {
panic(fmt.Sprintf("expecting nul, got '%v'", res)) t.Fatalf("expecting nul, got '%v'", res)
} }
res, err = do("ARRAY\r\n") res, err = do("ARRAY\r\n")
if err != nil { if err != nil {
panic(err) t.Fatal(err)
} }
if res != "*2\r\n:99\r\n+Hi!\r\n" { if res != "*2\r\n:99\r\n+Hi!\r\n" {
panic(fmt.Sprintf("expecting array, got '%v'", res)) t.Fatalf("expecting array, got '%v'", res)
} }
res, err = do("ERR\r\n") res, err = do("ERR\r\n")
if err != nil { if err != nil {
panic(err) t.Fatal(err)
} }
if res != "-ERR error\r\n" { if res != "-ERR error\r\n" {
panic(fmt.Sprintf("expecting array, got '%v'", res)) t.Fatalf("expecting array, got '%v'", res)
} }
res, err = do("DETACH\r\n") res, err = do("DETACH\r\n")
if err != nil { if err != nil {
panic(err) t.Fatal(err)
} }
if res != "+DETACHED\r\n" { if res != "+DETACHED\r\n" {
panic(fmt.Sprintf("expecting string, got '%v'", res)) t.Fatalf("expecting string, got '%v'", res)
} }
}() }()
go func() { go func() {
err := s.ListenServeAndSignal(signal) err := s.ListenServeAndSignal(signal)
if err != nil { if err != nil {
panic(err) t.Fatal(err)
} }
}() }()
<-done <-done
@ -432,12 +430,12 @@ func TestReaderRespRandom(t *testing.T) {
for h := 0; h < 10000; h++ { for h := 0; h < 10000; h++ {
var rawargs [][]string var rawargs [][]string
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
// var args []string var args []string
n := int(rand.Int() % 16) n := int(rand.Int() % 16)
for j := 0; j < n; j++ { for j := 0; j < n; j++ {
arg := make([]byte, rand.Int()%512) arg := make([]byte, rand.Int()%512)
rand.Read(arg) rand.Read(arg)
// args = append(args, string(arg)) args = append(args, string(arg))
} }
} }
rawcmds := testMakeRawCommands(rawargs) rawcmds := testMakeRawCommands(rawargs)
@ -556,185 +554,3 @@ func TestParse(t *testing.T) {
t.Fatalf("expected '%v', got '%v'", "A", string(cmd.Args[0])) t.Fatalf("expected '%v', got '%v'", "A", string(cmd.Args[0]))
} }
} }
func TestPubSub(t *testing.T) {
addr := ":12346"
done := make(chan bool)
go func() {
var ps PubSub
go func() {
tch := time.NewTicker(time.Millisecond * 5)
defer tch.Stop()
channels := []string{"achan1", "bchan2", "cchan3", "dchan4"}
for i := 0; ; i++ {
select {
case <-tch.C:
case <-done:
for {
var empty bool
ps.mu.Lock()
if len(ps.conns) == 0 {
if ps.chans.Len() != 0 {
panic("chans not empty")
}
empty = true
}
ps.mu.Unlock()
if empty {
break
}
time.Sleep(time.Millisecond * 10)
}
done <- true
return
}
channel := channels[i%len(channels)]
message := fmt.Sprintf("message %d", i)
ps.Publish(channel, message)
}
}()
panic(ListenAndServe(addr, func(conn Conn, cmd Command) {
switch strings.ToLower(string(cmd.Args[0])) {
default:
conn.WriteError("ERR unknown command '" +
string(cmd.Args[0]) + "'")
case "publish":
if len(cmd.Args) != 3 {
conn.WriteError("ERR wrong number of arguments for '" +
string(cmd.Args[0]) + "' command")
return
}
count := ps.Publish(string(cmd.Args[1]), string(cmd.Args[2]))
conn.WriteInt(count)
case "subscribe", "psubscribe":
if len(cmd.Args) < 2 {
conn.WriteError("ERR wrong number of arguments for '" +
string(cmd.Args[0]) + "' command")
return
}
command := strings.ToLower(string(cmd.Args[0]))
for i := 1; i < len(cmd.Args); i++ {
if command == "psubscribe" {
ps.Psubscribe(conn, string(cmd.Args[i]))
} else {
ps.Subscribe(conn, string(cmd.Args[i]))
}
}
}
}, nil, nil))
}()
final := make(chan bool)
go func() {
select {
case <-time.Tick(time.Second * 30):
panic("timeout")
case <-final:
return
}
}()
// create 10 connections
var wg sync.WaitGroup
wg.Add(10)
for i := 0; i < 10; i++ {
go func(i int) {
defer wg.Done()
var conn net.Conn
for i := 0; i < 5; i++ {
var err error
conn, err = net.Dial("tcp", addr)
if err != nil {
time.Sleep(time.Second / 10)
continue
}
}
if conn == nil {
panic("could not connect to server")
}
defer conn.Close()
regs := make(map[string]int)
var maxp int
var maxs int
fmt.Fprintf(conn, "subscribe achan1\r\n")
fmt.Fprintf(conn, "subscribe bchan2 cchan3\r\n")
fmt.Fprintf(conn, "psubscribe a*1\r\n")
fmt.Fprintf(conn, "psubscribe b*2 c*3\r\n")
// collect 50 messages from each channel
rd := bufio.NewReader(conn)
var buf []byte
for {
line, err := rd.ReadBytes('\n')
if err != nil {
panic(err)
}
buf = append(buf, line...)
n, resp := ReadNextRESP(buf)
if n == 0 {
continue
}
buf = nil
if resp.Type != Array {
panic("expected array")
}
var vals []RESP
resp.ForEach(func(item RESP) bool {
vals = append(vals, item)
return true
})
name := string(vals[0].Data)
switch name {
case "subscribe":
if len(vals) != 3 {
panic("invalid count")
}
ch := string(vals[1].Data)
regs[ch] = 0
maxs, _ = strconv.Atoi(string(vals[2].Data))
case "psubscribe":
if len(vals) != 3 {
panic("invalid count")
}
ch := string(vals[1].Data)
regs[ch] = 0
maxp, _ = strconv.Atoi(string(vals[2].Data))
case "message":
if len(vals) != 3 {
panic("invalid count")
}
ch := string(vals[1].Data)
regs[ch] = regs[ch] + 1
case "pmessage":
if len(vals) != 4 {
panic("invalid count")
}
ch := string(vals[1].Data)
regs[ch] = regs[ch] + 1
}
if len(regs) == 6 && maxp == 3 && maxs == 3 {
ready := true
for _, count := range regs {
if count < 50 {
ready = false
break
}
}
if ready {
// all messages have been received
return
}
}
}
}(i)
}
wg.Wait()
// notify sender
done <- true
// wait for sender
<-done
// stop the timeout
final <- true
}

547
resp.go
View File

@ -1,11 +1,7 @@
package redcon package redcon
import ( import (
"fmt"
"reflect"
"sort"
"strconv" "strconv"
"strings"
) )
// Type of RESP // Type of RESP
@ -20,6 +16,7 @@ const (
Error = '-' Error = '-'
) )
// RESP ...
type RESP struct { type RESP struct {
Type Type Type Type
Raw []byte Raw []byte
@ -28,7 +25,7 @@ type RESP struct {
} }
// ForEach iterates over each Array element // ForEach iterates over each Array element
func (r RESP) ForEach(iter func(resp RESP) bool) { func (r *RESP) ForEach(iter func(resp RESP) bool) {
data := r.Data data := r.Data
for i := 0; i < r.Count; i++ { for i := 0; i < r.Count; i++ {
n, resp := ReadNextRESP(data) n, resp := ReadNextRESP(data)
@ -39,71 +36,6 @@ func (r RESP) ForEach(iter func(resp RESP) bool) {
} }
} }
func (r RESP) Bytes() []byte {
return r.Data
}
func (r RESP) String() string {
return string(r.Data)
}
func (r RESP) Int() int64 {
x, _ := strconv.ParseInt(r.String(), 10, 64)
return x
}
func (r RESP) Float() float64 {
x, _ := strconv.ParseFloat(r.String(), 10)
return x
}
// Map returns a key/value map of an Array.
// The receiver RESP must be an Array with an equal number of values, where
// the value of the key is followed by the key.
// Example: key1,value1,key2,value2,key3,value3
func (r RESP) Map() map[string]RESP {
if r.Type != Array {
return nil
}
var n int
var key string
m := make(map[string]RESP)
r.ForEach(func(resp RESP) bool {
if n&1 == 0 {
key = resp.String()
} else {
m[key] = resp
}
n++
return true
})
return m
}
func (r RESP) MapGet(key string) RESP {
if r.Type != Array {
return RESP{}
}
var val RESP
var n int
var ok bool
r.ForEach(func(resp RESP) bool {
if n&1 == 0 {
ok = resp.String() == key
} else if ok {
val = resp
return false
}
n++
return true
})
return val
}
func (r RESP) Exists() bool {
return r.Type != 0
}
// ReadNextRESP returns the next resp in b and returns the number of bytes the // ReadNextRESP returns the next resp in b and returns the number of bytes the
// took up the result. // took up the result.
func ReadNextRESP(b []byte) (n int, resp RESP) { func ReadNextRESP(b []byte) (n int, resp RESP) {
@ -196,478 +128,3 @@ func ReadNextRESP(b []byte) (n int, resp RESP) {
resp.Raw = b[0 : i+tn] resp.Raw = b[0 : i+tn]
return len(resp.Raw), resp return len(resp.Raw), resp
} }
// Kind is the kind of command
type Kind int
const (
// Redis is returned for Redis protocol commands
Redis Kind = iota
// Tile38 is returnd for Tile38 native protocol commands
Tile38
// Telnet is returnd for plain telnet commands
Telnet
)
var errInvalidMessage = &errProtocol{"invalid message"}
// ReadNextCommand reads the next command from the provided packet. It's
// possible that the packet contains multiple commands, or zero commands
// when the packet is incomplete.
// 'argsbuf' is an optional reusable buffer and it can be nil.
// 'complete' indicates that a command was read. false means no more commands.
// 'args' are the output arguments for the command.
// 'kind' is the type of command that was read.
// 'leftover' is any remaining unused bytes which belong to the next command.
// 'err' is returned when a protocol error was encountered.
func ReadNextCommand(packet []byte, argsbuf [][]byte) (
complete bool, args [][]byte, kind Kind, leftover []byte, err error,
) {
args = argsbuf[:0]
if len(packet) > 0 {
if packet[0] != '*' {
if packet[0] == '$' {
return readTile38Command(packet, args)
}
return readTelnetCommand(packet, args)
}
// standard redis command
for s, i := 1, 1; i < len(packet); i++ {
if packet[i] == '\n' {
if packet[i-1] != '\r' {
return false, args[:0], Redis, packet, errInvalidMultiBulkLength
}
count, ok := parseInt(packet[s : i-1])
if !ok || count < 0 {
return false, args[:0], Redis, packet, errInvalidMultiBulkLength
}
i++
if count == 0 {
return true, args[:0], Redis, packet[i:], nil
}
nextArg:
for j := 0; j < count; j++ {
if i == len(packet) {
break
}
if packet[i] != '$' {
return false, args[:0], Redis, packet,
&errProtocol{"expected '$', got '" +
string(packet[i]) + "'"}
}
for s := i + 1; i < len(packet); i++ {
if packet[i] == '\n' {
if packet[i-1] != '\r' {
return false, args[:0], Redis, packet, errInvalidBulkLength
}
n, ok := parseInt(packet[s : i-1])
if !ok || count <= 0 {
return false, args[:0], Redis, packet, errInvalidBulkLength
}
i++
if len(packet)-i >= n+2 {
if packet[i+n] != '\r' || packet[i+n+1] != '\n' {
return false, args[:0], Redis, packet, errInvalidBulkLength
}
args = append(args, packet[i:i+n])
i += n + 2
if j == count-1 {
// done reading
return true, args, Redis, packet[i:], nil
}
continue nextArg
}
break
}
}
break
}
break
}
}
}
return false, args[:0], Redis, packet, nil
}
func readTile38Command(packet []byte, argsbuf [][]byte) (
complete bool, args [][]byte, kind Kind, leftover []byte, err error,
) {
for i := 1; i < len(packet); i++ {
if packet[i] == ' ' {
n, ok := parseInt(packet[1:i])
if !ok || n < 0 {
return false, args[:0], Tile38, packet, errInvalidMessage
}
i++
if len(packet) >= i+n+2 {
if packet[i+n] != '\r' || packet[i+n+1] != '\n' {
return false, args[:0], Tile38, packet, errInvalidMessage
}
line := packet[i : i+n]
reading:
for len(line) != 0 {
if line[0] == '{' {
// The native protocol cannot understand json boundaries so it assumes that
// a json element must be at the end of the line.
args = append(args, line)
break
}
if line[0] == '"' && line[len(line)-1] == '"' {
if len(args) > 0 &&
strings.ToLower(string(args[0])) == "set" &&
strings.ToLower(string(args[len(args)-1])) == "string" {
// Setting a string value that is contained inside double quotes.
// This is only because of the boundary issues of the native protocol.
args = append(args, line[1:len(line)-1])
break
}
}
i := 0
for ; i < len(line); i++ {
if line[i] == ' ' {
value := line[:i]
if len(value) > 0 {
args = append(args, value)
}
line = line[i+1:]
continue reading
}
}
args = append(args, line)
break
}
return true, args, Tile38, packet[i+n+2:], nil
}
break
}
}
return false, args[:0], Tile38, packet, nil
}
func readTelnetCommand(packet []byte, argsbuf [][]byte) (
complete bool, args [][]byte, kind Kind, leftover []byte, err error,
) {
// just a plain text command
for i := 0; i < len(packet); i++ {
if packet[i] == '\n' {
var line []byte
if i > 0 && packet[i-1] == '\r' {
line = packet[:i-1]
} else {
line = packet[:i]
}
var quote bool
var quotech byte
var escape bool
outer:
for {
nline := make([]byte, 0, len(line))
for i := 0; i < len(line); i++ {
c := line[i]
if !quote {
if c == ' ' {
if len(nline) > 0 {
args = append(args, nline)
}
line = line[i+1:]
continue outer
}
if c == '"' || c == '\'' {
if i != 0 {
return false, args[:0], Telnet, packet, errUnbalancedQuotes
}
quotech = c
quote = true
line = line[i+1:]
continue outer
}
} else {
if escape {
escape = false
switch c {
case 'n':
c = '\n'
case 'r':
c = '\r'
case 't':
c = '\t'
}
} else if c == quotech {
quote = false
quotech = 0
args = append(args, nline)
line = line[i+1:]
if len(line) > 0 && line[0] != ' ' {
return false, args[:0], Telnet, packet, errUnbalancedQuotes
}
continue outer
} else if c == '\\' {
escape = true
continue
}
}
nline = append(nline, c)
}
if quote {
return false, args[:0], Telnet, packet, errUnbalancedQuotes
}
if len(line) > 0 {
args = append(args, line)
}
break
}
return true, args, Telnet, packet[i+1:], nil
}
}
return false, args[:0], Telnet, packet, nil
}
// appendPrefix will append a "$3\r\n" style redis prefix for a message.
func appendPrefix(b []byte, c byte, n int64) []byte {
if n >= 0 && n <= 9 {
return append(b, c, byte('0'+n), '\r', '\n')
}
b = append(b, c)
b = strconv.AppendInt(b, n, 10)
return append(b, '\r', '\n')
}
// AppendUint appends a Redis protocol uint64 to the input bytes.
func AppendUint(b []byte, n uint64) []byte {
b = append(b, ':')
b = strconv.AppendUint(b, n, 10)
return append(b, '\r', '\n')
}
// AppendInt appends a Redis protocol int64 to the input bytes.
func AppendInt(b []byte, n int64) []byte {
return appendPrefix(b, ':', n)
}
// AppendArray appends a Redis protocol array to the input bytes.
func AppendArray(b []byte, n int) []byte {
return appendPrefix(b, '*', int64(n))
}
// AppendBulk appends a Redis protocol bulk byte slice to the input bytes.
func AppendBulk(b []byte, bulk []byte) []byte {
b = appendPrefix(b, '$', int64(len(bulk)))
b = append(b, bulk...)
return append(b, '\r', '\n')
}
// AppendBulkString appends a Redis protocol bulk string to the input bytes.
func AppendBulkString(b []byte, bulk string) []byte {
b = appendPrefix(b, '$', int64(len(bulk)))
b = append(b, bulk...)
return append(b, '\r', '\n')
}
// AppendString appends a Redis protocol string to the input bytes.
func AppendString(b []byte, s string) []byte {
b = append(b, '+')
b = append(b, stripNewlines(s)...)
return append(b, '\r', '\n')
}
// AppendError appends a Redis protocol error to the input bytes.
func AppendError(b []byte, s string) []byte {
b = append(b, '-')
b = append(b, stripNewlines(s)...)
return append(b, '\r', '\n')
}
// AppendOK appends a Redis protocol OK to the input bytes.
func AppendOK(b []byte) []byte {
return append(b, '+', 'O', 'K', '\r', '\n')
}
func stripNewlines(s string) string {
for i := 0; i < len(s); i++ {
if s[i] == '\r' || s[i] == '\n' {
s = strings.Replace(s, "\r", " ", -1)
s = strings.Replace(s, "\n", " ", -1)
break
}
}
return s
}
// AppendTile38 appends a Tile38 message to the input bytes.
func AppendTile38(b []byte, data []byte) []byte {
b = append(b, '$')
b = strconv.AppendInt(b, int64(len(data)), 10)
b = append(b, ' ')
b = append(b, data...)
return append(b, '\r', '\n')
}
// AppendNull appends a Redis protocol null to the input bytes.
func AppendNull(b []byte) []byte {
return append(b, '$', '-', '1', '\r', '\n')
}
// AppendBulkFloat appends a float64, as bulk bytes.
func AppendBulkFloat(dst []byte, f float64) []byte {
return AppendBulk(dst, strconv.AppendFloat(nil, f, 'f', -1, 64))
}
// AppendBulkInt appends an int64, as bulk bytes.
func AppendBulkInt(dst []byte, x int64) []byte {
return AppendBulk(dst, strconv.AppendInt(nil, x, 10))
}
// AppendBulkUint appends an uint64, as bulk bytes.
func AppendBulkUint(dst []byte, x uint64) []byte {
return AppendBulk(dst, strconv.AppendUint(nil, x, 10))
}
func prefixERRIfNeeded(msg string) string {
msg = strings.TrimSpace(msg)
firstWord := strings.Split(msg, " ")[0]
addERR := len(firstWord) == 0
for i := 0; i < len(firstWord); i++ {
if firstWord[i] < 'A' || firstWord[i] > 'Z' {
addERR = true
break
}
}
if addERR {
msg = strings.TrimSpace("ERR " + msg)
}
return msg
}
// SimpleString is for representing a non-bulk representation of a string
// from an *Any call.
type SimpleString string
// SimpleInt is for representing a non-bulk representation of a int
// from an *Any call.
type SimpleInt int
// SimpleError is for representing an error without adding the "ERR" prefix
// from an *Any call.
type SimpleError error
// Marshaler is the interface implemented by types that
// can marshal themselves into a Redis response type from an *Any call.
// The return value is not check for validity.
type Marshaler interface {
MarshalRESP() []byte
}
// AppendAny appends any type to valid Redis type.
// nil -> null
// error -> error (adds "ERR " when first word is not uppercase)
// string -> bulk-string
// numbers -> bulk-string
// []byte -> bulk-string
// bool -> bulk-string ("0" or "1")
// slice -> array
// map -> array with key/value pairs
// SimpleString -> string
// SimpleInt -> integer
// Marshaler -> raw bytes
// everything-else -> bulk-string representation using fmt.Sprint()
func AppendAny(b []byte, v interface{}) []byte {
switch v := v.(type) {
case SimpleString:
b = AppendString(b, string(v))
case SimpleInt:
b = AppendInt(b, int64(v))
case SimpleError:
b = AppendError(b, v.Error())
case nil:
b = AppendNull(b)
case error:
b = AppendError(b, prefixERRIfNeeded(v.Error()))
case string:
b = AppendBulkString(b, v)
case []byte:
b = AppendBulk(b, v)
case bool:
if v {
b = AppendBulkString(b, "1")
} else {
b = AppendBulkString(b, "0")
}
case int:
b = AppendBulkInt(b, int64(v))
case int8:
b = AppendBulkInt(b, int64(v))
case int16:
b = AppendBulkInt(b, int64(v))
case int32:
b = AppendBulkInt(b, int64(v))
case int64:
b = AppendBulkInt(b, int64(v))
case uint:
b = AppendBulkUint(b, uint64(v))
case uint8:
b = AppendBulkUint(b, uint64(v))
case uint16:
b = AppendBulkUint(b, uint64(v))
case uint32:
b = AppendBulkUint(b, uint64(v))
case uint64:
b = AppendBulkUint(b, uint64(v))
case float32:
b = AppendBulkFloat(b, float64(v))
case float64:
b = AppendBulkFloat(b, float64(v))
case Marshaler:
b = append(b, v.MarshalRESP()...)
default:
vv := reflect.ValueOf(v)
switch vv.Kind() {
case reflect.Slice:
n := vv.Len()
b = AppendArray(b, n)
for i := 0; i < n; i++ {
b = AppendAny(b, vv.Index(i).Interface())
}
case reflect.Map:
n := vv.Len()
b = AppendArray(b, n*2)
var i int
var strKey bool
var strsKeyItems []strKeyItem
iter := vv.MapRange()
for iter.Next() {
key := iter.Key().Interface()
if i == 0 {
if _, ok := key.(string); ok {
strKey = true
strsKeyItems = make([]strKeyItem, n)
}
}
if strKey {
strsKeyItems[i] = strKeyItem{
key.(string), iter.Value().Interface(),
}
} else {
b = AppendAny(b, key)
b = AppendAny(b, iter.Value().Interface())
}
i++
}
if strKey {
sort.Slice(strsKeyItems, func(i, j int) bool {
return strsKeyItems[i].key < strsKeyItems[j].key
})
for _, item := range strsKeyItems {
b = AppendBulkString(b, item.key)
b = AppendAny(b, item.value)
}
}
default:
b = AppendBulkString(b, fmt.Sprint(v))
}
}
return b
}
type strKeyItem struct {
key string
value interface{}
}

View File

@ -1,12 +1,9 @@
package redcon package redcon
import ( import (
"bytes"
"fmt" "fmt"
"math/rand"
"strconv" "strconv"
"testing" "testing"
"time"
) )
func isEmptyRESP(resp RESP) bool { func isEmptyRESP(resp RESP) bool {
@ -131,145 +128,3 @@ func TestRESP(t *testing.T) {
t.Fatalf("expected %v, got %v", 3, xx) t.Fatalf("expected %v, got %v", 3, xx)
} }
} }
func TestNextCommand(t *testing.T) {
rand.Seed(time.Now().UnixNano())
start := time.Now()
for time.Since(start) < time.Second {
// keep copy of pipeline args for final compare
var plargs [][][]byte
// create a pipeline of random number of commands with random data.
N := rand.Int() % 10000
var data []byte
for i := 0; i < N; i++ {
nargs := rand.Int() % 10
data = AppendArray(data, nargs)
var args [][]byte
for j := 0; j < nargs; j++ {
arg := make([]byte, rand.Int()%100)
if _, err := rand.Read(arg); err != nil {
t.Fatal(err)
}
data = AppendBulk(data, arg)
args = append(args, arg)
}
plargs = append(plargs, args)
}
// break data into random number of chunks
chunkn := rand.Int() % 100
if chunkn == 0 {
chunkn = 1
}
if len(data) < chunkn {
continue
}
var chunks [][]byte
var chunksz int
for i := 0; i < len(data); i += chunksz {
chunksz = rand.Int() % (len(data) / chunkn)
var chunk []byte
if i+chunksz < len(data) {
chunk = data[i : i+chunksz]
} else {
chunk = data[i:]
}
chunks = append(chunks, chunk)
}
// process chunks
var rbuf []byte
var fargs [][][]byte
for _, chunk := range chunks {
var data []byte
if len(rbuf) > 0 {
data = append(rbuf, chunk...)
} else {
data = chunk
}
for {
complete, args, _, leftover, err := ReadNextCommand(data, nil)
data = leftover
if err != nil {
t.Fatal(err)
}
if !complete {
break
}
fargs = append(fargs, args)
}
rbuf = append(rbuf[:0], data...)
}
// compare final args to original
if len(plargs) != len(fargs) {
t.Fatalf("not equal size: %v != %v", len(plargs), len(fargs))
}
for i := 0; i < len(plargs); i++ {
if len(plargs[i]) != len(fargs[i]) {
t.Fatalf("not equal size for item %v: %v != %v", i, len(plargs[i]), len(fargs[i]))
}
for j := 0; j < len(plargs[i]); j++ {
if !bytes.Equal(plargs[i][j], plargs[i][j]) {
t.Fatalf("not equal for item %v:%v: %v != %v", i, j, len(plargs[i][j]), len(fargs[i][j]))
}
}
}
}
}
func TestAppendBulkFloat(t *testing.T) {
var b []byte
b = AppendString(b, "HELLO")
b = AppendBulkFloat(b, 9.123192839)
b = AppendString(b, "HELLO")
exp := "+HELLO\r\n$11\r\n9.123192839\r\n+HELLO\r\n"
if string(b) != exp {
t.Fatalf("expected '%s', got '%s'", exp, b)
}
}
func TestAppendBulkInt(t *testing.T) {
var b []byte
b = AppendString(b, "HELLO")
b = AppendBulkInt(b, -9182739137)
b = AppendString(b, "HELLO")
exp := "+HELLO\r\n$11\r\n-9182739137\r\n+HELLO\r\n"
if string(b) != exp {
t.Fatalf("expected '%s', got '%s'", exp, b)
}
}
func TestAppendBulkUint(t *testing.T) {
var b []byte
b = AppendString(b, "HELLO")
b = AppendBulkInt(b, 91827391370)
b = AppendString(b, "HELLO")
exp := "+HELLO\r\n$11\r\n91827391370\r\n+HELLO\r\n"
if string(b) != exp {
t.Fatalf("expected '%s', got '%s'", exp, b)
}
}
func TestArrayMap(t *testing.T) {
var dst []byte
dst = AppendArray(dst, 4)
dst = AppendBulkString(dst, "key1")
dst = AppendBulkString(dst, "val1")
dst = AppendBulkString(dst, "key2")
dst = AppendBulkString(dst, "val2")
n, resp := ReadNextRESP(dst)
if n != len(dst) {
t.Fatalf("expected '%d', got '%d'", len(dst), n)
}
m := resp.Map()
if len(m) != 2 {
t.Fatalf("expected '%d', got '%d'", 2, len(m))
}
if m["key1"].String() != "val1" {
t.Fatalf("expected '%s', got '%s'", "val1", m["key1"].String())
}
if m["key2"].String() != "val2" {
t.Fatalf("expected '%s', got '%s'", "val2", m["key2"].String())
}
}