// Package redcon implements a Redis compatible server framework package redcon import ( "bufio" "errors" "io" "net" "strconv" "sync" ) var ( errUnbalancedQuotes = &errProtocol{"unbalanced quotes in request"} errInvalidBulkLength = &errProtocol{"invalid bulk length"} errInvalidMultiBulkLength = &errProtocol{"invalid multibulk length"} errDetached = errors.New("detached") errIncompleteCommand = errors.New("incomplete command") errTooMuchData = errors.New("too much data") ) type errProtocol struct { msg string } func (err *errProtocol) Error() string { return "Protocol error: " + err.msg } // Conn represents a client connection type Conn interface { // RemoteAddr returns the remote address of the client connection. RemoteAddr() string // Close closes the connection. Close() error // WriteError writes an error to the client. WriteError(msg string) // WriteString writes a string to the client. WriteString(str string) // WriteBulk writes bulk bytes to the client. WriteBulk(bulk []byte) // WriteBulkString writes a bulk string to the client. WriteBulkString(bulk string) // WriteInt writes an integer to the client. WriteInt(num int) // WriteArray writes an array header. You must then write addtional // sub-responses to the client to complete the response. // For example to write two strings: // // c.WriteArray(2) // c.WriteBulk("item 1") // c.WriteBulk("item 2") WriteArray(count int) // WriteNull writes a null to the client WriteNull() // Context returns a user-defined context Context() interface{} // SetContext sets a user-defined context SetContext(v interface{}) // SetReadBuffer updates the buffer read size for the connection SetReadBuffer(bytes int) // Detach return a connection that is detached from the server. // Useful for operations like PubSub. // // dconn := conn.Detach() // go func(){ // defer dconn.Close() // cmd, err := dconn.ReadCommand() // if err != nil{ // fmt.Printf("read failed: %v\n", err) // return // } // fmt.Printf("received command: %v", cmd) // hconn.WriteString("OK") // if err := dconn.Flush(); err != nil{ // fmt.Printf("write failed: %v\n", err) // return // } // }() Detach() DetachedConn } // NewServer returns a new Redcon server. func NewServer(addr string, handler func(conn Conn, cmd Command), accept func(conn Conn) bool, closed func(conn Conn, err error), ) *Server { if handler == nil { panic("handler is nil") } s := &Server{ addr: addr, handler: handler, accept: accept, closed: closed, conns: make(map[*conn]bool), } return s } // Close stops listening on the TCP address. // Already Accepted connections will be closed. func (s *Server) Close() error { s.mu.Lock() defer s.mu.Unlock() if s.ln == nil { return errors.New("not serving") } s.done = true return s.ln.Close() } // ListenAndServe serves incoming connections. func (s *Server) ListenAndServe() error { return s.ListenServeAndSignal(nil) } // ListenAndServe creates a new server and binds to addr. func ListenAndServe(addr string, handler func(conn Conn, cmd Command), accept func(conn Conn) bool, closed func(conn Conn, err error), ) error { return NewServer(addr, handler, accept, closed).ListenAndServe() } // ListenServeAndSignal serves incoming connections and passes nil or error // when listening. signal can be nil. func (s *Server) ListenServeAndSignal(signal chan error) error { var addr = s.addr ln, err := net.Listen("tcp", addr) if err != nil { if signal != nil { signal <- err } return err } if signal != nil { signal <- nil } tln := ln.(*net.TCPListener) s.mu.Lock() s.ln = tln s.mu.Unlock() defer func() { ln.Close() func() { s.mu.Lock() defer s.mu.Unlock() for c := range s.conns { c.Close() } s.conns = nil }() }() for { tcpc, err := tln.AcceptTCP() if err != nil { s.mu.Lock() done := s.done s.mu.Unlock() if done { return nil } return err } c := &conn{conn: tcpc, addr: tcpc.RemoteAddr().String(), wr: NewWriter(tcpc), rd: NewReader(tcpc)} s.mu.Lock() s.conns[c] = true s.mu.Unlock() if s.accept != nil && !s.accept(c) { s.mu.Lock() delete(s.conns, c) s.mu.Unlock() c.Close() continue } go handle(s, c) } } // handle manages the server connection. func handle(s *Server, c *conn) { var err error defer func() { if err != errDetached { // do not close the connection when a detach is detected. c.conn.Close() } func() { // remove the conn from the server s.mu.Lock() defer s.mu.Unlock() delete(s.conns, c) if s.closed != nil { if err == io.EOF { err = nil } s.closed(c, err) } }() }() err = func() error { // read commands and feed back to the client for { // read pipeline commands cmds, err := c.rd.readCommands(nil) if err != nil { if err, ok := err.(*errProtocol); ok { // All protocol errors should attempt a response to // the client. Ignore write errors. c.wr.WriteError("ERR " + err.Error()) c.wr.Flush() } return err } for _, cmd := range cmds { s.handler(c, cmd) } if c.detached { // client has been detached return errDetached } if c.closed { return nil } if err := c.wr.Flush(); err != nil { return err } } }() } // conn represents a client connection type conn struct { conn *net.TCPConn wr *Writer rd *Reader addr string ctx interface{} detached bool closed bool } func (c *conn) Close() error { c.closed = true return c.conn.Close() } func (c *conn) Context() interface{} { return c.ctx } func (c *conn) SetContext(v interface{}) { c.ctx = v } func (c *conn) SetReadBuffer(n int) {} func (c *conn) WriteString(str string) { c.wr.WriteString(str) } func (c *conn) WriteBulk(bulk []byte) { c.wr.WriteBulk(bulk) } func (c *conn) WriteBulkString(bulk string) { c.wr.WriteBulkString(bulk) } func (c *conn) WriteInt(num int) { c.wr.WriteInt(num) } func (c *conn) WriteError(msg string) { c.wr.WriteError(msg) } func (c *conn) WriteArray(count int) { c.wr.WriteArray(count) } func (c *conn) WriteNull() { c.wr.WriteNull() } func (c *conn) RemoteAddr() string { return c.addr } // DetachedConn represents a connection that is detached from the server type DetachedConn interface { // Conn is the original connection Conn // ReadCommand reads the next client command. ReadCommand() (Command, error) // Flush flushes any writes to the network. Flush() error } // Detach removes the current connection from the server loop and returns // a detached connection. This is useful for operations such as PubSub. // The detached connection must be closed by calling Close() when done. // All writes such as WriteString() will not be written to the client // until Flush() is called. func (c *conn) Detach() DetachedConn { c.detached = true return &detachedConn{conn: c} } type detachedConn struct { *conn } // Flush writes and Write* calls to the client. func (dc *detachedConn) Flush() error { return dc.conn.wr.Flush() } // ReadCommand read the next command from the client. func (dc *detachedConn) ReadCommand() (Command, error) { if dc.closed { return Command{}, errors.New("closed") } cmd, err := dc.rd.ReadCommand() if err != nil { return Command{}, err } return cmd, nil } // Command represent a command type Command struct { // Raw is a encoded RESP message. Raw []byte // Args is a series of arguments that make up the command. Args [][]byte } // Server defines a server for clients for managing client connections. type Server struct { mu sync.Mutex addr string handler func(conn Conn, cmd Command) accept func(conn Conn) bool closed func(conn Conn, err error) conns map[*conn]bool ln *net.TCPListener done bool } // Writer allows for writing RESP messages. type Writer struct { w io.Writer b []byte } // NewWriter creates a new RESP writer. func NewWriter(wr io.Writer) *Writer { return &Writer{ w: wr, } } // WriteNull writes a null to the client func (w *Writer) WriteNull() { w.b = append(w.b, '$', '-', '1', '\r', '\n') } // WriteArray writes an array header. You must then write addtional // sub-responses to the client to complete the response. // For example to write two strings: // // c.WriteArray(2) // c.WriteBulk("item 1") // c.WriteBulk("item 2") func (w *Writer) WriteArray(count int) { w.b = append(w.b, '*') w.b = append(w.b, strconv.FormatInt(int64(count), 10)...) w.b = append(w.b, '\r', '\n') } // WriteBulk writes bulk bytes to the client. func (w *Writer) WriteBulk(bulk []byte) { w.b = append(w.b, '$') w.b = append(w.b, strconv.FormatInt(int64(len(bulk)), 10)...) w.b = append(w.b, '\r', '\n') w.b = append(w.b, bulk...) w.b = append(w.b, '\r', '\n') } // WriteBulkString writes a bulk string to the client. func (w *Writer) WriteBulkString(bulk string) { w.b = append(w.b, '$') w.b = append(w.b, strconv.FormatInt(int64(len(bulk)), 10)...) w.b = append(w.b, '\r', '\n') w.b = append(w.b, bulk...) w.b = append(w.b, '\r', '\n') } // Flush writes all unflushed Write* calls to the underlying writer. func (w *Writer) Flush() error { if _, err := w.w.Write(w.b); err != nil { return err } w.b = w.b[:0] return nil } // WriteError writes an error to the client. func (w *Writer) WriteError(msg string) { w.b = append(w.b, '-') w.b = append(w.b, msg...) w.b = append(w.b, '\r', '\n') } // WriteString writes a string to the client. func (w *Writer) WriteString(msg string) { if msg == "OK" { w.b = append(w.b, '+', 'O', 'K', '\r', '\n') } else { w.b = append(w.b, '+') w.b = append(w.b, []byte(msg)...) w.b = append(w.b, '\r', '\n') } } // WriteInt writes an integer to the client. func (w *Writer) WriteInt(num int) { w.b = append(w.b, ':') w.b = append(w.b, []byte(strconv.FormatInt(int64(num), 10))...) w.b = append(w.b, '\r', '\n') } // Reader represent a reader for RESP or telnet commands. type Reader struct { rd *bufio.Reader buf []byte start int end int cmds []Command } // NewReader returns a command reader which will read RESP or telnet commands. func NewReader(rd io.Reader) *Reader { return &Reader{ rd: bufio.NewReader(rd), buf: make([]byte, 4096), } } func parseInt(b []byte) (int, error) { // shortcut atoi for 0-99. fails for negative numbers. switch len(b) { case 1: if b[0] >= '0' && b[0] <= '9' { return int(b[0] - '0'), nil } case 2: if b[0] >= '0' && b[0] <= '9' && b[1] >= '0' && b[1] <= '9' { return int(b[0]-'0')*10 + int(b[1]-'0'), nil } } // fallback to standard library n, err := strconv.ParseUint(string(b), 10, 64) return int(n), err } func (rd *Reader) readCommands(leftover *int) ([]Command, error) { var cmds []Command b := rd.buf[rd.start:rd.end] if len(b) > 0 { // we have data, yay! // but is this enough data for a complete command? or multiple? next: switch b[0] { default: // just a plain text command for i := 0; i < len(b); i++ { if b[i] == '\n' { var line []byte if i > 0 && b[i-1] == '\r' { line = b[:i-1] } else { line = b[:i] } var cmd Command var quote bool var escape bool outer: for { nline := make([]byte, 0, len(line)) for i := 0; i < len(line); i++ { c := line[i] if !quote { if c == ' ' { if len(nline) > 0 { cmd.Args = append(cmd.Args, nline) } line = line[i+1:] continue outer } if c == '"' { if i != 0 { return nil, errUnbalancedQuotes } quote = true line = line[i+1:] continue outer } } else { if escape { escape = false switch c { case 'n': c = '\n' case 'r': c = '\r' case 't': c = '\t' } } else if c == '"' { quote = false cmd.Args = append(cmd.Args, nline) line = line[i+1:] if len(line) > 0 && line[0] != ' ' { return nil, errUnbalancedQuotes } continue outer } else if c == '\\' { escape = true continue } } nline = append(nline, c) } if quote { return nil, errUnbalancedQuotes } if len(line) > 0 { cmd.Args = append(cmd.Args, line) } break } if len(cmd.Args) > 0 { // convert this to resp command syntax var wr Writer wr.WriteArray(len(cmd.Args)) for i := range cmd.Args { wr.WriteBulk(cmd.Args[i]) cmd.Args[i] = append([]byte(nil), cmd.Args[i]...) } cmd.Raw = wr.b cmds = append(cmds, cmd) } b = b[i+1:] if len(b) > 0 { goto next } else { goto done } } } case '*': // resp formatted command marks := make([]int, 0, 16) outer2: for i := 1; i < len(b); i++ { if b[i] == '\n' { if b[i-1] != '\r' { return nil, errInvalidMultiBulkLength } count, err := parseInt(b[1 : i-1]) if err != nil || count <= 0 { return nil, errInvalidMultiBulkLength } marks = marks[:0] for j := 0; j < count; j++ { // read bulk length i++ if i < len(b) { if b[i] != '$' { return nil, &errProtocol{"expected '$', got '" + string(b[i]) + "'"} } si := i for ; i < len(b); i++ { if b[i] == '\n' { if b[i-1] != '\r' { return nil, errInvalidBulkLength } size, err := parseInt(b[si+1 : i-1]) if err != nil || size < 0 { return nil, errInvalidBulkLength } if i+size+2 >= len(b) { // not ready break outer2 } if b[i+size+2] != '\n' || b[i+size+1] != '\r' { return nil, errInvalidBulkLength } i++ marks = append(marks, i, i+size) i += size + 1 break } } } } if len(marks) == count*2 { var cmd Command if rd.rd != nil { // make a raw copy of the entire command when // there's a underlying reader. cmd.Raw = append([]byte(nil), b[:i+1]...) } else { // just assign the slice cmd.Raw = b[:i+1] } cmd.Args = make([][]byte, len(marks)/2) // slice up the raw command into the args based on // the recorded marks. for h := 0; h < len(marks); h += 2 { cmd.Args[h/2] = cmd.Raw[marks[h]:marks[h+1]] } cmds = append(cmds, cmd) b = b[i+1:] if len(b) > 0 { goto next } else { goto done } } } } } done: rd.start = rd.end - len(b) } if leftover != nil { *leftover = rd.end - rd.start } if len(cmds) > 0 { return cmds, nil } if rd.rd == nil { return nil, errIncompleteCommand } if rd.end == len(rd.buf) { // at the end of the buffer. if rd.start == rd.end { // rewind the to the beginning rd.start, rd.end = 0, 0 } else { // must grow the buffer newbuf := make([]byte, len(rd.buf)*2) copy(newbuf, rd.buf) rd.buf = newbuf } } n, err := rd.rd.Read(rd.buf[rd.end:]) if err != nil { return nil, err } rd.end += n return rd.readCommands(leftover) } // ReadCommand reads the next command. func (rd *Reader) ReadCommand() (Command, error) { if len(rd.cmds) > 0 { cmd := rd.cmds[0] rd.cmds = rd.cmds[1:] return cmd, nil } cmds, err := rd.readCommands(nil) if err != nil { return Command{}, err } rd.cmds = cmds return rd.ReadCommand() } // Parse parses a raw RESP message and returns a command. func Parse(raw []byte) (Command, error) { rd := Reader{buf: raw, end: len(raw)} var leftover int cmds, err := rd.readCommands(&leftover) if err != nil { return Command{}, err } if leftover > 0 { return Command{}, errTooMuchData } return cmds[0], nil }