From 67c21aa488786a9e02490903e751a0f4c209d615 Mon Sep 17 00:00:00 2001 From: Josh Baker Date: Sat, 17 Sep 2016 20:01:16 -0700 Subject: [PATCH] Major API update. Please see details. The Redcon API has been changed to better reflect the wants of the community. THIS IS A BREAKING COMMIT... sorry, it's a one time thing. The changes include: 1. All commands and responses use []byte rather than string for data. 2. The handler signature has been changed from: func(conn redcon.Conn, args [][]string) to: func(conn redcon.Conn, cmd redcon.Command) 3. There's a new Reader and Writer types for reading commands and writing responses. Performance remains the same. --- README.md | 105 +++--- example/clone.go | 108 +++---- redcon.go | 820 +++++++++++++++++++++-------------------------- redcon_test.go | 341 +++++++++++++++----- 4 files changed, 737 insertions(+), 637 deletions(-) diff --git a/README.md b/README.md index 464f832..22603b7 100644 --- a/README.md +++ b/README.md @@ -11,15 +11,13 @@ Redcon is a custom Redis server framework for Go that is fast and simple to use. The reason for this library it to give an efficient server front-end for the [BuntDB](https://github.com/tidwall/buntdb) and [Tile38](https://github.com/tidwall/tile38) projects. - Features -------- - Create a [Fast](#benchmarks) custom Redis compatible server in Go -- Simple interface. One function `ListenAndServe` and one type `Conn` +- Simple interface. One function `ListenAndServe` and two types `Conn` & `Command` - Support for pipelining and telnet commands - Works with Redis clients such as [redigo](https://github.com/garyburd/redigo), [redis-py](https://github.com/andymccurdy/redis-py), [node_redis](https://github.com/NodeRedis/node_redis), and [jedis](https://github.com/xetorthio/jedis) - Installing ---------- @@ -29,6 +27,7 @@ go get -u github.com/tidwall/redcon Example ------- + Here's a full example of a Redis clone that accepts: - SET key value @@ -58,55 +57,53 @@ var addr = ":6380" func main() { var mu sync.RWMutex - var items = make(map[string]string) + var items = make(map[string][]byte) go log.Printf("started server at %s", addr) err := redcon.ListenAndServe(addr, - func(conn redcon.Conn, commands [][]string) { - for _, args := range commands { - switch strings.ToLower(args[0]) { - default: - conn.WriteError("ERR unknown command '" + args[0] + "'") - case "ping": - conn.WriteString("PONG") - case "quit": - conn.WriteString("OK") - conn.Close() - case "set": - if len(args) != 3 { - conn.WriteError("ERR wrong number of arguments for '" + args[0] + "' command") - continue - } - mu.Lock() - items[args[1]] = args[2] - mu.Unlock() - conn.WriteString("OK") - case "get": - if len(args) != 2 { - conn.WriteError("ERR wrong number of arguments for '" + args[0] + "' command") - continue - } - mu.RLock() - val, ok := items[args[1]] - mu.RUnlock() - if !ok { - conn.WriteNull() - } else { - conn.WriteBulk(val) - } - case "del": - if len(args) != 2 { - conn.WriteError("ERR wrong number of arguments for '" + args[0] + "' command") - continue - } - mu.Lock() - _, ok := items[args[1]] - delete(items, args[1]) - mu.Unlock() - if !ok { - conn.WriteInt(0) - } else { - conn.WriteInt(1) - } + func(conn redcon.Conn, cmd redcon.Command) { + switch strings.ToLower(string(cmd.Args[0])) { + default: + conn.WriteError("ERR unknown command '" + string(cmd.Args[0]) + "'") + case "ping": + conn.WriteString("PONG") + case "quit": + conn.WriteString("OK") + conn.Close() + case "set": + if len(cmd.Args) != 3 { + conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") + return + } + mu.Lock() + items[string(cmd.Args[1])] = cmd.Args[2] + mu.Unlock() + conn.WriteString("OK") + case "get": + if len(cmd.Args) != 2 { + conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") + return + } + mu.RLock() + val, ok := items[string(cmd.Args[1])] + mu.RUnlock() + if !ok { + conn.WriteNull() + } else { + conn.WriteBulk(val) + } + case "del": + if len(cmd.Args) != 2 { + conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") + return + } + mu.Lock() + _, ok := items[string(cmd.Args[1])] + delete(items, string(cmd.Args[1])) + mu.Unlock() + if !ok { + conn.WriteInt(0) + } else { + conn.WriteInt(1) } } }, @@ -147,8 +144,8 @@ $ GOMAXPROCS=1 go run example/clone.go ``` ``` redis-benchmark -p 6380 -t set,get -n 10000000 -q -P 512 -c 512 -SET: 3119151.50 requests per second -GET: 4142502.25 requests per second +SET: 2018570.88 requests per second +GET: 2403846.25 requests per second ``` **Redcon**: Multi-threaded, no disk persistence. @@ -158,8 +155,8 @@ $ GOMAXPROCS=0 go run example/clone.go ``` ``` $ redis-benchmark -p 6380 -t set,get -n 10000000 -q -P 512 -c 512 -SET: 3637686.25 requests per second -GET: 4249894.00 requests per second +SET: 1944390.38 requests per second +GET: 3993610.25 requests per second ``` *Running on a MacBook Pro 15" 2.8 GHz Intel Core i7 using Go 1.7* diff --git a/example/clone.go b/example/clone.go index 155fd99..f32bc1c 100644 --- a/example/clone.go +++ b/example/clone.go @@ -12,64 +12,62 @@ var addr = ":6380" func main() { var mu sync.RWMutex - var items = make(map[string]string) + var items = make(map[string][]byte) go log.Printf("started server at %s", addr) err := redcon.ListenAndServe(addr, - func(conn redcon.Conn, commands [][]string) { - for _, args := range commands { - switch strings.ToLower(args[0]) { - default: - conn.WriteError("ERR unknown command '" + args[0] + "'") - case "hijack": - hconn := conn.Hijack() - log.Printf("connection is hijacked") - go func() { - defer hconn.Close() - hconn.WriteString("OK") - hconn.Flush() - }() + func(conn redcon.Conn, cmd redcon.Command) { + switch strings.ToLower(string(cmd.Args[0])) { + default: + conn.WriteError("ERR unknown command '" + string(cmd.Args[0]) + "'") + case "detach": + hconn := conn.Detach() + log.Printf("connection has been detached") + go func() { + defer hconn.Close() + hconn.WriteString("OK") + hconn.Flush() + }() + return + case "ping": + conn.WriteString("PONG") + case "quit": + conn.WriteString("OK") + conn.Close() + case "set": + if len(cmd.Args) != 3 { + conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") return - case "ping": - conn.WriteString("PONG") - case "quit": - conn.WriteString("OK") - conn.Close() - case "set": - if len(args) != 3 { - conn.WriteError("ERR wrong number of arguments for '" + args[0] + "' command") - continue - } - mu.Lock() - items[args[1]] = args[2] - mu.Unlock() - conn.WriteString("OK") - case "get": - if len(args) != 2 { - conn.WriteError("ERR wrong number of arguments for '" + args[0] + "' command") - continue - } - mu.RLock() - val, ok := items[args[1]] - mu.RUnlock() - if !ok { - conn.WriteNull() - } else { - conn.WriteBulk(val) - } - case "del": - if len(args) != 2 { - conn.WriteError("ERR wrong number of arguments for '" + args[0] + "' command") - continue - } - mu.Lock() - _, ok := items[args[1]] - delete(items, args[1]) - mu.Unlock() - if !ok { - conn.WriteInt(0) - } else { - conn.WriteInt(1) - } + } + mu.Lock() + items[string(cmd.Args[1])] = cmd.Args[2] + mu.Unlock() + conn.WriteString("OK") + case "get": + if len(cmd.Args) != 2 { + conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") + return + } + mu.RLock() + val, ok := items[string(cmd.Args[1])] + mu.RUnlock() + if !ok { + conn.WriteNull() + } else { + conn.WriteBulk(val) + } + case "del": + if len(cmd.Args) != 2 { + conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") + return + } + mu.Lock() + _, ok := items[string(cmd.Args[1])] + delete(items, string(cmd.Args[1])) + mu.Unlock() + if !ok { + conn.WriteInt(0) + } else { + conn.WriteInt(1) } } }, diff --git a/redcon.go b/redcon.go index d8cfb17..bd3c81e 100644 --- a/redcon.go +++ b/redcon.go @@ -1,7 +1,8 @@ -// Package redcon provides a custom redis server implementation. +// Package redcon implements a Redis compatible server framework package redcon import ( + "bufio" "errors" "io" "net" @@ -9,6 +10,23 @@ import ( "sync" ) +var ( + errUnbalancedQuotes = &errProtocol{"unbalanced quotes in request"} + errInvalidBulkLength = &errProtocol{"invalid bulk length"} + errInvalidMultiBulkLength = &errProtocol{"invalid multibulk length"} + errDetached = errors.New("detached") + errIncompleteCommand = errors.New("incomplete command") + errTooMuchData = errors.New("too much data") +) + +type errProtocol struct { + msg string +} + +func (err *errProtocol) Error() string { + return "Protocol error: " + err.msg +} + // Conn represents a client connection type Conn interface { // RemoteAddr returns the remote address of the client connection. @@ -19,10 +37,10 @@ type Conn interface { WriteError(msg string) // WriteString writes a string to the client. WriteString(str string) - // WriteBulk writes a bulk string to the client. - WriteBulk(bulk string) - // WriteBulkBytes writes bulk bytes to the client. - WriteBulkBytes(bulk []byte) + // WriteBulk writes bulk bytes to the client. + WriteBulk(bulk []byte) + // WriteBulkString writes a bulk string to the client. + WriteBulkString(bulk string) // WriteInt writes an integer to the client. WriteInt(num int) // WriteArray writes an array header. You must then write addtional @@ -35,107 +53,49 @@ type Conn interface { WriteArray(count int) // WriteNull writes a null to the client WriteNull() - // SetReadBuffer updates the buffer read size for the connection - SetReadBuffer(bytes int) // Context returns a user-defined context Context() interface{} // SetContext sets a user-defined context SetContext(v interface{}) - // Hijack return an unmanaged connection. Useful for operations like PubSub. + // SetReadBuffer updates the buffer read size for the connection + SetReadBuffer(bytes int) + // Detach return a connection that is detached from the server. + // Useful for operations like PubSub. // - // hconn := conn.Hijack() - // go func(){ - // defer hconn.Close() - // cmd, err := hconn.ReadCommand() - // if err != nil{ - // fmt.Printf("read failed: %v\n", err) - // return - // } - // fmt.Printf("received command: %v", cmd) - // hconn.WriteString("OK") - // if err := hconn.Flush(); err != nil{ - // fmt.Printf("write failed: %v\n", err) - // return - // } - // }() - Hijack() HijackedConn + // dconn := conn.Detach() + // go func(){ + // defer dconn.Close() + // cmd, err := dconn.ReadCommand() + // if err != nil{ + // fmt.Printf("read failed: %v\n", err) + // return + // } + // fmt.Printf("received command: %v", cmd) + // hconn.WriteString("OK") + // if err := dconn.Flush(); err != nil{ + // fmt.Printf("write failed: %v\n", err) + // return + // } + // }() + Detach() DetachedConn } -// HijackConn represents an unmanaged connection. -type HijackedConn interface { - // Conn is the original connection - Conn - // ReadCommand reads the next client command. - ReadCommand() ([]string, error) - // ReadCommandBytes reads the next client command as bytes. - ReadCommandBytes() ([][]byte, error) - // Flush flushes any writes to the network. - Flush() error -} - -var ( - errUnbalancedQuotes = &errProtocol{"unbalanced quotes in request"} - errInvalidBulkLength = &errProtocol{"invalid bulk length"} - errInvalidMultiBulkLength = &errProtocol{"invalid multibulk length"} - errHijacked = errors.New("hijacked") -) - -const ( - defaultBufLen = 4 * 1024 - defaultPoolSize = 64 -) - -type errProtocol struct { - msg string -} - -func (err *errProtocol) Error() string { - return "Protocol error: " + err.msg -} - -// Server represents a Redcon server. -type Server struct { - mu sync.Mutex - addr string - shandler func(conn Conn, cmds [][]string) - bhandler func(conn Conn, cmds [][][]byte) - accept func(conn Conn) bool - closed func(conn Conn, err error) - ln *net.TCPListener - done bool - conns map[*conn]bool - rdpool [][]byte - wrpool [][]byte -} - -// NewServer returns a new server -func NewServer( - addr string, handler func(conn Conn, cmds [][]string), - accept func(conn Conn) bool, closed func(conn Conn, err error), +// NewServer returns a new Redcon server. +func NewServer(addr string, + handler func(conn Conn, cmd Command), + accept func(conn Conn) bool, + closed func(conn Conn, err error), ) *Server { + if handler == nil { + panic("handler is nil") + } s := &Server{ - addr: addr, - shandler: handler, - accept: accept, - closed: closed, - conns: make(map[*conn]bool), + addr: addr, + handler: handler, + accept: accept, + closed: closed, + conns: make(map[*conn]bool), } - initbuf := make([]byte, defaultPoolSize*defaultBufLen) - s.rdpool = make([][]byte, defaultPoolSize) - for i := 0; i < defaultPoolSize; i++ { - s.rdpool[i] = initbuf[i*defaultBufLen : i*defaultBufLen+defaultBufLen] - } - return s -} - -// NewServerBytes returns a new server -// It uses []byte instead of string for the handler commands. -func NewServerBytes( - addr string, handler func(conn Conn, cmds [][][]byte), - accept func(conn Conn) bool, closed func(conn Conn, err error), -) *Server { - s := NewServer(addr, nil, accept, closed) - s.bhandler = handler return s } @@ -156,14 +116,19 @@ func (s *Server) ListenAndServe() error { return s.ListenServeAndSignal(nil) } +// ListenAndServe creates a new server and binds to addr. +func ListenAndServe(addr string, + handler func(conn Conn, cmd Command), + accept func(conn Conn) bool, + closed func(conn Conn, err error), +) error { + return NewServer(addr, handler, accept, closed).ListenAndServe() +} + // ListenServeAndSignal serves incoming connections and passes nil or error // when listening. signal can be nil. func (s *Server) ListenServeAndSignal(signal chan error) error { var addr = s.addr - var shandler = s.shandler - var bhandler = s.bhandler - var accept = s.accept - var closed = s.closed ln, err := net.Listen("tcp", addr) if err != nil { if signal != nil { @@ -200,136 +165,67 @@ func (s *Server) ListenServeAndSignal(signal chan error) error { } return err } - c := &conn{ - tcpc, - newWriter(tcpc), - newReader(tcpc, nil), - tcpc.RemoteAddr().String(), - nil, false, - } + c := &conn{conn: tcpc, addr: tcpc.RemoteAddr().String(), + wr: NewWriter(tcpc), rd: NewReader(tcpc)} s.mu.Lock() - if len(s.rdpool) > 0 { - c.rd.buf = s.rdpool[len(s.rdpool)-1] - s.rdpool = s.rdpool[:len(s.rdpool)-1] - } else { - c.rd.buf = make([]byte, defaultBufLen) - } - if len(s.wrpool) > 0 { - c.wr.b = s.wrpool[len(s.wrpool)-1] - s.wrpool = s.wrpool[:len(s.wrpool)-1] - } else { - c.wr.b = make([]byte, 0, 64) - } s.conns[c] = true s.mu.Unlock() - if accept != nil && !accept(c) { + if s.accept != nil && !s.accept(c) { s.mu.Lock() delete(s.conns, c) s.mu.Unlock() c.Close() continue } - if shandler == nil && bhandler == nil { - shandler = func(conn Conn, cmds [][]string) {} - } else if shandler != nil { - bhandler = nil - } else if bhandler != nil { - shandler = nil - } - go handle(s, c, shandler, bhandler, closed) + go handle(s, c) } } -// ListenAndServe creates a new server and binds to addr. -func ListenAndServe( - addr string, handler func(conn Conn, cmds [][]string), - accept func(conn Conn) bool, closed func(conn Conn, err error), -) error { - return NewServer(addr, handler, accept, closed).ListenAndServe() -} - -// ListenAndServeBytes creates a new server and binds to addr. -// It uses []byte instead of string for the handler commands. -func ListenAndServeBytes( - addr string, handler func(conn Conn, cmds [][][]byte), - accept func(conn Conn) bool, closed func(conn Conn, err error), -) error { - return NewServerBytes(addr, handler, accept, closed).ListenAndServe() -} - -func handle( - s *Server, c *conn, - shandler func(conn Conn, cmds [][]string), - bhandler func(conn Conn, cmds [][][]byte), - closed func(conn Conn, err error)) { +// handle manages the server connection. +func handle(s *Server, c *conn) { var err error defer func() { - if err != errHijacked { + if err != errDetached { + // do not close the connection when a detach is detected. c.conn.Close() } func() { + // remove the conn from the server s.mu.Lock() defer s.mu.Unlock() delete(s.conns, c) - if closed != nil { + if s.closed != nil { if err == io.EOF { err = nil } - closed(c, err) - } - if err != errHijacked { - if len(s.rdpool) < defaultPoolSize && len(c.rd.buf) < defaultBufLen { - s.rdpool = append(s.rdpool, c.rd.buf) - } - if len(s.wrpool) < defaultPoolSize && cap(c.wr.b) < defaultBufLen { - s.wrpool = append(s.wrpool, c.wr.b[:0]) - } + s.closed(c, err) } }() }() + err = func() error { + // read commands and feed back to the client for { - if c.hj { - return errHijacked - } - cmds, err := c.rd.ReadCommands() + // read pipeline commands + cmds, err := c.rd.readCommands(nil) if err != nil { if err, ok := err.(*errProtocol); ok { // All protocol errors should attempt a response to - // the client. Ignore errors. + // the client. Ignore write errors. c.wr.WriteError("ERR " + err.Error()) c.wr.Flush() } return err } - if len(cmds) > 0 { - if shandler != nil { - // convert bytes to strings - scmds := make([][]string, len(cmds)) - for i := 0; i < len(cmds); i++ { - scmds[i] = make([]string, len(cmds[i])) - for j := 0; j < len(scmds[i]); j++ { - scmds[i][j] = string(cmds[i][j]) - } - } - shandler(c, scmds) - } else if bhandler != nil { - // copy the byte commands once, before exposing to client. - for i := 0; i < len(cmds); i++ { - for j := 0; j < len(cmds[i]); j++ { - nb := make([]byte, len(cmds[i][j])) - copy(nb, cmds[i][j]) - cmds[i][j] = nb - } - } - bhandler(c, cmds) - } + for _, cmd := range cmds { + s.handler(c, cmd) } - if c.wr.err != nil { - if c.wr.err == errClosed { - return nil - } - return c.wr.err + if c.detached { + // client has been detached + return errDetached + } + if c.closed { + return nil } if err := c.wr.Flush(); err != nil { return err @@ -338,120 +234,218 @@ func handle( }() } +// conn represents a client connection type conn struct { - conn *net.TCPConn - wr *writer - rd *reader - addr string - ctx interface{} - hj bool + conn *net.TCPConn + wr *Writer + rd *Reader + addr string + ctx interface{} + detached bool + closed bool } func (c *conn) Close() error { - err := c.wr.Close() // flush and close the writer - c.conn.Close() // close the connection. ignore this error - return err // return the writer error only + c.closed = true + return c.conn.Close() } -func (c *conn) Context() interface{} { - return c.ctx -} -func (c *conn) SetContext(v interface{}) { - c.ctx = v -} -func (c *conn) WriteString(str string) { - c.wr.WriteString(str) -} -func (c *conn) WriteBulk(bulk string) { - c.wr.WriteBulk(bulk) -} -func (c *conn) WriteBulkBytes(bulk []byte) { - c.wr.WriteBulkBytes(bulk) -} -func (c *conn) WriteInt(num int) { - c.wr.WriteInt(num) -} -func (c *conn) WriteError(msg string) { - c.wr.WriteError(msg) -} -func (c *conn) WriteArray(count int) { - c.wr.WriteArrayStart(count) -} -func (c *conn) WriteNull() { - c.wr.WriteNull() -} -func (c *conn) RemoteAddr() string { - return c.addr -} -func (c *conn) SetReadBuffer(bytes int) { -} -func (c *conn) Hijack() HijackedConn { - c.hj = true - return &hijackedConn{conn: c} +func (c *conn) Context() interface{} { return c.ctx } +func (c *conn) SetContext(v interface{}) { c.ctx = v } +func (c *conn) SetReadBuffer(n int) {} +func (c *conn) WriteString(str string) { c.wr.WriteString(str) } +func (c *conn) WriteBulk(bulk []byte) { c.wr.WriteBulk(bulk) } +func (c *conn) WriteBulkString(bulk string) { c.wr.WriteBulkString(bulk) } +func (c *conn) WriteInt(num int) { c.wr.WriteInt(num) } +func (c *conn) WriteError(msg string) { c.wr.WriteError(msg) } +func (c *conn) WriteArray(count int) { c.wr.WriteArray(count) } +func (c *conn) WriteNull() { c.wr.WriteNull() } +func (c *conn) RemoteAddr() string { return c.addr } + +// DetachedConn represents a connection that is detached from the server +type DetachedConn interface { + // Conn is the original connection + Conn + // ReadCommand reads the next client command. + ReadCommand() (Command, error) + // Flush flushes any writes to the network. + Flush() error } -type hijackedConn struct { +// Detach removes the current connection from the server loop and returns +// a detached connection. This is useful for operations such as PubSub. +// The detached connection must be closed by calling Close() when done. +// All writes such as WriteString() will not be written to the client +// until Flush() is called. +func (c *conn) Detach() DetachedConn { + c.detached = true + return &detachedConn{conn: c} +} + +type detachedConn struct { *conn - cmds [][][]byte } -func (hjc *hijackedConn) Flush() error { - return hjc.conn.wr.Flush() +// Flush writes and Write* calls to the client. +func (dc *detachedConn) Flush() error { + return dc.conn.wr.Flush() } -func (hjc *hijackedConn) ReadCommandBytes() ([][]byte, error) { - if len(hjc.cmds) > 0 { - args := hjc.cmds[0] - hjc.cmds = hjc.cmds[1:] - for i, arg := range args { - nb := make([]byte, len(arg)) - copy(nb, arg) - args[i] = nb - } - return args, nil +// ReadCommand read the next command from the client. +func (dc *detachedConn) ReadCommand() (Command, error) { + if dc.closed { + return Command{}, errors.New("closed") } - cmds, err := hjc.rd.ReadCommands() + cmd, err := dc.rd.ReadCommand() if err != nil { - return nil, err + return Command{}, err } - hjc.cmds = cmds - return hjc.ReadCommandBytes() + return cmd, nil } -func (hjc *hijackedConn) ReadCommand() ([]string, error) { - if len(hjc.cmds) > 0 { - args := hjc.cmds[0] - hjc.cmds = hjc.cmds[1:] - nargs := make([]string, len(args)) - for i, arg := range args { - nargs[i] = string(arg) - } - return nargs, nil - } - return hjc.ReadCommand() +// Command represent a command +type Command struct { + // Raw is a encoded RESP message. + Raw []byte + // Args is a series of arguments that make up the command. + Args [][]byte } -// Reader represents a RESP command reader. -type reader struct { - r io.Reader // base reader +// Server defines a server for clients for managing client connections. +type Server struct { + mu sync.Mutex + addr string + handler func(conn Conn, cmd Command) + accept func(conn Conn) bool + closed func(conn Conn, err error) + conns map[*conn]bool + ln *net.TCPListener + done bool +} + +// Writer allows for writing RESP messages. +type Writer struct { + w io.Writer + b []byte +} + +// NewWriter creates a new RESP writer. +func NewWriter(wr io.Writer) *Writer { + return &Writer{ + w: wr, + } +} + +// WriteNull writes a null to the client +func (w *Writer) WriteNull() { + w.b = append(w.b, '$', '-', '1', '\r', '\n') +} + +// WriteArray writes an array header. You must then write addtional +// sub-responses to the client to complete the response. +// For example to write two strings: +// +// c.WriteArray(2) +// c.WriteBulk("item 1") +// c.WriteBulk("item 2") +func (w *Writer) WriteArray(count int) { + w.b = append(w.b, '*') + w.b = append(w.b, strconv.FormatInt(int64(count), 10)...) + w.b = append(w.b, '\r', '\n') +} + +// WriteBulk writes bulk bytes to the client. +func (w *Writer) WriteBulk(bulk []byte) { + w.b = append(w.b, '$') + w.b = append(w.b, strconv.FormatInt(int64(len(bulk)), 10)...) + w.b = append(w.b, '\r', '\n') + w.b = append(w.b, bulk...) + w.b = append(w.b, '\r', '\n') +} + +// WriteBulkString writes a bulk string to the client. +func (w *Writer) WriteBulkString(bulk string) { + w.b = append(w.b, '$') + w.b = append(w.b, strconv.FormatInt(int64(len(bulk)), 10)...) + w.b = append(w.b, '\r', '\n') + w.b = append(w.b, bulk...) + w.b = append(w.b, '\r', '\n') +} + +// Flush writes all unflushed Write* calls to the underlying writer. +func (w *Writer) Flush() error { + if _, err := w.w.Write(w.b); err != nil { + return err + } + w.b = w.b[:0] + return nil +} + +// WriteError writes an error to the client. +func (w *Writer) WriteError(msg string) { + w.b = append(w.b, '-') + w.b = append(w.b, msg...) + w.b = append(w.b, '\r', '\n') +} + +// WriteString writes a string to the client. +func (w *Writer) WriteString(msg string) { + if msg == "OK" { + w.b = append(w.b, '+', 'O', 'K', '\r', '\n') + } else { + w.b = append(w.b, '+') + w.b = append(w.b, []byte(msg)...) + w.b = append(w.b, '\r', '\n') + } +} + +// WriteInt writes an integer to the client. +func (w *Writer) WriteInt(num int) { + w.b = append(w.b, ':') + w.b = append(w.b, []byte(strconv.FormatInt(int64(num), 10))...) + w.b = append(w.b, '\r', '\n') +} + +// Reader represent a reader for RESP or telnet commands. +type Reader struct { + rd *bufio.Reader buf []byte start int end int + cmds []Command } -// NewReader returns a RESP command reader. -func newReader(r io.Reader, buf []byte) *reader { - return &reader{ - r: r, - buf: buf, +// NewReader returns a command reader which will read RESP or telnet commands. +func NewReader(rd io.Reader) *Reader { + return &Reader{ + rd: bufio.NewReader(rd), + buf: make([]byte, 4096), } } -// ReadCommands reads one or more commands from the reader. -func (r *reader) ReadCommands() ([][][]byte, error) { - if r.end-r.start > 0 { - b := r.buf[r.start:r.end] - // we have some potential commands. - var cmds [][][]byte +func parseInt(b []byte) (int, error) { + // shortcut atoi for 0-99. fails for negative numbers. + switch len(b) { + case 1: + if b[0] >= '0' && b[0] <= '9' { + return int(b[0] - '0'), nil + } + case 2: + if b[0] >= '0' && b[0] <= '9' && b[1] >= '0' && b[1] <= '9' { + return int(b[0]-'0')*10 + int(b[1]-'0'), nil + } + } + // fallback to standard library + n, err := strconv.ParseUint(string(b), 10, 64) + return int(n), err +} + +func (rd *Reader) readCommands(leftover *int) ([]Command, error) { + var cmds []Command + b := rd.buf[rd.start:rd.end] + + if len(b) > 0 { + // we have data, yay! + // but is this enough data for a complete command? or multiple? next: switch b[0] { default: @@ -464,7 +458,7 @@ func (r *reader) ReadCommands() ([][][]byte, error) { } else { line = b[:i] } - var args [][]byte + var cmd Command var quote bool var escape bool outer: @@ -475,7 +469,7 @@ func (r *reader) ReadCommands() ([][][]byte, error) { if !quote { if c == ' ' { if len(nline) > 0 { - args = append(args, nline) + cmd.Args = append(cmd.Args, nline) } line = line[i+1:] continue outer @@ -501,7 +495,7 @@ func (r *reader) ReadCommands() ([][][]byte, error) { } } else if c == '"' { quote = false - args = append(args, nline) + cmd.Args = append(cmd.Args, nline) line = line[i+1:] if len(line) > 0 && line[0] != ' ' { return nil, errUnbalancedQuotes @@ -518,12 +512,20 @@ func (r *reader) ReadCommands() ([][][]byte, error) { return nil, errUnbalancedQuotes } if len(line) > 0 { - args = append(args, line) + cmd.Args = append(cmd.Args, line) } break } - if len(args) > 0 { - cmds = append(cmds, args) + if len(cmd.Args) > 0 { + // convert this to resp command syntax + var wr Writer + wr.WriteArray(len(cmd.Args)) + for i := range cmd.Args { + wr.WriteBulk(cmd.Args[i]) + cmd.Args[i] = append([]byte(nil), cmd.Args[i]...) + } + cmd.Raw = wr.b + cmds = append(cmds, cmd) } b = b[i+1:] if len(b) > 0 { @@ -535,20 +537,19 @@ func (r *reader) ReadCommands() ([][][]byte, error) { } case '*': // resp formatted command - var si int + marks := make([]int, 0, 16) outer2: - for i := 0; i < len(b); i++ { - var args [][]byte + for i := 1; i < len(b); i++ { if b[i] == '\n' { if b[i-1] != '\r' { return nil, errInvalidMultiBulkLength } - ni, err := parseInt(b[si+1 : i-1]) - if err != nil || ni <= 0 { + count, err := parseInt(b[1 : i-1]) + if err != nil || count <= 0 { return nil, errInvalidMultiBulkLength } - args = make([][]byte, 0, ni) - for j := 0; j < ni; j++ { + marks = marks[:0] + for j := 0; j < count; j++ { // read bulk length i++ if i < len(b) { @@ -556,35 +557,49 @@ func (r *reader) ReadCommands() ([][][]byte, error) { return nil, &errProtocol{"expected '$', got '" + string(b[i]) + "'"} } - si = i + si := i for ; i < len(b); i++ { if b[i] == '\n' { if b[i-1] != '\r' { return nil, errInvalidBulkLength } - ni2, err := parseInt(b[si+1 : i-1]) - if err != nil || ni2 < 0 { + size, err := parseInt(b[si+1 : i-1]) + if err != nil || size < 0 { return nil, errInvalidBulkLength } - if i+ni2+2 >= len(b) { + if i+size+2 >= len(b) { // not ready break outer2 } - if b[i+ni2+2] != '\n' || - b[i+ni2+1] != '\r' { + if b[i+size+2] != '\n' || + b[i+size+1] != '\r' { return nil, errInvalidBulkLength } i++ - arg := b[i : i+ni2] - i += ni2 + 1 - args = append(args, arg) + marks = append(marks, i, i+size) + i += size + 1 break } } } } - if len(args) == cap(args) { - cmds = append(cmds, args) + if len(marks) == count*2 { + var cmd Command + if rd.rd != nil { + // make a raw copy of the entire command when + // there's a underlying reader. + cmd.Raw = append([]byte(nil), b[:i+1]...) + } else { + // just assign the slice + cmd.Raw = b[:i+1] + } + cmd.Args = make([][]byte, len(marks)/2) + // slice up the raw command into the args based on + // the recorded marks. + for h := 0; h < len(marks); h += 2 { + cmd.Args[h/2] = cmd.Raw[marks[h]:marks[h+1]] + } + cmds = append(cmds, cmd) b = b[i+1:] if len(b) > 0 { goto next @@ -596,164 +611,63 @@ func (r *reader) ReadCommands() ([][][]byte, error) { } } done: - if len(b) == 0 { - r.start = 0 - r.end = 0 + rd.start = rd.end - len(b) + } + if leftover != nil { + *leftover = rd.end - rd.start + } + if len(cmds) > 0 { + return cmds, nil + } + if rd.rd == nil { + return nil, errIncompleteCommand + } + if rd.end == len(rd.buf) { + // at the end of the buffer. + if rd.start == rd.end { + // rewind the to the beginning + rd.start, rd.end = 0, 0 } else { - r.start = r.end - len(b) - } - if len(cmds) > 0 { - return cmds, nil + // must grow the buffer + newbuf := make([]byte, len(rd.buf)*2) + copy(newbuf, rd.buf) + rd.buf = newbuf } } - if r.end == len(r.buf) { - nbuf := make([]byte, len(r.buf)*2) - copy(nbuf, r.buf) - r.buf = nbuf - } - n, err := r.r.Read(r.buf[r.end:]) + n, err := rd.rd.Read(rd.buf[rd.end:]) if err != nil { - if err == io.EOF { - if r.end > 0 { - return nil, io.ErrUnexpectedEOF - } - } return nil, err } - r.end += n - return r.ReadCommands() -} -func parseInt(b []byte) (int, error) { - switch len(b) { - case 1: - if b[0] >= '0' && b[0] <= '9' { - return int(b[0] - '0'), nil - } - case 2: - if b[0] >= '0' && b[0] <= '9' && b[1] >= '0' && b[1] <= '9' { - return int(b[0]-'0')*10 + int(b[1]-'0'), nil - } - } - var n int - for i := 0; i < len(b); i++ { - if b[i] < '0' || b[i] > '9' { - return 0, errors.New("invalid number") - } - n = n*10 + int(b[i]-'0') - } - return n, nil + rd.end += n + return rd.readCommands(leftover) } -var errClosed = errors.New("closed") - -type writer struct { - w *net.TCPConn - b []byte - err error +// ReadCommand reads the next command. +func (rd *Reader) ReadCommand() (Command, error) { + if len(rd.cmds) > 0 { + cmd := rd.cmds[0] + rd.cmds = rd.cmds[1:] + return cmd, nil + } + cmds, err := rd.readCommands(nil) + if err != nil { + return Command{}, err + } + rd.cmds = cmds + return rd.ReadCommand() } -func newWriter(w *net.TCPConn) *writer { - return &writer{w: w, b: make([]byte, 0, 512)} -} +// Parse parses a raw RESP message and returns a command. +func Parse(raw []byte) (Command, error) { + rd := Reader{buf: raw, end: len(raw)} + var leftover int + cmds, err := rd.readCommands(&leftover) + if err != nil { + return Command{}, err + } + if leftover > 0 { + return Command{}, errTooMuchData + } + return cmds[0], nil -func (w *writer) WriteNull() error { - if w.err != nil { - return w.err - } - w.b = append(w.b, '$', '-', '1', '\r', '\n') - return nil -} -func (w *writer) WriteArrayStart(count int) error { - if w.err != nil { - return w.err - } - w.b = append(w.b, '*') - w.b = append(w.b, []byte(strconv.FormatInt(int64(count), 10))...) - w.b = append(w.b, '\r', '\n') - return nil -} - -func (w *writer) WriteBulk(bulk string) error { - if w.err != nil { - return w.err - } - w.b = append(w.b, '$') - w.b = append(w.b, []byte(strconv.FormatInt(int64(len(bulk)), 10))...) - w.b = append(w.b, '\r', '\n') - w.b = append(w.b, []byte(bulk)...) - w.b = append(w.b, '\r', '\n') - return nil -} - -func (w *writer) WriteBulkBytes(bulk []byte) error { - if w.err != nil { - return w.err - } - w.b = append(w.b, '$') - w.b = append(w.b, []byte(strconv.FormatInt(int64(len(bulk)), 10))...) - w.b = append(w.b, '\r', '\n') - w.b = append(w.b, bulk...) - w.b = append(w.b, '\r', '\n') - return nil -} - -func (w *writer) Flush() error { - if w.err != nil { - return w.err - } - if len(w.b) == 0 { - return nil - } - if _, err := w.w.Write(w.b); err != nil { - w.err = err - return err - } - w.b = w.b[:0] - return nil -} - -func (w *writer) WriteError(msg string) error { - if w.err != nil { - return w.err - } - w.b = append(w.b, '-') - w.b = append(w.b, []byte(msg)...) - w.b = append(w.b, '\r', '\n') - return nil -} - -func (w *writer) WriteString(msg string) error { - if w.err != nil { - return w.err - } - if msg == "OK" { - w.b = append(w.b, '+', 'O', 'K', '\r', '\n') - } else { - w.b = append(w.b, '+') - w.b = append(w.b, []byte(msg)...) - w.b = append(w.b, '\r', '\n') - } - return nil -} - -func (w *writer) WriteInt(num int) error { - if w.err != nil { - return w.err - } - w.b = append(w.b, ':') - w.b = append(w.b, []byte(strconv.FormatInt(int64(num), 10))...) - w.b = append(w.b, '\r', '\n') - return nil -} - -func (w *writer) Close() error { - if w.err != nil { - return w.err - } - if err := w.Flush(); err != nil { - w.err = err - return err - } - w.err = errClosed - return nil } diff --git a/redcon_test.go b/redcon_test.go index e8fcb71..97029dd 100644 --- a/redcon_test.go +++ b/redcon_test.go @@ -1,11 +1,13 @@ package redcon import ( + "bytes" "fmt" "io" "log" "math/rand" "net" + "strconv" "strings" "testing" "time" @@ -145,63 +147,62 @@ func TestRandomCommands(t *testing.T) { cnt := 0 idx := 0 start := time.Now() - r := newReader(rd, make([]byte, 256)) + r := NewReader(rd) for { - cmds, err := r.ReadCommands() + cmd, err := r.ReadCommand() if err != nil { if err == io.EOF { break } log.Fatal(err) } - for _, cmd := range cmds { - if len(cmd) == 3 && string(cmd[0]) == "RESET" && string(cmd[1]) == "THE" && string(cmd[2]) == "INDEX" { - if idx != len(gcmds) { - t.Fatalf("did not process all commands") - } - idx = 0 - break + if len(cmd.Args) == 3 && string(cmd.Args[0]) == "RESET" && + string(cmd.Args[1]) == "THE" && string(cmd.Args[2]) == "INDEX" { + if idx != len(gcmds) { + t.Fatalf("did not process all commands") } - if len(cmd) != len(gcmds[idx]) { - t.Fatalf("len not equal for index %d -- %d != %d", idx, len(cmd), len(gcmds[idx])) - } - for i := 0; i < len(cmd); i++ { - if i == 0 { - if len(cmd[i]) == len(gcmds[idx][i]) { - ok := true - for j := 0; j < len(cmd[i]); j++ { - c1, c2 := cmd[i][j], gcmds[idx][i][j] - if c1 >= 'A' && c1 <= 'Z' { - c1 += 32 - } - if c2 >= 'A' && c2 <= 'Z' { - c2 += 32 - } - if c1 != c2 { - ok = false - break - } + idx = 0 + break + } + if len(cmd.Args) != len(gcmds[idx]) { + t.Fatalf("len not equal for index %d -- %d != %d", idx, len(cmd.Args), len(gcmds[idx])) + } + for i := 0; i < len(cmd.Args); i++ { + if i == 0 { + if len(cmd.Args[i]) == len(gcmds[idx][i]) { + ok := true + for j := 0; j < len(cmd.Args[i]); j++ { + c1, c2 := cmd.Args[i][j], gcmds[idx][i][j] + if c1 >= 'A' && c1 <= 'Z' { + c1 += 32 } - if ok { - continue + if c2 >= 'A' && c2 <= 'Z' { + c2 += 32 + } + if c1 != c2 { + ok = false + break } } - } else if string(cmd[i]) == string(gcmds[idx][i]) { - continue + if ok { + continue + } } - t.Fatalf("not equal for index %d/%d", idx, i) + } else if string(cmd.Args[i]) == string(gcmds[idx][i]) { + continue } - idx++ - cnt++ + t.Fatalf("not equal for index %d/%d", idx, i) } + idx++ + cnt++ } if false { dur := time.Now().Sub(start) fmt.Printf("%d commands in %s - %.0f ops/sec\n", cnt, dur, float64(cnt)/(float64(dur)/float64(time.Second))) } } -func testHijack(t *testing.T, conn HijackedConn) { - conn.WriteString("HIJACKED") +func testDetached(t *testing.T, conn DetachedConn) { + conn.WriteString("DETACHED") if err := conn.Flush(); err != nil { t.Fatal(err) } @@ -209,33 +210,31 @@ func testHijack(t *testing.T, conn HijackedConn) { func TestServer(t *testing.T) { s := NewServer(":12345", - func(conn Conn, cmds [][]string) { - for _, cmd := range cmds { - switch strings.ToLower(cmd[0]) { - default: - conn.WriteError("ERR unknown command '" + cmd[0] + "'") - case "ping": - conn.WriteString("PONG") - case "quit": - conn.WriteString("OK") - conn.Close() - case "hijack": - go testHijack(t, conn.Hijack()) - case "int": - conn.WriteInt(100) - case "bulk": - conn.WriteBulk("bulk") - case "bulkbytes": - conn.WriteBulkBytes([]byte("bulkbytes")) - case "null": - conn.WriteNull() - case "err": - conn.WriteError("ERR error") - case "array": - conn.WriteArray(2) - conn.WriteInt(99) - conn.WriteString("Hi!") - } + func(conn Conn, cmd Command) { + switch strings.ToLower(string(cmd.Args[0])) { + default: + conn.WriteError("ERR unknown command '" + string(cmd.Args[0]) + "'") + case "ping": + conn.WriteString("PONG") + case "quit": + conn.WriteString("OK") + conn.Close() + case "detach": + go testDetached(t, conn.Detach()) + case "int": + conn.WriteInt(100) + case "bulk": + conn.WriteBulkString("bulk") + case "bulkbytes": + conn.WriteBulk([]byte("bulkbytes")) + case "null": + conn.WriteNull() + case "err": + conn.WriteError("ERR error") + case "array": + conn.WriteArray(2) + conn.WriteInt(99) + conn.WriteString("Hi!") } }, func(conn Conn) bool { @@ -251,7 +250,7 @@ func TestServer(t *testing.T) { } go func() { time.Sleep(time.Second / 4) - if err := ListenAndServe(":12345", nil, nil, nil); err == nil { + if err := ListenAndServe(":12345", func(conn Conn, cmd Command) {}, nil, nil); err == nil { t.Fatalf("expected an error, should not be able to listen on the same port") } time.Sleep(time.Second / 4) @@ -294,56 +293,56 @@ func TestServer(t *testing.T) { t.Fatal(err) } if res != "+PONG\r\n" { - t.Fatal("expecting '+PONG\r\n', got '%v'", res) + t.Fatalf("expecting '+PONG\r\n', got '%v'", res) } res, err = do("BULK\r\n") if err != nil { t.Fatal(err) } if res != "$4\r\nbulk\r\n" { - t.Fatal("expecting bulk, got '%v'", res) + t.Fatalf("expecting bulk, got '%v'", res) } res, err = do("BULKBYTES\r\n") if err != nil { t.Fatal(err) } if res != "$9\r\nbulkbytes\r\n" { - t.Fatal("expecting bulkbytes, got '%v'", res) + t.Fatalf("expecting bulkbytes, got '%v'", res) } res, err = do("INT\r\n") if err != nil { t.Fatal(err) } if res != ":100\r\n" { - t.Fatal("expecting int, got '%v'", res) + t.Fatalf("expecting int, got '%v'", res) } res, err = do("NULL\r\n") if err != nil { t.Fatal(err) } if res != "$-1\r\n" { - t.Fatal("expecting nul, got '%v'", res) + t.Fatalf("expecting nul, got '%v'", res) } res, err = do("ARRAY\r\n") if err != nil { t.Fatal(err) } if res != "*2\r\n:99\r\n+Hi!\r\n" { - t.Fatal("expecting array, got '%v'", res) + t.Fatalf("expecting array, got '%v'", res) } res, err = do("ERR\r\n") if err != nil { t.Fatal(err) } if res != "-ERR error\r\n" { - t.Fatal("expecting array, got '%v'", res) + t.Fatalf("expecting array, got '%v'", res) } - res, err = do("HIJACK\r\n") + res, err = do("DETACH\r\n") if err != nil { t.Fatal(err) } - if res != "+HIJACKED\r\n" { - t.Fatal("expecting string, got '%v'", res) + if res != "+DETACHED\r\n" { + t.Fatalf("expecting string, got '%v'", res) } }() go func() { @@ -354,3 +353,195 @@ func TestServer(t *testing.T) { }() <-done } + +func TestWriter(t *testing.T) { + buf := &bytes.Buffer{} + wr := NewWriter(buf) + wr.WriteError("ERR bad stuff") + wr.Flush() + if buf.String() != "-ERR bad stuff\r\n" { + t.Fatal("failed") + } + buf.Reset() + wr.WriteString("HELLO") + wr.Flush() + if buf.String() != "+HELLO\r\n" { + t.Fatal("failed") + } + buf.Reset() + wr.WriteInt(-1234) + wr.Flush() + if buf.String() != ":-1234\r\n" { + t.Fatal("failed") + } + buf.Reset() + wr.WriteNull() + wr.Flush() + if buf.String() != "$-1\r\n" { + t.Fatal("failed") + } + buf.Reset() + wr.WriteBulk([]byte("HELLO\r\nPLANET")) + wr.Flush() + if buf.String() != "$13\r\nHELLO\r\nPLANET\r\n" { + t.Fatal("failed") + } + buf.Reset() + wr.WriteBulkString("HELLO\r\nPLANET") + wr.Flush() + if buf.String() != "$13\r\nHELLO\r\nPLANET\r\n" { + t.Fatal("failed") + } + buf.Reset() + wr.WriteArray(3) + wr.WriteBulkString("THIS") + wr.WriteBulkString("THAT") + wr.WriteString("THE OTHER THING") + wr.Flush() + if buf.String() != "*3\r\n$4\r\nTHIS\r\n$4\r\nTHAT\r\n+THE OTHER THING\r\n" { + t.Fatal("failed") + } + buf.Reset() +} +func testMakeRawCommands(rawargs [][]string) []string { + var rawcmds []string + for i := 0; i < len(rawargs); i++ { + rawcmd := "*" + strconv.FormatUint(uint64(len(rawargs[i])), 10) + "\r\n" + for j := 0; j < len(rawargs[i]); j++ { + rawcmd += "$" + strconv.FormatUint(uint64(len(rawargs[i][j])), 10) + "\r\n" + rawcmd += rawargs[i][j] + "\r\n" + } + rawcmds = append(rawcmds, rawcmd) + } + return rawcmds +} + +func TestReaderRespRandom(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + for h := 0; h < 10000; h++ { + var rawargs [][]string + for i := 0; i < 100; i++ { + var args []string + n := int(rand.Int() % 16) + for j := 0; j < n; j++ { + arg := make([]byte, rand.Int()%512) + rand.Read(arg) + args = append(args, string(arg)) + } + } + rawcmds := testMakeRawCommands(rawargs) + data := strings.Join(rawcmds, "") + rd := NewReader(bytes.NewBufferString(data)) + for i := 0; i < len(rawcmds); i++ { + if len(rawargs[i]) == 0 { + continue + } + cmd, err := rd.ReadCommand() + if err != nil { + t.Fatal(err) + } + if string(cmd.Raw) != rawcmds[i] { + t.Fatalf("expected '%v', got '%v'", rawcmds[i], string(cmd.Raw)) + } + if len(cmd.Args) != len(rawargs[i]) { + t.Fatalf("expected '%v', got '%v'", len(rawargs[i]), len(cmd.Args)) + } + for j := 0; j < len(rawargs[i]); j++ { + if string(cmd.Args[j]) != rawargs[i][j] { + t.Fatalf("expected '%v', got '%v'", rawargs[i][j], string(cmd.Args[j])) + } + } + } + } +} + +func TestPlainReader(t *testing.T) { + rawargs := [][]string{ + {"HELLO", "WORLD"}, + {"HELLO", "WORLD"}, + {"HELLO", "PLANET"}, + {"HELLO", "JELLO"}, + {"HELLO ", "JELLO"}, + } + rawcmds := []string{ + "HELLO WORLD\n", + "HELLO WORLD\r\n", + " HELLO PLANET \r\n", + " \"HELLO\" \"JELLO\" \r\n", + " \"HELLO \" JELLO \n", + } + rawres := []string{ + "*2\r\n$5\r\nHELLO\r\n$5\r\nWORLD\r\n", + "*2\r\n$5\r\nHELLO\r\n$5\r\nWORLD\r\n", + "*2\r\n$5\r\nHELLO\r\n$6\r\nPLANET\r\n", + "*2\r\n$5\r\nHELLO\r\n$5\r\nJELLO\r\n", + "*2\r\n$6\r\nHELLO \r\n$5\r\nJELLO\r\n", + } + data := strings.Join(rawcmds, "") + rd := NewReader(bytes.NewBufferString(data)) + for i := 0; i < len(rawcmds); i++ { + if len(rawargs[i]) == 0 { + continue + } + cmd, err := rd.ReadCommand() + if err != nil { + t.Fatal(err) + } + if string(cmd.Raw) != rawres[i] { + t.Fatalf("expected '%v', got '%v'", rawres[i], string(cmd.Raw)) + } + if len(cmd.Args) != len(rawargs[i]) { + t.Fatalf("expected '%v', got '%v'", len(rawargs[i]), len(cmd.Args)) + } + for j := 0; j < len(rawargs[i]); j++ { + if string(cmd.Args[j]) != rawargs[i][j] { + t.Fatalf("expected '%v', got '%v'", rawargs[i][j], string(cmd.Args[j])) + } + } + } +} + +func TestParse(t *testing.T) { + _, err := Parse(nil) + if err != errIncompleteCommand { + t.Fatalf("expected '%v', got '%v'", errIncompleteCommand, err) + } + _, err = Parse([]byte("*1\r\n")) + if err != errIncompleteCommand { + t.Fatalf("expected '%v', got '%v'", errIncompleteCommand, err) + } + _, err = Parse([]byte("*-1\r\n")) + if err != errInvalidMultiBulkLength { + t.Fatalf("expected '%v', got '%v'", errInvalidMultiBulkLength, err) + } + _, err = Parse([]byte("*0\r\n")) + if err != errInvalidMultiBulkLength { + t.Fatalf("expected '%v', got '%v'", errInvalidMultiBulkLength, err) + } + cmd, err := Parse([]byte("*1\r\n$1\r\nA\r\n")) + if err != nil { + t.Fatal(err) + } + if string(cmd.Raw) != "*1\r\n$1\r\nA\r\n" { + t.Fatalf("expected '%v', got '%v'", "*1\r\n$1\r\nA\r\n", string(cmd.Raw)) + } + if len(cmd.Args) != 1 { + t.Fatalf("expected '%v', got '%v'", 1, len(cmd.Args)) + } + if string(cmd.Args[0]) != "A" { + t.Fatalf("expected '%v', got '%v'", "A", string(cmd.Args[0])) + } + cmd, err = Parse([]byte("A\r\n")) + if err != nil { + t.Fatal(err) + } + if string(cmd.Raw) != "*1\r\n$1\r\nA\r\n" { + t.Fatalf("expected '%v', got '%v'", "*1\r\n$1\r\nA\r\n", string(cmd.Raw)) + } + if len(cmd.Args) != 1 { + t.Fatalf("expected '%v', got '%v'", 1, len(cmd.Args)) + } + if string(cmd.Args[0]) != "A" { + t.Fatalf("expected '%v', got '%v'", "A", string(cmd.Args[0])) + } +}