redcon/redcon.go

709 lines
17 KiB
Go
Raw Normal View History

// Package redcon implements a Redis compatible server framework
2016-07-28 17:54:02 +03:00
package redcon
import (
"bufio"
2016-07-28 17:54:02 +03:00
"errors"
"io"
"net"
"strconv"
"sync"
)
var (
errUnbalancedQuotes = &errProtocol{"unbalanced quotes in request"}
errInvalidBulkLength = &errProtocol{"invalid bulk length"}
errInvalidMultiBulkLength = &errProtocol{"invalid multibulk length"}
errDetached = errors.New("detached")
errIncompleteCommand = errors.New("incomplete command")
errTooMuchData = errors.New("too much data")
)
type errProtocol struct {
msg string
}
func (err *errProtocol) Error() string {
return "Protocol error: " + err.msg
}
2016-08-20 20:42:34 +03:00
// Conn represents a client connection
2016-07-28 17:54:02 +03:00
type Conn interface {
2016-08-20 21:32:41 +03:00
// RemoteAddr returns the remote address of the client connection.
2016-07-28 17:54:02 +03:00
RemoteAddr() string
2016-08-20 21:32:41 +03:00
// Close closes the connection.
2016-07-28 17:54:02 +03:00
Close() error
2016-08-20 21:32:41 +03:00
// WriteError writes an error to the client.
2016-07-28 17:54:02 +03:00
WriteError(msg string)
2016-08-20 21:32:41 +03:00
// WriteString writes a string to the client.
2016-07-28 17:54:02 +03:00
WriteString(str string)
// WriteBulk writes bulk bytes to the client.
WriteBulk(bulk []byte)
// WriteBulkString writes a bulk string to the client.
WriteBulkString(bulk string)
2016-08-20 21:32:41 +03:00
// WriteInt writes an integer to the client.
2016-07-28 17:54:02 +03:00
WriteInt(num int)
2016-08-20 21:32:41 +03:00
// WriteArray writes an array header. You must then write addtional
// sub-responses to the client to complete the response.
// For example to write two strings:
//
// c.WriteArray(2)
// c.WriteBulk("item 1")
// c.WriteBulk("item 2")
2016-07-28 17:54:02 +03:00
WriteArray(count int)
2016-08-20 21:32:41 +03:00
// WriteNull writes a null to the client
2016-07-28 17:54:02 +03:00
WriteNull()
2016-09-22 23:11:27 +03:00
// WriteRaw writes raw data to the client.
WriteRaw(data []byte)
2016-08-29 17:01:30 +03:00
// Context returns a user-defined context
Context() interface{}
// SetContext sets a user-defined context
SetContext(v interface{})
// SetReadBuffer updates the buffer read size for the connection
SetReadBuffer(bytes int)
// Detach return a connection that is detached from the server.
// Useful for operations like PubSub.
//
// dconn := conn.Detach()
// go func(){
// defer dconn.Close()
// cmd, err := dconn.ReadCommand()
// if err != nil{
// fmt.Printf("read failed: %v\n", err)
// return
// }
// fmt.Printf("received command: %v", cmd)
// hconn.WriteString("OK")
// if err := dconn.Flush(); err != nil{
// fmt.Printf("write failed: %v\n", err)
// return
// }
// }()
Detach() DetachedConn
2016-09-23 00:32:11 +03:00
// ReadPipeline returns all commands in current pipeline, if any
ReadPipeline() []Command
}
// NewServer returns a new Redcon server.
func NewServer(addr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
2016-08-21 19:34:23 +03:00
) *Server {
if handler == nil {
panic("handler is nil")
2016-09-07 15:50:27 +03:00
}
s := &Server{
addr: addr,
handler: handler,
accept: accept,
closed: closed,
conns: make(map[*conn]bool),
2016-09-07 16:19:55 +03:00
}
return s
2016-09-07 15:50:27 +03:00
}
2016-08-21 19:34:23 +03:00
// Close stops listening on the TCP address.
// Already Accepted connections will be closed.
func (s *Server) Close() error {
2016-08-22 21:05:11 +03:00
s.mu.Lock()
defer s.mu.Unlock()
2016-08-21 19:34:23 +03:00
if s.ln == nil {
return errors.New("not serving")
}
s.done = true
return s.ln.Close()
}
// ListenAndServe serves incoming connections.
func (s *Server) ListenAndServe() error {
2016-08-22 20:53:04 +03:00
return s.ListenServeAndSignal(nil)
}
// ListenAndServe creates a new server and binds to addr.
func ListenAndServe(addr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
) error {
return NewServer(addr, handler, accept, closed).ListenAndServe()
}
2016-08-22 20:53:04 +03:00
// ListenServeAndSignal serves incoming connections and passes nil or error
// when listening. signal can be nil.
func (s *Server) ListenServeAndSignal(signal chan error) error {
2016-08-21 19:34:23 +03:00
var addr = s.addr
2016-07-28 17:54:02 +03:00
ln, err := net.Listen("tcp", addr)
if err != nil {
2016-08-22 20:53:04 +03:00
if signal != nil {
signal <- err
}
2016-07-28 17:54:02 +03:00
return err
}
2016-08-22 20:53:04 +03:00
if signal != nil {
signal <- nil
}
2016-08-22 21:11:30 +03:00
tln := ln.(*net.TCPListener)
2016-08-22 21:07:42 +03:00
s.mu.Lock()
2016-08-22 21:11:30 +03:00
s.ln = tln
2016-08-22 21:07:42 +03:00
s.mu.Unlock()
2016-08-21 19:34:23 +03:00
defer func() {
ln.Close()
func() {
s.mu.Lock()
defer s.mu.Unlock()
for c := range s.conns {
c.Close()
}
s.conns = nil
}()
}()
2016-07-28 17:54:02 +03:00
for {
2016-08-22 21:11:30 +03:00
tcpc, err := tln.AcceptTCP()
2016-07-28 17:54:02 +03:00
if err != nil {
2016-08-21 19:34:23 +03:00
s.mu.Lock()
done := s.done
s.mu.Unlock()
if done {
return nil
}
2016-07-28 17:54:02 +03:00
return err
}
c := &conn{conn: tcpc, addr: tcpc.RemoteAddr().String(),
wr: NewWriter(tcpc), rd: NewReader(tcpc)}
2016-09-07 08:56:44 +03:00
s.mu.Lock()
s.conns[c] = true
s.mu.Unlock()
if s.accept != nil && !s.accept(c) {
2016-09-07 08:56:44 +03:00
s.mu.Lock()
delete(s.conns, c)
s.mu.Unlock()
2016-08-20 18:19:08 +03:00
c.Close()
2016-07-28 17:54:02 +03:00
continue
}
go handle(s, c)
2016-08-20 18:19:08 +03:00
}
}
2016-08-21 19:34:23 +03:00
// handle manages the server connection.
func handle(s *Server, c *conn) {
2016-08-20 18:19:08 +03:00
var err error
defer func() {
if err != errDetached {
// do not close the connection when a detach is detected.
c.conn.Close()
}
2016-08-21 19:34:23 +03:00
func() {
// remove the conn from the server
2016-08-21 19:34:23 +03:00
s.mu.Lock()
defer s.mu.Unlock()
delete(s.conns, c)
if s.closed != nil {
2016-08-21 19:34:23 +03:00
if err == io.EOF {
err = nil
}
s.closed(c, err)
2016-09-07 08:56:44 +03:00
}
2016-08-21 19:34:23 +03:00
}()
2016-08-20 18:19:08 +03:00
}()
2016-08-20 18:19:08 +03:00
err = func() error {
// read commands and feed back to the client
2016-08-20 18:19:08 +03:00
for {
// read pipeline commands
cmds, err := c.rd.readCommands(nil)
2016-08-20 18:19:08 +03:00
if err != nil {
if err, ok := err.(*errProtocol); ok {
// All protocol errors should attempt a response to
// the client. Ignore write errors.
2016-08-20 18:19:08 +03:00
c.wr.WriteError("ERR " + err.Error())
c.wr.Flush()
2016-07-28 17:54:02 +03:00
}
2016-08-20 18:19:08 +03:00
return err
}
2016-09-23 00:32:11 +03:00
c.cmds = cmds
for len(c.cmds) > 0 {
cmd := c.cmds[0]
if len(c.cmds) == 1 {
c.cmds = nil
} else {
c.cmds = c.cmds[1:]
}
s.handler(c, cmd)
2016-08-20 18:19:08 +03:00
}
if c.detached {
// client has been detached
return errDetached
}
if c.closed {
return nil
2016-08-20 18:19:08 +03:00
}
if err := c.wr.Flush(); err != nil {
return err
}
}
}()
2016-07-28 17:54:02 +03:00
}
// conn represents a client connection
2016-08-20 18:19:08 +03:00
type conn struct {
conn *net.TCPConn
wr *Writer
rd *Reader
addr string
ctx interface{}
detached bool
closed bool
2016-09-23 00:32:11 +03:00
cmds []Command
2016-07-28 17:54:02 +03:00
}
2016-08-20 18:19:08 +03:00
func (c *conn) Close() error {
c.closed = true
return c.conn.Close()
}
func (c *conn) Context() interface{} { return c.ctx }
func (c *conn) SetContext(v interface{}) { c.ctx = v }
func (c *conn) SetReadBuffer(n int) {}
func (c *conn) WriteString(str string) { c.wr.WriteString(str) }
func (c *conn) WriteBulk(bulk []byte) { c.wr.WriteBulk(bulk) }
func (c *conn) WriteBulkString(bulk string) { c.wr.WriteBulkString(bulk) }
func (c *conn) WriteInt(num int) { c.wr.WriteInt(num) }
func (c *conn) WriteError(msg string) { c.wr.WriteError(msg) }
func (c *conn) WriteArray(count int) { c.wr.WriteArray(count) }
func (c *conn) WriteNull() { c.wr.WriteNull() }
2016-09-22 23:11:27 +03:00
func (c *conn) WriteRaw(data []byte) { c.wr.WriteRaw(data) }
func (c *conn) RemoteAddr() string { return c.addr }
2016-09-23 00:32:11 +03:00
func (c *conn) ReadPipeline() []Command {
cmds := c.cmds
c.cmds = nil
return cmds
}
// DetachedConn represents a connection that is detached from the server
type DetachedConn interface {
// Conn is the original connection
Conn
// ReadCommand reads the next client command.
ReadCommand() (Command, error)
// Flush flushes any writes to the network.
Flush() error
2016-08-29 17:01:30 +03:00
}
// Detach removes the current connection from the server loop and returns
// a detached connection. This is useful for operations such as PubSub.
// The detached connection must be closed by calling Close() when done.
// All writes such as WriteString() will not be written to the client
// until Flush() is called.
func (c *conn) Detach() DetachedConn {
c.detached = true
2016-09-23 00:32:11 +03:00
cmds := c.cmds
c.cmds = nil
return &detachedConn{conn: c, cmds: cmds}
2016-07-28 17:54:02 +03:00
}
type detachedConn struct {
*conn
2016-09-23 00:32:11 +03:00
cmds []Command
2016-07-28 17:54:02 +03:00
}
// Flush writes and Write* calls to the client.
func (dc *detachedConn) Flush() error {
return dc.conn.wr.Flush()
2016-08-23 15:54:17 +03:00
}
// ReadCommand read the next command from the client.
func (dc *detachedConn) ReadCommand() (Command, error) {
if dc.closed {
return Command{}, errors.New("closed")
}
2016-09-23 00:32:11 +03:00
if len(dc.cmds) > 0 {
cmd := dc.cmds[0]
if len(dc.cmds) == 1 {
dc.cmds = nil
} else {
dc.cmds = dc.cmds[1:]
}
return cmd, nil
}
cmd, err := dc.rd.ReadCommand()
if err != nil {
return Command{}, err
}
return cmd, nil
2016-07-28 17:54:02 +03:00
}
// Command represent a command
type Command struct {
// Raw is a encoded RESP message.
Raw []byte
// Args is a series of arguments that make up the command.
Args [][]byte
2016-07-28 17:54:02 +03:00
}
// Server defines a server for clients for managing client connections.
type Server struct {
mu sync.Mutex
addr string
handler func(conn Conn, cmd Command)
accept func(conn Conn) bool
closed func(conn Conn, err error)
conns map[*conn]bool
ln *net.TCPListener
done bool
}
// Writer allows for writing RESP messages.
type Writer struct {
w io.Writer
b []byte
}
// NewWriter creates a new RESP writer.
func NewWriter(wr io.Writer) *Writer {
return &Writer{
w: wr,
}
2016-07-28 17:54:02 +03:00
}
// WriteNull writes a null to the client
func (w *Writer) WriteNull() {
w.b = append(w.b, '$', '-', '1', '\r', '\n')
2016-07-28 17:54:02 +03:00
}
// WriteArray writes an array header. You must then write addtional
// sub-responses to the client to complete the response.
// For example to write two strings:
//
// c.WriteArray(2)
// c.WriteBulk("item 1")
// c.WriteBulk("item 2")
func (w *Writer) WriteArray(count int) {
w.b = append(w.b, '*')
w.b = append(w.b, strconv.FormatInt(int64(count), 10)...)
w.b = append(w.b, '\r', '\n')
2016-07-28 17:54:02 +03:00
}
// WriteBulk writes bulk bytes to the client.
func (w *Writer) WriteBulk(bulk []byte) {
w.b = append(w.b, '$')
w.b = append(w.b, strconv.FormatInt(int64(len(bulk)), 10)...)
w.b = append(w.b, '\r', '\n')
w.b = append(w.b, bulk...)
w.b = append(w.b, '\r', '\n')
2016-08-20 20:42:34 +03:00
}
// WriteBulkString writes a bulk string to the client.
func (w *Writer) WriteBulkString(bulk string) {
w.b = append(w.b, '$')
w.b = append(w.b, strconv.FormatInt(int64(len(bulk)), 10)...)
w.b = append(w.b, '\r', '\n')
w.b = append(w.b, bulk...)
w.b = append(w.b, '\r', '\n')
}
// Flush writes all unflushed Write* calls to the underlying writer.
func (w *Writer) Flush() error {
if _, err := w.w.Write(w.b); err != nil {
return err
}
w.b = w.b[:0]
return nil
}
// WriteError writes an error to the client.
func (w *Writer) WriteError(msg string) {
w.b = append(w.b, '-')
w.b = append(w.b, msg...)
w.b = append(w.b, '\r', '\n')
}
// WriteString writes a string to the client.
func (w *Writer) WriteString(msg string) {
if msg == "OK" {
w.b = append(w.b, '+', 'O', 'K', '\r', '\n')
} else {
w.b = append(w.b, '+')
w.b = append(w.b, []byte(msg)...)
w.b = append(w.b, '\r', '\n')
}
}
// WriteInt writes an integer to the client.
func (w *Writer) WriteInt(num int) {
w.b = append(w.b, ':')
w.b = append(w.b, []byte(strconv.FormatInt(int64(num), 10))...)
w.b = append(w.b, '\r', '\n')
}
2016-07-28 17:54:02 +03:00
2016-09-22 23:11:27 +03:00
// WriteRaw writes raw data to the client.
func (w *Writer) WriteRaw(data []byte) {
w.b = append(w.b, data...)
}
// Reader represent a reader for RESP or telnet commands.
type Reader struct {
rd *bufio.Reader
2016-09-07 16:19:55 +03:00
buf []byte
start int
end int
cmds []Command
2016-07-28 17:54:02 +03:00
}
// NewReader returns a command reader which will read RESP or telnet commands.
func NewReader(rd io.Reader) *Reader {
return &Reader{
rd: bufio.NewReader(rd),
buf: make([]byte, 4096),
}
}
func parseInt(b []byte) (int, error) {
// shortcut atoi for 0-99. fails for negative numbers.
switch len(b) {
case 1:
if b[0] >= '0' && b[0] <= '9' {
return int(b[0] - '0'), nil
}
case 2:
if b[0] >= '0' && b[0] <= '9' && b[1] >= '0' && b[1] <= '9' {
return int(b[0]-'0')*10 + int(b[1]-'0'), nil
}
2016-07-28 17:54:02 +03:00
}
// fallback to standard library
n, err := strconv.ParseUint(string(b), 10, 64)
return int(n), err
2016-07-28 17:54:02 +03:00
}
func (rd *Reader) readCommands(leftover *int) ([]Command, error) {
var cmds []Command
b := rd.buf[rd.start:rd.end]
if len(b) > 0 {
// we have data, yay!
// but is this enough data for a complete command? or multiple?
2016-07-28 17:54:02 +03:00
next:
2016-09-07 08:56:44 +03:00
switch b[0] {
2016-07-28 17:54:02 +03:00
default:
// just a plain text command
2016-09-07 08:56:44 +03:00
for i := 0; i < len(b); i++ {
if b[i] == '\n' {
2016-07-28 17:54:02 +03:00
var line []byte
2016-09-07 08:56:44 +03:00
if i > 0 && b[i-1] == '\r' {
line = b[:i-1]
2016-07-28 17:54:02 +03:00
} else {
2016-09-07 08:56:44 +03:00
line = b[:i]
2016-07-28 17:54:02 +03:00
}
var cmd Command
2016-07-28 17:54:02 +03:00
var quote bool
var escape bool
outer:
for {
nline := make([]byte, 0, len(line))
for i := 0; i < len(line); i++ {
c := line[i]
if !quote {
if c == ' ' {
if len(nline) > 0 {
cmd.Args = append(cmd.Args, nline)
2016-07-28 17:54:02 +03:00
}
line = line[i+1:]
continue outer
}
if c == '"' {
if i != 0 {
return nil, errUnbalancedQuotes
}
quote = true
line = line[i+1:]
continue outer
}
} else {
if escape {
escape = false
switch c {
case 'n':
c = '\n'
case 'r':
c = '\r'
case 't':
c = '\t'
}
} else if c == '"' {
quote = false
cmd.Args = append(cmd.Args, nline)
2016-07-28 17:54:02 +03:00
line = line[i+1:]
if len(line) > 0 && line[0] != ' ' {
return nil, errUnbalancedQuotes
}
continue outer
} else if c == '\\' {
escape = true
continue
}
}
nline = append(nline, c)
}
if quote {
return nil, errUnbalancedQuotes
}
if len(line) > 0 {
cmd.Args = append(cmd.Args, line)
2016-07-28 17:54:02 +03:00
}
break
}
if len(cmd.Args) > 0 {
// convert this to resp command syntax
var wr Writer
wr.WriteArray(len(cmd.Args))
for i := range cmd.Args {
wr.WriteBulk(cmd.Args[i])
cmd.Args[i] = append([]byte(nil), cmd.Args[i]...)
}
cmd.Raw = wr.b
cmds = append(cmds, cmd)
2016-07-28 17:54:02 +03:00
}
2016-09-07 08:56:44 +03:00
b = b[i+1:]
if len(b) > 0 {
2016-07-28 17:54:02 +03:00
goto next
} else {
goto done
}
}
}
case '*':
// resp formatted command
marks := make([]int, 0, 16)
2016-07-28 17:54:02 +03:00
outer2:
for i := 1; i < len(b); i++ {
2016-09-07 08:56:44 +03:00
if b[i] == '\n' {
if b[i-1] != '\r' {
2016-07-28 17:54:02 +03:00
return nil, errInvalidMultiBulkLength
}
count, err := parseInt(b[1 : i-1])
if err != nil || count <= 0 {
2016-07-28 17:54:02 +03:00
return nil, errInvalidMultiBulkLength
}
marks = marks[:0]
for j := 0; j < count; j++ {
2016-07-28 17:54:02 +03:00
// read bulk length
i++
2016-09-07 08:56:44 +03:00
if i < len(b) {
if b[i] != '$' {
2016-07-28 17:54:02 +03:00
return nil, &errProtocol{"expected '$', got '" +
2016-09-07 08:56:44 +03:00
string(b[i]) + "'"}
2016-07-28 17:54:02 +03:00
}
si := i
2016-09-07 08:56:44 +03:00
for ; i < len(b); i++ {
if b[i] == '\n' {
if b[i-1] != '\r' {
2016-07-28 17:54:02 +03:00
return nil, errInvalidBulkLength
}
size, err := parseInt(b[si+1 : i-1])
if err != nil || size < 0 {
2016-07-28 17:54:02 +03:00
return nil, errInvalidBulkLength
}
if i+size+2 >= len(b) {
2016-07-28 17:54:02 +03:00
// not ready
break outer2
}
if b[i+size+2] != '\n' ||
b[i+size+1] != '\r' {
2016-07-28 17:54:02 +03:00
return nil, errInvalidBulkLength
}
2016-09-07 08:56:44 +03:00
i++
marks = append(marks, i, i+size)
i += size + 1
2016-07-28 17:54:02 +03:00
break
}
}
}
}
if len(marks) == count*2 {
var cmd Command
if rd.rd != nil {
// make a raw copy of the entire command when
// there's a underlying reader.
cmd.Raw = append([]byte(nil), b[:i+1]...)
} else {
// just assign the slice
cmd.Raw = b[:i+1]
}
cmd.Args = make([][]byte, len(marks)/2)
// slice up the raw command into the args based on
// the recorded marks.
for h := 0; h < len(marks); h += 2 {
cmd.Args[h/2] = cmd.Raw[marks[h]:marks[h+1]]
}
cmds = append(cmds, cmd)
2016-09-07 08:56:44 +03:00
b = b[i+1:]
if len(b) > 0 {
2016-07-28 17:54:02 +03:00
goto next
} else {
goto done
}
}
}
}
}
done:
rd.start = rd.end - len(b)
2016-07-28 17:54:02 +03:00
}
if leftover != nil {
*leftover = rd.end - rd.start
2016-08-20 19:23:39 +03:00
}
if len(cmds) > 0 {
return cmds, nil
2016-07-28 17:54:02 +03:00
}
if rd.rd == nil {
return nil, errIncompleteCommand
2016-09-07 08:56:44 +03:00
}
if rd.end == len(rd.buf) {
// at the end of the buffer.
if rd.start == rd.end {
// rewind the to the beginning
rd.start, rd.end = 0, 0
} else {
// must grow the buffer
newbuf := make([]byte, len(rd.buf)*2)
copy(newbuf, rd.buf)
rd.buf = newbuf
2016-09-07 08:56:44 +03:00
}
2016-07-28 17:54:02 +03:00
}
n, err := rd.rd.Read(rd.buf[rd.end:])
if err != nil {
return nil, err
2016-08-23 15:54:17 +03:00
}
rd.end += n
return rd.readCommands(leftover)
2016-07-28 17:54:02 +03:00
}
// ReadCommand reads the next command.
func (rd *Reader) ReadCommand() (Command, error) {
if len(rd.cmds) > 0 {
cmd := rd.cmds[0]
rd.cmds = rd.cmds[1:]
return cmd, nil
2016-07-28 17:54:02 +03:00
}
cmds, err := rd.readCommands(nil)
if err != nil {
return Command{}, err
2016-07-28 17:54:02 +03:00
}
rd.cmds = cmds
return rd.ReadCommand()
2016-07-28 17:54:02 +03:00
}
// Parse parses a raw RESP message and returns a command.
func Parse(raw []byte) (Command, error) {
rd := Reader{buf: raw, end: len(raw)}
var leftover int
cmds, err := rd.readCommands(&leftover)
if err != nil {
return Command{}, err
2016-09-07 08:56:44 +03:00
}
if leftover > 0 {
return Command{}, errTooMuchData
2016-07-28 17:54:02 +03:00
}
return cmds[0], nil
2016-07-28 17:54:02 +03:00
}