Major API update. Please see details.

The Redcon API has been changed to better reflect the wants of the
community. THIS IS A BREAKING COMMIT... sorry, it's a one time thing.

The changes include:

1. All commands and responses use []byte rather than string for data.
2. The handler signature has been changed from:
   func(conn redcon.Conn, args [][]string)
   to:
   func(conn redcon.Conn, cmd redcon.Command)
3. There's a new Reader and Writer types for reading commands and
   writing responses.

Performance remains the same.
This commit is contained in:
Josh Baker 2016-09-17 20:01:16 -07:00
parent 08e1ceff58
commit 67c21aa488
4 changed files with 737 additions and 637 deletions

View File

@ -11,15 +11,13 @@
Redcon is a custom Redis server framework for Go that is fast and simple to use. The reason for this library it to give an efficient server front-end for the [BuntDB](https://github.com/tidwall/buntdb) and [Tile38](https://github.com/tidwall/tile38) projects. Redcon is a custom Redis server framework for Go that is fast and simple to use. The reason for this library it to give an efficient server front-end for the [BuntDB](https://github.com/tidwall/buntdb) and [Tile38](https://github.com/tidwall/tile38) projects.
Features Features
-------- --------
- Create a [Fast](#benchmarks) custom Redis compatible server in Go - Create a [Fast](#benchmarks) custom Redis compatible server in Go
- Simple interface. One function `ListenAndServe` and one type `Conn` - Simple interface. One function `ListenAndServe` and two types `Conn` & `Command`
- Support for pipelining and telnet commands - Support for pipelining and telnet commands
- Works with Redis clients such as [redigo](https://github.com/garyburd/redigo), [redis-py](https://github.com/andymccurdy/redis-py), [node_redis](https://github.com/NodeRedis/node_redis), and [jedis](https://github.com/xetorthio/jedis) - Works with Redis clients such as [redigo](https://github.com/garyburd/redigo), [redis-py](https://github.com/andymccurdy/redis-py), [node_redis](https://github.com/NodeRedis/node_redis), and [jedis](https://github.com/xetorthio/jedis)
Installing Installing
---------- ----------
@ -29,6 +27,7 @@ go get -u github.com/tidwall/redcon
Example Example
------- -------
Here's a full example of a Redis clone that accepts: Here's a full example of a Redis clone that accepts:
- SET key value - SET key value
@ -58,35 +57,34 @@ var addr = ":6380"
func main() { func main() {
var mu sync.RWMutex var mu sync.RWMutex
var items = make(map[string]string) var items = make(map[string][]byte)
go log.Printf("started server at %s", addr) go log.Printf("started server at %s", addr)
err := redcon.ListenAndServe(addr, err := redcon.ListenAndServe(addr,
func(conn redcon.Conn, commands [][]string) { func(conn redcon.Conn, cmd redcon.Command) {
for _, args := range commands { switch strings.ToLower(string(cmd.Args[0])) {
switch strings.ToLower(args[0]) {
default: default:
conn.WriteError("ERR unknown command '" + args[0] + "'") conn.WriteError("ERR unknown command '" + string(cmd.Args[0]) + "'")
case "ping": case "ping":
conn.WriteString("PONG") conn.WriteString("PONG")
case "quit": case "quit":
conn.WriteString("OK") conn.WriteString("OK")
conn.Close() conn.Close()
case "set": case "set":
if len(args) != 3 { if len(cmd.Args) != 3 {
conn.WriteError("ERR wrong number of arguments for '" + args[0] + "' command") conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command")
continue return
} }
mu.Lock() mu.Lock()
items[args[1]] = args[2] items[string(cmd.Args[1])] = cmd.Args[2]
mu.Unlock() mu.Unlock()
conn.WriteString("OK") conn.WriteString("OK")
case "get": case "get":
if len(args) != 2 { if len(cmd.Args) != 2 {
conn.WriteError("ERR wrong number of arguments for '" + args[0] + "' command") conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command")
continue return
} }
mu.RLock() mu.RLock()
val, ok := items[args[1]] val, ok := items[string(cmd.Args[1])]
mu.RUnlock() mu.RUnlock()
if !ok { if !ok {
conn.WriteNull() conn.WriteNull()
@ -94,13 +92,13 @@ func main() {
conn.WriteBulk(val) conn.WriteBulk(val)
} }
case "del": case "del":
if len(args) != 2 { if len(cmd.Args) != 2 {
conn.WriteError("ERR wrong number of arguments for '" + args[0] + "' command") conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command")
continue return
} }
mu.Lock() mu.Lock()
_, ok := items[args[1]] _, ok := items[string(cmd.Args[1])]
delete(items, args[1]) delete(items, string(cmd.Args[1]))
mu.Unlock() mu.Unlock()
if !ok { if !ok {
conn.WriteInt(0) conn.WriteInt(0)
@ -108,7 +106,6 @@ func main() {
conn.WriteInt(1) conn.WriteInt(1)
} }
} }
}
}, },
func(conn redcon.Conn) bool { func(conn redcon.Conn) bool {
// use this function to accept or deny the connection. // use this function to accept or deny the connection.
@ -147,8 +144,8 @@ $ GOMAXPROCS=1 go run example/clone.go
``` ```
``` ```
redis-benchmark -p 6380 -t set,get -n 10000000 -q -P 512 -c 512 redis-benchmark -p 6380 -t set,get -n 10000000 -q -P 512 -c 512
SET: 3119151.50 requests per second SET: 2018570.88 requests per second
GET: 4142502.25 requests per second GET: 2403846.25 requests per second
``` ```
**Redcon**: Multi-threaded, no disk persistence. **Redcon**: Multi-threaded, no disk persistence.
@ -158,8 +155,8 @@ $ GOMAXPROCS=0 go run example/clone.go
``` ```
``` ```
$ redis-benchmark -p 6380 -t set,get -n 10000000 -q -P 512 -c 512 $ redis-benchmark -p 6380 -t set,get -n 10000000 -q -P 512 -c 512
SET: 3637686.25 requests per second SET: 1944390.38 requests per second
GET: 4249894.00 requests per second GET: 3993610.25 requests per second
``` ```
*Running on a MacBook Pro 15" 2.8 GHz Intel Core i7 using Go 1.7* *Running on a MacBook Pro 15" 2.8 GHz Intel Core i7 using Go 1.7*

View File

@ -12,17 +12,16 @@ var addr = ":6380"
func main() { func main() {
var mu sync.RWMutex var mu sync.RWMutex
var items = make(map[string]string) var items = make(map[string][]byte)
go log.Printf("started server at %s", addr) go log.Printf("started server at %s", addr)
err := redcon.ListenAndServe(addr, err := redcon.ListenAndServe(addr,
func(conn redcon.Conn, commands [][]string) { func(conn redcon.Conn, cmd redcon.Command) {
for _, args := range commands { switch strings.ToLower(string(cmd.Args[0])) {
switch strings.ToLower(args[0]) {
default: default:
conn.WriteError("ERR unknown command '" + args[0] + "'") conn.WriteError("ERR unknown command '" + string(cmd.Args[0]) + "'")
case "hijack": case "detach":
hconn := conn.Hijack() hconn := conn.Detach()
log.Printf("connection is hijacked") log.Printf("connection has been detached")
go func() { go func() {
defer hconn.Close() defer hconn.Close()
hconn.WriteString("OK") hconn.WriteString("OK")
@ -35,21 +34,21 @@ func main() {
conn.WriteString("OK") conn.WriteString("OK")
conn.Close() conn.Close()
case "set": case "set":
if len(args) != 3 { if len(cmd.Args) != 3 {
conn.WriteError("ERR wrong number of arguments for '" + args[0] + "' command") conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command")
continue return
} }
mu.Lock() mu.Lock()
items[args[1]] = args[2] items[string(cmd.Args[1])] = cmd.Args[2]
mu.Unlock() mu.Unlock()
conn.WriteString("OK") conn.WriteString("OK")
case "get": case "get":
if len(args) != 2 { if len(cmd.Args) != 2 {
conn.WriteError("ERR wrong number of arguments for '" + args[0] + "' command") conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command")
continue return
} }
mu.RLock() mu.RLock()
val, ok := items[args[1]] val, ok := items[string(cmd.Args[1])]
mu.RUnlock() mu.RUnlock()
if !ok { if !ok {
conn.WriteNull() conn.WriteNull()
@ -57,13 +56,13 @@ func main() {
conn.WriteBulk(val) conn.WriteBulk(val)
} }
case "del": case "del":
if len(args) != 2 { if len(cmd.Args) != 2 {
conn.WriteError("ERR wrong number of arguments for '" + args[0] + "' command") conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command")
continue return
} }
mu.Lock() mu.Lock()
_, ok := items[args[1]] _, ok := items[string(cmd.Args[1])]
delete(items, args[1]) delete(items, string(cmd.Args[1]))
mu.Unlock() mu.Unlock()
if !ok { if !ok {
conn.WriteInt(0) conn.WriteInt(0)
@ -71,7 +70,6 @@ func main() {
conn.WriteInt(1) conn.WriteInt(1)
} }
} }
}
}, },
func(conn redcon.Conn) bool { func(conn redcon.Conn) bool {
// use this function to accept or deny the connection. // use this function to accept or deny the connection.

774
redcon.go
View File

@ -1,7 +1,8 @@
// Package redcon provides a custom redis server implementation. // Package redcon implements a Redis compatible server framework
package redcon package redcon
import ( import (
"bufio"
"errors" "errors"
"io" "io"
"net" "net"
@ -9,6 +10,23 @@ import (
"sync" "sync"
) )
var (
errUnbalancedQuotes = &errProtocol{"unbalanced quotes in request"}
errInvalidBulkLength = &errProtocol{"invalid bulk length"}
errInvalidMultiBulkLength = &errProtocol{"invalid multibulk length"}
errDetached = errors.New("detached")
errIncompleteCommand = errors.New("incomplete command")
errTooMuchData = errors.New("too much data")
)
type errProtocol struct {
msg string
}
func (err *errProtocol) Error() string {
return "Protocol error: " + err.msg
}
// Conn represents a client connection // Conn represents a client connection
type Conn interface { type Conn interface {
// RemoteAddr returns the remote address of the client connection. // RemoteAddr returns the remote address of the client connection.
@ -19,10 +37,10 @@ type Conn interface {
WriteError(msg string) WriteError(msg string)
// WriteString writes a string to the client. // WriteString writes a string to the client.
WriteString(str string) WriteString(str string)
// WriteBulk writes a bulk string to the client. // WriteBulk writes bulk bytes to the client.
WriteBulk(bulk string) WriteBulk(bulk []byte)
// WriteBulkBytes writes bulk bytes to the client. // WriteBulkString writes a bulk string to the client.
WriteBulkBytes(bulk []byte) WriteBulkString(bulk string)
// WriteInt writes an integer to the client. // WriteInt writes an integer to the client.
WriteInt(num int) WriteInt(num int)
// WriteArray writes an array header. You must then write addtional // WriteArray writes an array header. You must then write addtional
@ -35,107 +53,49 @@ type Conn interface {
WriteArray(count int) WriteArray(count int)
// WriteNull writes a null to the client // WriteNull writes a null to the client
WriteNull() WriteNull()
// SetReadBuffer updates the buffer read size for the connection
SetReadBuffer(bytes int)
// Context returns a user-defined context // Context returns a user-defined context
Context() interface{} Context() interface{}
// SetContext sets a user-defined context // SetContext sets a user-defined context
SetContext(v interface{}) SetContext(v interface{})
// Hijack return an unmanaged connection. Useful for operations like PubSub. // SetReadBuffer updates the buffer read size for the connection
SetReadBuffer(bytes int)
// Detach return a connection that is detached from the server.
// Useful for operations like PubSub.
// //
// hconn := conn.Hijack() // dconn := conn.Detach()
// go func(){ // go func(){
// defer hconn.Close() // defer dconn.Close()
// cmd, err := hconn.ReadCommand() // cmd, err := dconn.ReadCommand()
// if err != nil{ // if err != nil{
// fmt.Printf("read failed: %v\n", err) // fmt.Printf("read failed: %v\n", err)
// return // return
// } // }
// fmt.Printf("received command: %v", cmd) // fmt.Printf("received command: %v", cmd)
// hconn.WriteString("OK") // hconn.WriteString("OK")
// if err := hconn.Flush(); err != nil{ // if err := dconn.Flush(); err != nil{
// fmt.Printf("write failed: %v\n", err) // fmt.Printf("write failed: %v\n", err)
// return // return
// } // }
// }() // }()
Hijack() HijackedConn Detach() DetachedConn
} }
// HijackConn represents an unmanaged connection. // NewServer returns a new Redcon server.
type HijackedConn interface { func NewServer(addr string,
// Conn is the original connection handler func(conn Conn, cmd Command),
Conn accept func(conn Conn) bool,
// ReadCommand reads the next client command. closed func(conn Conn, err error),
ReadCommand() ([]string, error)
// ReadCommandBytes reads the next client command as bytes.
ReadCommandBytes() ([][]byte, error)
// Flush flushes any writes to the network.
Flush() error
}
var (
errUnbalancedQuotes = &errProtocol{"unbalanced quotes in request"}
errInvalidBulkLength = &errProtocol{"invalid bulk length"}
errInvalidMultiBulkLength = &errProtocol{"invalid multibulk length"}
errHijacked = errors.New("hijacked")
)
const (
defaultBufLen = 4 * 1024
defaultPoolSize = 64
)
type errProtocol struct {
msg string
}
func (err *errProtocol) Error() string {
return "Protocol error: " + err.msg
}
// Server represents a Redcon server.
type Server struct {
mu sync.Mutex
addr string
shandler func(conn Conn, cmds [][]string)
bhandler func(conn Conn, cmds [][][]byte)
accept func(conn Conn) bool
closed func(conn Conn, err error)
ln *net.TCPListener
done bool
conns map[*conn]bool
rdpool [][]byte
wrpool [][]byte
}
// NewServer returns a new server
func NewServer(
addr string, handler func(conn Conn, cmds [][]string),
accept func(conn Conn) bool, closed func(conn Conn, err error),
) *Server { ) *Server {
if handler == nil {
panic("handler is nil")
}
s := &Server{ s := &Server{
addr: addr, addr: addr,
shandler: handler, handler: handler,
accept: accept, accept: accept,
closed: closed, closed: closed,
conns: make(map[*conn]bool), conns: make(map[*conn]bool),
} }
initbuf := make([]byte, defaultPoolSize*defaultBufLen)
s.rdpool = make([][]byte, defaultPoolSize)
for i := 0; i < defaultPoolSize; i++ {
s.rdpool[i] = initbuf[i*defaultBufLen : i*defaultBufLen+defaultBufLen]
}
return s
}
// NewServerBytes returns a new server
// It uses []byte instead of string for the handler commands.
func NewServerBytes(
addr string, handler func(conn Conn, cmds [][][]byte),
accept func(conn Conn) bool, closed func(conn Conn, err error),
) *Server {
s := NewServer(addr, nil, accept, closed)
s.bhandler = handler
return s return s
} }
@ -156,14 +116,19 @@ func (s *Server) ListenAndServe() error {
return s.ListenServeAndSignal(nil) return s.ListenServeAndSignal(nil)
} }
// ListenAndServe creates a new server and binds to addr.
func ListenAndServe(addr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
) error {
return NewServer(addr, handler, accept, closed).ListenAndServe()
}
// ListenServeAndSignal serves incoming connections and passes nil or error // ListenServeAndSignal serves incoming connections and passes nil or error
// when listening. signal can be nil. // when listening. signal can be nil.
func (s *Server) ListenServeAndSignal(signal chan error) error { func (s *Server) ListenServeAndSignal(signal chan error) error {
var addr = s.addr var addr = s.addr
var shandler = s.shandler
var bhandler = s.bhandler
var accept = s.accept
var closed = s.closed
ln, err := net.Listen("tcp", addr) ln, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
if signal != nil { if signal != nil {
@ -200,137 +165,68 @@ func (s *Server) ListenServeAndSignal(signal chan error) error {
} }
return err return err
} }
c := &conn{ c := &conn{conn: tcpc, addr: tcpc.RemoteAddr().String(),
tcpc, wr: NewWriter(tcpc), rd: NewReader(tcpc)}
newWriter(tcpc),
newReader(tcpc, nil),
tcpc.RemoteAddr().String(),
nil, false,
}
s.mu.Lock() s.mu.Lock()
if len(s.rdpool) > 0 {
c.rd.buf = s.rdpool[len(s.rdpool)-1]
s.rdpool = s.rdpool[:len(s.rdpool)-1]
} else {
c.rd.buf = make([]byte, defaultBufLen)
}
if len(s.wrpool) > 0 {
c.wr.b = s.wrpool[len(s.wrpool)-1]
s.wrpool = s.wrpool[:len(s.wrpool)-1]
} else {
c.wr.b = make([]byte, 0, 64)
}
s.conns[c] = true s.conns[c] = true
s.mu.Unlock() s.mu.Unlock()
if accept != nil && !accept(c) { if s.accept != nil && !s.accept(c) {
s.mu.Lock() s.mu.Lock()
delete(s.conns, c) delete(s.conns, c)
s.mu.Unlock() s.mu.Unlock()
c.Close() c.Close()
continue continue
} }
if shandler == nil && bhandler == nil { go handle(s, c)
shandler = func(conn Conn, cmds [][]string) {}
} else if shandler != nil {
bhandler = nil
} else if bhandler != nil {
shandler = nil
}
go handle(s, c, shandler, bhandler, closed)
} }
} }
// ListenAndServe creates a new server and binds to addr. // handle manages the server connection.
func ListenAndServe( func handle(s *Server, c *conn) {
addr string, handler func(conn Conn, cmds [][]string),
accept func(conn Conn) bool, closed func(conn Conn, err error),
) error {
return NewServer(addr, handler, accept, closed).ListenAndServe()
}
// ListenAndServeBytes creates a new server and binds to addr.
// It uses []byte instead of string for the handler commands.
func ListenAndServeBytes(
addr string, handler func(conn Conn, cmds [][][]byte),
accept func(conn Conn) bool, closed func(conn Conn, err error),
) error {
return NewServerBytes(addr, handler, accept, closed).ListenAndServe()
}
func handle(
s *Server, c *conn,
shandler func(conn Conn, cmds [][]string),
bhandler func(conn Conn, cmds [][][]byte),
closed func(conn Conn, err error)) {
var err error var err error
defer func() { defer func() {
if err != errHijacked { if err != errDetached {
// do not close the connection when a detach is detected.
c.conn.Close() c.conn.Close()
} }
func() { func() {
// remove the conn from the server
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
delete(s.conns, c) delete(s.conns, c)
if closed != nil { if s.closed != nil {
if err == io.EOF { if err == io.EOF {
err = nil err = nil
} }
closed(c, err) s.closed(c, err)
}
if err != errHijacked {
if len(s.rdpool) < defaultPoolSize && len(c.rd.buf) < defaultBufLen {
s.rdpool = append(s.rdpool, c.rd.buf)
}
if len(s.wrpool) < defaultPoolSize && cap(c.wr.b) < defaultBufLen {
s.wrpool = append(s.wrpool, c.wr.b[:0])
}
} }
}() }()
}() }()
err = func() error { err = func() error {
// read commands and feed back to the client
for { for {
if c.hj { // read pipeline commands
return errHijacked cmds, err := c.rd.readCommands(nil)
}
cmds, err := c.rd.ReadCommands()
if err != nil { if err != nil {
if err, ok := err.(*errProtocol); ok { if err, ok := err.(*errProtocol); ok {
// All protocol errors should attempt a response to // All protocol errors should attempt a response to
// the client. Ignore errors. // the client. Ignore write errors.
c.wr.WriteError("ERR " + err.Error()) c.wr.WriteError("ERR " + err.Error())
c.wr.Flush() c.wr.Flush()
} }
return err return err
} }
if len(cmds) > 0 { for _, cmd := range cmds {
if shandler != nil { s.handler(c, cmd)
// convert bytes to strings
scmds := make([][]string, len(cmds))
for i := 0; i < len(cmds); i++ {
scmds[i] = make([]string, len(cmds[i]))
for j := 0; j < len(scmds[i]); j++ {
scmds[i][j] = string(cmds[i][j])
} }
if c.detached {
// client has been detached
return errDetached
} }
shandler(c, scmds) if c.closed {
} else if bhandler != nil {
// copy the byte commands once, before exposing to client.
for i := 0; i < len(cmds); i++ {
for j := 0; j < len(cmds[i]); j++ {
nb := make([]byte, len(cmds[i][j]))
copy(nb, cmds[i][j])
cmds[i][j] = nb
}
}
bhandler(c, cmds)
}
}
if c.wr.err != nil {
if c.wr.err == errClosed {
return nil return nil
} }
return c.wr.err
}
if err := c.wr.Flush(); err != nil { if err := c.wr.Flush(); err != nil {
return err return err
} }
@ -338,120 +234,218 @@ func handle(
}() }()
} }
// conn represents a client connection
type conn struct { type conn struct {
conn *net.TCPConn conn *net.TCPConn
wr *writer wr *Writer
rd *reader rd *Reader
addr string addr string
ctx interface{} ctx interface{}
hj bool detached bool
closed bool
} }
func (c *conn) Close() error { func (c *conn) Close() error {
err := c.wr.Close() // flush and close the writer c.closed = true
c.conn.Close() // close the connection. ignore this error return c.conn.Close()
return err // return the writer error only
} }
func (c *conn) Context() interface{} { func (c *conn) Context() interface{} { return c.ctx }
return c.ctx func (c *conn) SetContext(v interface{}) { c.ctx = v }
} func (c *conn) SetReadBuffer(n int) {}
func (c *conn) SetContext(v interface{}) { func (c *conn) WriteString(str string) { c.wr.WriteString(str) }
c.ctx = v func (c *conn) WriteBulk(bulk []byte) { c.wr.WriteBulk(bulk) }
} func (c *conn) WriteBulkString(bulk string) { c.wr.WriteBulkString(bulk) }
func (c *conn) WriteString(str string) { func (c *conn) WriteInt(num int) { c.wr.WriteInt(num) }
c.wr.WriteString(str) func (c *conn) WriteError(msg string) { c.wr.WriteError(msg) }
} func (c *conn) WriteArray(count int) { c.wr.WriteArray(count) }
func (c *conn) WriteBulk(bulk string) { func (c *conn) WriteNull() { c.wr.WriteNull() }
c.wr.WriteBulk(bulk) func (c *conn) RemoteAddr() string { return c.addr }
}
func (c *conn) WriteBulkBytes(bulk []byte) { // DetachedConn represents a connection that is detached from the server
c.wr.WriteBulkBytes(bulk) type DetachedConn interface {
} // Conn is the original connection
func (c *conn) WriteInt(num int) { Conn
c.wr.WriteInt(num) // ReadCommand reads the next client command.
} ReadCommand() (Command, error)
func (c *conn) WriteError(msg string) { // Flush flushes any writes to the network.
c.wr.WriteError(msg) Flush() error
}
func (c *conn) WriteArray(count int) {
c.wr.WriteArrayStart(count)
}
func (c *conn) WriteNull() {
c.wr.WriteNull()
}
func (c *conn) RemoteAddr() string {
return c.addr
}
func (c *conn) SetReadBuffer(bytes int) {
}
func (c *conn) Hijack() HijackedConn {
c.hj = true
return &hijackedConn{conn: c}
} }
type hijackedConn struct { // Detach removes the current connection from the server loop and returns
// a detached connection. This is useful for operations such as PubSub.
// The detached connection must be closed by calling Close() when done.
// All writes such as WriteString() will not be written to the client
// until Flush() is called.
func (c *conn) Detach() DetachedConn {
c.detached = true
return &detachedConn{conn: c}
}
type detachedConn struct {
*conn *conn
cmds [][][]byte
} }
func (hjc *hijackedConn) Flush() error { // Flush writes and Write* calls to the client.
return hjc.conn.wr.Flush() func (dc *detachedConn) Flush() error {
return dc.conn.wr.Flush()
} }
func (hjc *hijackedConn) ReadCommandBytes() ([][]byte, error) { // ReadCommand read the next command from the client.
if len(hjc.cmds) > 0 { func (dc *detachedConn) ReadCommand() (Command, error) {
args := hjc.cmds[0] if dc.closed {
hjc.cmds = hjc.cmds[1:] return Command{}, errors.New("closed")
for i, arg := range args {
nb := make([]byte, len(arg))
copy(nb, arg)
args[i] = nb
} }
return args, nil cmd, err := dc.rd.ReadCommand()
}
cmds, err := hjc.rd.ReadCommands()
if err != nil { if err != nil {
return nil, err return Command{}, err
} }
hjc.cmds = cmds return cmd, nil
return hjc.ReadCommandBytes()
} }
func (hjc *hijackedConn) ReadCommand() ([]string, error) { // Command represent a command
if len(hjc.cmds) > 0 { type Command struct {
args := hjc.cmds[0] // Raw is a encoded RESP message.
hjc.cmds = hjc.cmds[1:] Raw []byte
nargs := make([]string, len(args)) // Args is a series of arguments that make up the command.
for i, arg := range args { Args [][]byte
nargs[i] = string(arg)
}
return nargs, nil
}
return hjc.ReadCommand()
} }
// Reader represents a RESP command reader. // Server defines a server for clients for managing client connections.
type reader struct { type Server struct {
r io.Reader // base reader mu sync.Mutex
addr string
handler func(conn Conn, cmd Command)
accept func(conn Conn) bool
closed func(conn Conn, err error)
conns map[*conn]bool
ln *net.TCPListener
done bool
}
// Writer allows for writing RESP messages.
type Writer struct {
w io.Writer
b []byte
}
// NewWriter creates a new RESP writer.
func NewWriter(wr io.Writer) *Writer {
return &Writer{
w: wr,
}
}
// WriteNull writes a null to the client
func (w *Writer) WriteNull() {
w.b = append(w.b, '$', '-', '1', '\r', '\n')
}
// WriteArray writes an array header. You must then write addtional
// sub-responses to the client to complete the response.
// For example to write two strings:
//
// c.WriteArray(2)
// c.WriteBulk("item 1")
// c.WriteBulk("item 2")
func (w *Writer) WriteArray(count int) {
w.b = append(w.b, '*')
w.b = append(w.b, strconv.FormatInt(int64(count), 10)...)
w.b = append(w.b, '\r', '\n')
}
// WriteBulk writes bulk bytes to the client.
func (w *Writer) WriteBulk(bulk []byte) {
w.b = append(w.b, '$')
w.b = append(w.b, strconv.FormatInt(int64(len(bulk)), 10)...)
w.b = append(w.b, '\r', '\n')
w.b = append(w.b, bulk...)
w.b = append(w.b, '\r', '\n')
}
// WriteBulkString writes a bulk string to the client.
func (w *Writer) WriteBulkString(bulk string) {
w.b = append(w.b, '$')
w.b = append(w.b, strconv.FormatInt(int64(len(bulk)), 10)...)
w.b = append(w.b, '\r', '\n')
w.b = append(w.b, bulk...)
w.b = append(w.b, '\r', '\n')
}
// Flush writes all unflushed Write* calls to the underlying writer.
func (w *Writer) Flush() error {
if _, err := w.w.Write(w.b); err != nil {
return err
}
w.b = w.b[:0]
return nil
}
// WriteError writes an error to the client.
func (w *Writer) WriteError(msg string) {
w.b = append(w.b, '-')
w.b = append(w.b, msg...)
w.b = append(w.b, '\r', '\n')
}
// WriteString writes a string to the client.
func (w *Writer) WriteString(msg string) {
if msg == "OK" {
w.b = append(w.b, '+', 'O', 'K', '\r', '\n')
} else {
w.b = append(w.b, '+')
w.b = append(w.b, []byte(msg)...)
w.b = append(w.b, '\r', '\n')
}
}
// WriteInt writes an integer to the client.
func (w *Writer) WriteInt(num int) {
w.b = append(w.b, ':')
w.b = append(w.b, []byte(strconv.FormatInt(int64(num), 10))...)
w.b = append(w.b, '\r', '\n')
}
// Reader represent a reader for RESP or telnet commands.
type Reader struct {
rd *bufio.Reader
buf []byte buf []byte
start int start int
end int end int
cmds []Command
} }
// NewReader returns a RESP command reader. // NewReader returns a command reader which will read RESP or telnet commands.
func newReader(r io.Reader, buf []byte) *reader { func NewReader(rd io.Reader) *Reader {
return &reader{ return &Reader{
r: r, rd: bufio.NewReader(rd),
buf: buf, buf: make([]byte, 4096),
} }
} }
// ReadCommands reads one or more commands from the reader. func parseInt(b []byte) (int, error) {
func (r *reader) ReadCommands() ([][][]byte, error) { // shortcut atoi for 0-99. fails for negative numbers.
if r.end-r.start > 0 { switch len(b) {
b := r.buf[r.start:r.end] case 1:
// we have some potential commands. if b[0] >= '0' && b[0] <= '9' {
var cmds [][][]byte return int(b[0] - '0'), nil
}
case 2:
if b[0] >= '0' && b[0] <= '9' && b[1] >= '0' && b[1] <= '9' {
return int(b[0]-'0')*10 + int(b[1]-'0'), nil
}
}
// fallback to standard library
n, err := strconv.ParseUint(string(b), 10, 64)
return int(n), err
}
func (rd *Reader) readCommands(leftover *int) ([]Command, error) {
var cmds []Command
b := rd.buf[rd.start:rd.end]
if len(b) > 0 {
// we have data, yay!
// but is this enough data for a complete command? or multiple?
next: next:
switch b[0] { switch b[0] {
default: default:
@ -464,7 +458,7 @@ func (r *reader) ReadCommands() ([][][]byte, error) {
} else { } else {
line = b[:i] line = b[:i]
} }
var args [][]byte var cmd Command
var quote bool var quote bool
var escape bool var escape bool
outer: outer:
@ -475,7 +469,7 @@ func (r *reader) ReadCommands() ([][][]byte, error) {
if !quote { if !quote {
if c == ' ' { if c == ' ' {
if len(nline) > 0 { if len(nline) > 0 {
args = append(args, nline) cmd.Args = append(cmd.Args, nline)
} }
line = line[i+1:] line = line[i+1:]
continue outer continue outer
@ -501,7 +495,7 @@ func (r *reader) ReadCommands() ([][][]byte, error) {
} }
} else if c == '"' { } else if c == '"' {
quote = false quote = false
args = append(args, nline) cmd.Args = append(cmd.Args, nline)
line = line[i+1:] line = line[i+1:]
if len(line) > 0 && line[0] != ' ' { if len(line) > 0 && line[0] != ' ' {
return nil, errUnbalancedQuotes return nil, errUnbalancedQuotes
@ -518,12 +512,20 @@ func (r *reader) ReadCommands() ([][][]byte, error) {
return nil, errUnbalancedQuotes return nil, errUnbalancedQuotes
} }
if len(line) > 0 { if len(line) > 0 {
args = append(args, line) cmd.Args = append(cmd.Args, line)
} }
break break
} }
if len(args) > 0 { if len(cmd.Args) > 0 {
cmds = append(cmds, args) // convert this to resp command syntax
var wr Writer
wr.WriteArray(len(cmd.Args))
for i := range cmd.Args {
wr.WriteBulk(cmd.Args[i])
cmd.Args[i] = append([]byte(nil), cmd.Args[i]...)
}
cmd.Raw = wr.b
cmds = append(cmds, cmd)
} }
b = b[i+1:] b = b[i+1:]
if len(b) > 0 { if len(b) > 0 {
@ -535,20 +537,19 @@ func (r *reader) ReadCommands() ([][][]byte, error) {
} }
case '*': case '*':
// resp formatted command // resp formatted command
var si int marks := make([]int, 0, 16)
outer2: outer2:
for i := 0; i < len(b); i++ { for i := 1; i < len(b); i++ {
var args [][]byte
if b[i] == '\n' { if b[i] == '\n' {
if b[i-1] != '\r' { if b[i-1] != '\r' {
return nil, errInvalidMultiBulkLength return nil, errInvalidMultiBulkLength
} }
ni, err := parseInt(b[si+1 : i-1]) count, err := parseInt(b[1 : i-1])
if err != nil || ni <= 0 { if err != nil || count <= 0 {
return nil, errInvalidMultiBulkLength return nil, errInvalidMultiBulkLength
} }
args = make([][]byte, 0, ni) marks = marks[:0]
for j := 0; j < ni; j++ { for j := 0; j < count; j++ {
// read bulk length // read bulk length
i++ i++
if i < len(b) { if i < len(b) {
@ -556,35 +557,49 @@ func (r *reader) ReadCommands() ([][][]byte, error) {
return nil, &errProtocol{"expected '$', got '" + return nil, &errProtocol{"expected '$', got '" +
string(b[i]) + "'"} string(b[i]) + "'"}
} }
si = i si := i
for ; i < len(b); i++ { for ; i < len(b); i++ {
if b[i] == '\n' { if b[i] == '\n' {
if b[i-1] != '\r' { if b[i-1] != '\r' {
return nil, errInvalidBulkLength return nil, errInvalidBulkLength
} }
ni2, err := parseInt(b[si+1 : i-1]) size, err := parseInt(b[si+1 : i-1])
if err != nil || ni2 < 0 { if err != nil || size < 0 {
return nil, errInvalidBulkLength return nil, errInvalidBulkLength
} }
if i+ni2+2 >= len(b) { if i+size+2 >= len(b) {
// not ready // not ready
break outer2 break outer2
} }
if b[i+ni2+2] != '\n' || if b[i+size+2] != '\n' ||
b[i+ni2+1] != '\r' { b[i+size+1] != '\r' {
return nil, errInvalidBulkLength return nil, errInvalidBulkLength
} }
i++ i++
arg := b[i : i+ni2] marks = append(marks, i, i+size)
i += ni2 + 1 i += size + 1
args = append(args, arg)
break break
} }
} }
} }
} }
if len(args) == cap(args) { if len(marks) == count*2 {
cmds = append(cmds, args) var cmd Command
if rd.rd != nil {
// make a raw copy of the entire command when
// there's a underlying reader.
cmd.Raw = append([]byte(nil), b[:i+1]...)
} else {
// just assign the slice
cmd.Raw = b[:i+1]
}
cmd.Args = make([][]byte, len(marks)/2)
// slice up the raw command into the args based on
// the recorded marks.
for h := 0; h < len(marks); h += 2 {
cmd.Args[h/2] = cmd.Raw[marks[h]:marks[h+1]]
}
cmds = append(cmds, cmd)
b = b[i+1:] b = b[i+1:]
if len(b) > 0 { if len(b) > 0 {
goto next goto next
@ -596,164 +611,63 @@ func (r *reader) ReadCommands() ([][][]byte, error) {
} }
} }
done: done:
if len(b) == 0 { rd.start = rd.end - len(b)
r.start = 0 }
r.end = 0 if leftover != nil {
} else { *leftover = rd.end - rd.start
r.start = r.end - len(b)
} }
if len(cmds) > 0 { if len(cmds) > 0 {
return cmds, nil return cmds, nil
} }
if rd.rd == nil {
return nil, errIncompleteCommand
} }
if r.end == len(r.buf) { if rd.end == len(rd.buf) {
nbuf := make([]byte, len(r.buf)*2) // at the end of the buffer.
copy(nbuf, r.buf) if rd.start == rd.end {
r.buf = nbuf // rewind the to the beginning
rd.start, rd.end = 0, 0
} else {
// must grow the buffer
newbuf := make([]byte, len(rd.buf)*2)
copy(newbuf, rd.buf)
rd.buf = newbuf
} }
n, err := r.r.Read(r.buf[r.end:]) }
n, err := rd.rd.Read(rd.buf[rd.end:])
if err != nil { if err != nil {
if err == io.EOF {
if r.end > 0 {
return nil, io.ErrUnexpectedEOF
}
}
return nil, err return nil, err
} }
r.end += n rd.end += n
return r.ReadCommands() return rd.readCommands(leftover)
}
func parseInt(b []byte) (int, error) {
switch len(b) {
case 1:
if b[0] >= '0' && b[0] <= '9' {
return int(b[0] - '0'), nil
}
case 2:
if b[0] >= '0' && b[0] <= '9' && b[1] >= '0' && b[1] <= '9' {
return int(b[0]-'0')*10 + int(b[1]-'0'), nil
}
}
var n int
for i := 0; i < len(b); i++ {
if b[i] < '0' || b[i] > '9' {
return 0, errors.New("invalid number")
}
n = n*10 + int(b[i]-'0')
}
return n, nil
} }
var errClosed = errors.New("closed") // ReadCommand reads the next command.
func (rd *Reader) ReadCommand() (Command, error) {
type writer struct { if len(rd.cmds) > 0 {
w *net.TCPConn cmd := rd.cmds[0]
b []byte rd.cmds = rd.cmds[1:]
err error return cmd, nil
}
cmds, err := rd.readCommands(nil)
if err != nil {
return Command{}, err
}
rd.cmds = cmds
return rd.ReadCommand()
} }
func newWriter(w *net.TCPConn) *writer { // Parse parses a raw RESP message and returns a command.
return &writer{w: w, b: make([]byte, 0, 512)} func Parse(raw []byte) (Command, error) {
} rd := Reader{buf: raw, end: len(raw)}
var leftover int
cmds, err := rd.readCommands(&leftover)
if err != nil {
return Command{}, err
}
if leftover > 0 {
return Command{}, errTooMuchData
}
return cmds[0], nil
func (w *writer) WriteNull() error {
if w.err != nil {
return w.err
}
w.b = append(w.b, '$', '-', '1', '\r', '\n')
return nil
}
func (w *writer) WriteArrayStart(count int) error {
if w.err != nil {
return w.err
}
w.b = append(w.b, '*')
w.b = append(w.b, []byte(strconv.FormatInt(int64(count), 10))...)
w.b = append(w.b, '\r', '\n')
return nil
}
func (w *writer) WriteBulk(bulk string) error {
if w.err != nil {
return w.err
}
w.b = append(w.b, '$')
w.b = append(w.b, []byte(strconv.FormatInt(int64(len(bulk)), 10))...)
w.b = append(w.b, '\r', '\n')
w.b = append(w.b, []byte(bulk)...)
w.b = append(w.b, '\r', '\n')
return nil
}
func (w *writer) WriteBulkBytes(bulk []byte) error {
if w.err != nil {
return w.err
}
w.b = append(w.b, '$')
w.b = append(w.b, []byte(strconv.FormatInt(int64(len(bulk)), 10))...)
w.b = append(w.b, '\r', '\n')
w.b = append(w.b, bulk...)
w.b = append(w.b, '\r', '\n')
return nil
}
func (w *writer) Flush() error {
if w.err != nil {
return w.err
}
if len(w.b) == 0 {
return nil
}
if _, err := w.w.Write(w.b); err != nil {
w.err = err
return err
}
w.b = w.b[:0]
return nil
}
func (w *writer) WriteError(msg string) error {
if w.err != nil {
return w.err
}
w.b = append(w.b, '-')
w.b = append(w.b, []byte(msg)...)
w.b = append(w.b, '\r', '\n')
return nil
}
func (w *writer) WriteString(msg string) error {
if w.err != nil {
return w.err
}
if msg == "OK" {
w.b = append(w.b, '+', 'O', 'K', '\r', '\n')
} else {
w.b = append(w.b, '+')
w.b = append(w.b, []byte(msg)...)
w.b = append(w.b, '\r', '\n')
}
return nil
}
func (w *writer) WriteInt(num int) error {
if w.err != nil {
return w.err
}
w.b = append(w.b, ':')
w.b = append(w.b, []byte(strconv.FormatInt(int64(num), 10))...)
w.b = append(w.b, '\r', '\n')
return nil
}
func (w *writer) Close() error {
if w.err != nil {
return w.err
}
if err := w.Flush(); err != nil {
w.err = err
return err
}
w.err = errClosed
return nil
} }

View File

@ -1,11 +1,13 @@
package redcon package redcon
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"log" "log"
"math/rand" "math/rand"
"net" "net"
"strconv"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -145,32 +147,32 @@ func TestRandomCommands(t *testing.T) {
cnt := 0 cnt := 0
idx := 0 idx := 0
start := time.Now() start := time.Now()
r := newReader(rd, make([]byte, 256)) r := NewReader(rd)
for { for {
cmds, err := r.ReadCommands() cmd, err := r.ReadCommand()
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
break break
} }
log.Fatal(err) log.Fatal(err)
} }
for _, cmd := range cmds { if len(cmd.Args) == 3 && string(cmd.Args[0]) == "RESET" &&
if len(cmd) == 3 && string(cmd[0]) == "RESET" && string(cmd[1]) == "THE" && string(cmd[2]) == "INDEX" { string(cmd.Args[1]) == "THE" && string(cmd.Args[2]) == "INDEX" {
if idx != len(gcmds) { if idx != len(gcmds) {
t.Fatalf("did not process all commands") t.Fatalf("did not process all commands")
} }
idx = 0 idx = 0
break break
} }
if len(cmd) != len(gcmds[idx]) { if len(cmd.Args) != len(gcmds[idx]) {
t.Fatalf("len not equal for index %d -- %d != %d", idx, len(cmd), len(gcmds[idx])) t.Fatalf("len not equal for index %d -- %d != %d", idx, len(cmd.Args), len(gcmds[idx]))
} }
for i := 0; i < len(cmd); i++ { for i := 0; i < len(cmd.Args); i++ {
if i == 0 { if i == 0 {
if len(cmd[i]) == len(gcmds[idx][i]) { if len(cmd.Args[i]) == len(gcmds[idx][i]) {
ok := true ok := true
for j := 0; j < len(cmd[i]); j++ { for j := 0; j < len(cmd.Args[i]); j++ {
c1, c2 := cmd[i][j], gcmds[idx][i][j] c1, c2 := cmd.Args[i][j], gcmds[idx][i][j]
if c1 >= 'A' && c1 <= 'Z' { if c1 >= 'A' && c1 <= 'Z' {
c1 += 32 c1 += 32
} }
@ -186,7 +188,7 @@ func TestRandomCommands(t *testing.T) {
continue continue
} }
} }
} else if string(cmd[i]) == string(gcmds[idx][i]) { } else if string(cmd.Args[i]) == string(gcmds[idx][i]) {
continue continue
} }
t.Fatalf("not equal for index %d/%d", idx, i) t.Fatalf("not equal for index %d/%d", idx, i)
@ -194,14 +196,13 @@ func TestRandomCommands(t *testing.T) {
idx++ idx++
cnt++ cnt++
} }
}
if false { if false {
dur := time.Now().Sub(start) dur := time.Now().Sub(start)
fmt.Printf("%d commands in %s - %.0f ops/sec\n", cnt, dur, float64(cnt)/(float64(dur)/float64(time.Second))) fmt.Printf("%d commands in %s - %.0f ops/sec\n", cnt, dur, float64(cnt)/(float64(dur)/float64(time.Second)))
} }
} }
func testHijack(t *testing.T, conn HijackedConn) { func testDetached(t *testing.T, conn DetachedConn) {
conn.WriteString("HIJACKED") conn.WriteString("DETACHED")
if err := conn.Flush(); err != nil { if err := conn.Flush(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -209,24 +210,23 @@ func testHijack(t *testing.T, conn HijackedConn) {
func TestServer(t *testing.T) { func TestServer(t *testing.T) {
s := NewServer(":12345", s := NewServer(":12345",
func(conn Conn, cmds [][]string) { func(conn Conn, cmd Command) {
for _, cmd := range cmds { switch strings.ToLower(string(cmd.Args[0])) {
switch strings.ToLower(cmd[0]) {
default: default:
conn.WriteError("ERR unknown command '" + cmd[0] + "'") conn.WriteError("ERR unknown command '" + string(cmd.Args[0]) + "'")
case "ping": case "ping":
conn.WriteString("PONG") conn.WriteString("PONG")
case "quit": case "quit":
conn.WriteString("OK") conn.WriteString("OK")
conn.Close() conn.Close()
case "hijack": case "detach":
go testHijack(t, conn.Hijack()) go testDetached(t, conn.Detach())
case "int": case "int":
conn.WriteInt(100) conn.WriteInt(100)
case "bulk": case "bulk":
conn.WriteBulk("bulk") conn.WriteBulkString("bulk")
case "bulkbytes": case "bulkbytes":
conn.WriteBulkBytes([]byte("bulkbytes")) conn.WriteBulk([]byte("bulkbytes"))
case "null": case "null":
conn.WriteNull() conn.WriteNull()
case "err": case "err":
@ -236,7 +236,6 @@ func TestServer(t *testing.T) {
conn.WriteInt(99) conn.WriteInt(99)
conn.WriteString("Hi!") conn.WriteString("Hi!")
} }
}
}, },
func(conn Conn) bool { func(conn Conn) bool {
//log.Printf("accept: %s", conn.RemoteAddr()) //log.Printf("accept: %s", conn.RemoteAddr())
@ -251,7 +250,7 @@ func TestServer(t *testing.T) {
} }
go func() { go func() {
time.Sleep(time.Second / 4) time.Sleep(time.Second / 4)
if err := ListenAndServe(":12345", nil, nil, nil); err == nil { if err := ListenAndServe(":12345", func(conn Conn, cmd Command) {}, nil, nil); err == nil {
t.Fatalf("expected an error, should not be able to listen on the same port") t.Fatalf("expected an error, should not be able to listen on the same port")
} }
time.Sleep(time.Second / 4) time.Sleep(time.Second / 4)
@ -294,56 +293,56 @@ func TestServer(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if res != "+PONG\r\n" { if res != "+PONG\r\n" {
t.Fatal("expecting '+PONG\r\n', got '%v'", res) t.Fatalf("expecting '+PONG\r\n', got '%v'", res)
} }
res, err = do("BULK\r\n") res, err = do("BULK\r\n")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if res != "$4\r\nbulk\r\n" { if res != "$4\r\nbulk\r\n" {
t.Fatal("expecting bulk, got '%v'", res) t.Fatalf("expecting bulk, got '%v'", res)
} }
res, err = do("BULKBYTES\r\n") res, err = do("BULKBYTES\r\n")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if res != "$9\r\nbulkbytes\r\n" { if res != "$9\r\nbulkbytes\r\n" {
t.Fatal("expecting bulkbytes, got '%v'", res) t.Fatalf("expecting bulkbytes, got '%v'", res)
} }
res, err = do("INT\r\n") res, err = do("INT\r\n")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if res != ":100\r\n" { if res != ":100\r\n" {
t.Fatal("expecting int, got '%v'", res) t.Fatalf("expecting int, got '%v'", res)
} }
res, err = do("NULL\r\n") res, err = do("NULL\r\n")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if res != "$-1\r\n" { if res != "$-1\r\n" {
t.Fatal("expecting nul, got '%v'", res) t.Fatalf("expecting nul, got '%v'", res)
} }
res, err = do("ARRAY\r\n") res, err = do("ARRAY\r\n")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if res != "*2\r\n:99\r\n+Hi!\r\n" { if res != "*2\r\n:99\r\n+Hi!\r\n" {
t.Fatal("expecting array, got '%v'", res) t.Fatalf("expecting array, got '%v'", res)
} }
res, err = do("ERR\r\n") res, err = do("ERR\r\n")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if res != "-ERR error\r\n" { if res != "-ERR error\r\n" {
t.Fatal("expecting array, got '%v'", res) t.Fatalf("expecting array, got '%v'", res)
} }
res, err = do("HIJACK\r\n") res, err = do("DETACH\r\n")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if res != "+HIJACKED\r\n" { if res != "+DETACHED\r\n" {
t.Fatal("expecting string, got '%v'", res) t.Fatalf("expecting string, got '%v'", res)
} }
}() }()
go func() { go func() {
@ -354,3 +353,195 @@ func TestServer(t *testing.T) {
}() }()
<-done <-done
} }
func TestWriter(t *testing.T) {
buf := &bytes.Buffer{}
wr := NewWriter(buf)
wr.WriteError("ERR bad stuff")
wr.Flush()
if buf.String() != "-ERR bad stuff\r\n" {
t.Fatal("failed")
}
buf.Reset()
wr.WriteString("HELLO")
wr.Flush()
if buf.String() != "+HELLO\r\n" {
t.Fatal("failed")
}
buf.Reset()
wr.WriteInt(-1234)
wr.Flush()
if buf.String() != ":-1234\r\n" {
t.Fatal("failed")
}
buf.Reset()
wr.WriteNull()
wr.Flush()
if buf.String() != "$-1\r\n" {
t.Fatal("failed")
}
buf.Reset()
wr.WriteBulk([]byte("HELLO\r\nPLANET"))
wr.Flush()
if buf.String() != "$13\r\nHELLO\r\nPLANET\r\n" {
t.Fatal("failed")
}
buf.Reset()
wr.WriteBulkString("HELLO\r\nPLANET")
wr.Flush()
if buf.String() != "$13\r\nHELLO\r\nPLANET\r\n" {
t.Fatal("failed")
}
buf.Reset()
wr.WriteArray(3)
wr.WriteBulkString("THIS")
wr.WriteBulkString("THAT")
wr.WriteString("THE OTHER THING")
wr.Flush()
if buf.String() != "*3\r\n$4\r\nTHIS\r\n$4\r\nTHAT\r\n+THE OTHER THING\r\n" {
t.Fatal("failed")
}
buf.Reset()
}
func testMakeRawCommands(rawargs [][]string) []string {
var rawcmds []string
for i := 0; i < len(rawargs); i++ {
rawcmd := "*" + strconv.FormatUint(uint64(len(rawargs[i])), 10) + "\r\n"
for j := 0; j < len(rawargs[i]); j++ {
rawcmd += "$" + strconv.FormatUint(uint64(len(rawargs[i][j])), 10) + "\r\n"
rawcmd += rawargs[i][j] + "\r\n"
}
rawcmds = append(rawcmds, rawcmd)
}
return rawcmds
}
func TestReaderRespRandom(t *testing.T) {
rand.Seed(time.Now().UnixNano())
for h := 0; h < 10000; h++ {
var rawargs [][]string
for i := 0; i < 100; i++ {
var args []string
n := int(rand.Int() % 16)
for j := 0; j < n; j++ {
arg := make([]byte, rand.Int()%512)
rand.Read(arg)
args = append(args, string(arg))
}
}
rawcmds := testMakeRawCommands(rawargs)
data := strings.Join(rawcmds, "")
rd := NewReader(bytes.NewBufferString(data))
for i := 0; i < len(rawcmds); i++ {
if len(rawargs[i]) == 0 {
continue
}
cmd, err := rd.ReadCommand()
if err != nil {
t.Fatal(err)
}
if string(cmd.Raw) != rawcmds[i] {
t.Fatalf("expected '%v', got '%v'", rawcmds[i], string(cmd.Raw))
}
if len(cmd.Args) != len(rawargs[i]) {
t.Fatalf("expected '%v', got '%v'", len(rawargs[i]), len(cmd.Args))
}
for j := 0; j < len(rawargs[i]); j++ {
if string(cmd.Args[j]) != rawargs[i][j] {
t.Fatalf("expected '%v', got '%v'", rawargs[i][j], string(cmd.Args[j]))
}
}
}
}
}
func TestPlainReader(t *testing.T) {
rawargs := [][]string{
{"HELLO", "WORLD"},
{"HELLO", "WORLD"},
{"HELLO", "PLANET"},
{"HELLO", "JELLO"},
{"HELLO ", "JELLO"},
}
rawcmds := []string{
"HELLO WORLD\n",
"HELLO WORLD\r\n",
" HELLO PLANET \r\n",
" \"HELLO\" \"JELLO\" \r\n",
" \"HELLO \" JELLO \n",
}
rawres := []string{
"*2\r\n$5\r\nHELLO\r\n$5\r\nWORLD\r\n",
"*2\r\n$5\r\nHELLO\r\n$5\r\nWORLD\r\n",
"*2\r\n$5\r\nHELLO\r\n$6\r\nPLANET\r\n",
"*2\r\n$5\r\nHELLO\r\n$5\r\nJELLO\r\n",
"*2\r\n$6\r\nHELLO \r\n$5\r\nJELLO\r\n",
}
data := strings.Join(rawcmds, "")
rd := NewReader(bytes.NewBufferString(data))
for i := 0; i < len(rawcmds); i++ {
if len(rawargs[i]) == 0 {
continue
}
cmd, err := rd.ReadCommand()
if err != nil {
t.Fatal(err)
}
if string(cmd.Raw) != rawres[i] {
t.Fatalf("expected '%v', got '%v'", rawres[i], string(cmd.Raw))
}
if len(cmd.Args) != len(rawargs[i]) {
t.Fatalf("expected '%v', got '%v'", len(rawargs[i]), len(cmd.Args))
}
for j := 0; j < len(rawargs[i]); j++ {
if string(cmd.Args[j]) != rawargs[i][j] {
t.Fatalf("expected '%v', got '%v'", rawargs[i][j], string(cmd.Args[j]))
}
}
}
}
func TestParse(t *testing.T) {
_, err := Parse(nil)
if err != errIncompleteCommand {
t.Fatalf("expected '%v', got '%v'", errIncompleteCommand, err)
}
_, err = Parse([]byte("*1\r\n"))
if err != errIncompleteCommand {
t.Fatalf("expected '%v', got '%v'", errIncompleteCommand, err)
}
_, err = Parse([]byte("*-1\r\n"))
if err != errInvalidMultiBulkLength {
t.Fatalf("expected '%v', got '%v'", errInvalidMultiBulkLength, err)
}
_, err = Parse([]byte("*0\r\n"))
if err != errInvalidMultiBulkLength {
t.Fatalf("expected '%v', got '%v'", errInvalidMultiBulkLength, err)
}
cmd, err := Parse([]byte("*1\r\n$1\r\nA\r\n"))
if err != nil {
t.Fatal(err)
}
if string(cmd.Raw) != "*1\r\n$1\r\nA\r\n" {
t.Fatalf("expected '%v', got '%v'", "*1\r\n$1\r\nA\r\n", string(cmd.Raw))
}
if len(cmd.Args) != 1 {
t.Fatalf("expected '%v', got '%v'", 1, len(cmd.Args))
}
if string(cmd.Args[0]) != "A" {
t.Fatalf("expected '%v', got '%v'", "A", string(cmd.Args[0]))
}
cmd, err = Parse([]byte("A\r\n"))
if err != nil {
t.Fatal(err)
}
if string(cmd.Raw) != "*1\r\n$1\r\nA\r\n" {
t.Fatalf("expected '%v', got '%v'", "*1\r\n$1\r\nA\r\n", string(cmd.Raw))
}
if len(cmd.Args) != 1 {
t.Fatalf("expected '%v', got '%v'", 1, len(cmd.Args))
}
if string(cmd.Args[0]) != "A" {
t.Fatalf("expected '%v', got '%v'", "A", string(cmd.Args[0]))
}
}