package redcon import ( "errors" "io" "net" "strconv" "sync" ) type Conn interface { RemoteAddr() string Close() error WriteError(msg string) WriteString(str string) WriteBulk(bulk string) WriteInt(num int) WriteArray(count int) WriteNull() } var ( errUnbalancedQuotes = &errProtocol{"unbalanced quotes in request"} errInvalidBulkLength = &errProtocol{"invalid bulk length"} errInvalidMultiBulkLength = &errProtocol{"invalid multibulk length"} ) type errProtocol struct { msg string } func (err *errProtocol) Error() string { return "Protocol error: " + err.msg } func ListenAndServe( addr string, handler func(conn Conn, cmds [][]string), accept func(conn Conn) bool, closed func(conn Conn, err error), ) error { ln, err := net.Listen("tcp", addr) if err != nil { return err } defer ln.Close() tcpln := ln.(*net.TCPListener) if handler == nil { handler = func(conn Conn, cmds [][]string) {} } var mu sync.Mutex for { tcpc, err := tcpln.AcceptTCP() if err != nil { return err } c := &conn{ tcpc, newWriter(tcpc), newReader(tcpc), tcpc.RemoteAddr().String(), } if accept != nil && !accept(c) { c.Close() continue } go handle(c, &mu, handler, closed) } } func handle(c *conn, mu *sync.Mutex, handler func(conn Conn, cmds [][]string), closed func(conn Conn, err error)) { var err error defer func() { c.conn.Close() if closed != nil { mu.Lock() defer mu.Unlock() if err == io.EOF { err = nil } closed(c, err) } }() err = func() error { for { cmds, err := c.rd.ReadCommands() if err != nil { if err, ok := err.(*errProtocol); ok { // All protocol errors should attempt a response to // the client. Ignore errors. c.wr.WriteError("ERR " + err.Error()) c.wr.Flush() } return err } if len(cmds) > 0 { handler(c, cmds) } if c.wr.err != nil { if c.wr.err == errClosed { return nil } return c.wr.err } if err := c.wr.Flush(); err != nil { return err } } }() } type conn struct { conn *net.TCPConn wr *writer rd *reader addr string } func (c *conn) Close() error { return c.wr.Close() } func (c *conn) WriteString(str string) { c.wr.WriteString(str) } func (c *conn) WriteBulk(bulk string) { c.wr.WriteBulk(bulk) } func (c *conn) WriteInt(num int) { c.wr.WriteInt(num) } func (c *conn) WriteError(msg string) { c.wr.WriteError(msg) } func (c *conn) WriteArray(count int) { c.wr.WriteMultiBulkStart(count) } func (c *conn) WriteNull() { c.wr.WriteNull() } func (c *conn) RemoteAddr() string { return c.addr } // Reader represents a RESP command reader. type reader struct { r *net.TCPConn // base reader b []byte // unprocessed bytes a []byte // static read buffer } const buflen = 1024 * 8 // NewReader returns a RESP command reader. func newReader(r *net.TCPConn) *reader { return &reader{ r: r, a: make([]byte, buflen), } } // ReadCommands reads one or more commands from the reader. func (r *reader) ReadCommands() ([][]string, error) { if len(r.b) > 0 { // we have some potential commands. var cmds [][]string next: switch r.b[0] { default: // just a plain text command for i := 0; i < len(r.b); i++ { if r.b[i] == '\n' { var line []byte if i > 0 && r.b[i-1] == '\r' { line = r.b[:i-1] } else { line = r.b[:i] } var args []string var quote bool var escape bool outer: for { nline := make([]byte, 0, len(line)) for i := 0; i < len(line); i++ { c := line[i] if !quote { if c == ' ' { if len(nline) > 0 { args = append(args, string(nline)) } line = line[i+1:] continue outer } if c == '"' { if i != 0 { return nil, errUnbalancedQuotes } quote = true line = line[i+1:] continue outer } } else { if escape { escape = false switch c { case 'n': c = '\n' case 'r': c = '\r' case 't': c = '\t' } } else if c == '"' { quote = false args = append(args, string(nline)) line = line[i+1:] if len(line) > 0 && line[0] != ' ' { return nil, errUnbalancedQuotes } continue outer } else if c == '\\' { escape = true continue } } nline = append(nline, c) } if quote { return nil, errUnbalancedQuotes } if len(line) > 0 { args = append(args, string(line)) } break } if len(args) > 0 { cmds = append(cmds, args) } r.b = r.b[i+1:] if len(r.b) > 0 { goto next } else { goto done } } } case '*': // resp formatted command var si int outer2: for i := 0; i < len(r.b); i++ { var args []string if r.b[i] == '\n' { if r.b[i-1] != '\r' { return nil, errInvalidMultiBulkLength } ni, err := strconv.ParseInt(string(r.b[si+1:i-1]), 10, 64) if err != nil || ni <= 0 { return nil, errInvalidMultiBulkLength } args = make([]string, 0, int(ni)) for j := 0; j < int(ni); j++ { // read bulk length i++ if i < len(r.b) { if r.b[i] != '$' { return nil, &errProtocol{"expected '$', got '" + string(r.b[i]) + "'"} } si = i for ; i < len(r.b); i++ { if r.b[i] == '\n' { if r.b[i-1] != '\r' { return nil, errInvalidBulkLength } s := string(r.b[si+1 : i-1]) ni2, err := strconv.ParseInt(s, 10, 64) if err != nil || ni2 < 0 { return nil, errInvalidBulkLength } if i+int(ni2)+2 >= len(r.b) { // not ready break outer2 } if r.b[i+int(ni2)+2] != '\n' || r.b[i+int(ni2)+1] != '\r' { return nil, errInvalidBulkLength } arg := string(r.b[i+1 : i+1+int(ni2)]) i += int(ni2) + 2 args = append(args, arg) break } } } } if len(args) == cap(args) { cmds = append(cmds, args) r.b = r.b[i+1:] if len(r.b) > 0 { goto next } else { goto done } } } } } done: if len(r.b) == 0 { r.b = nil } if len(cmds) > 0 { return cmds, nil } } if len(r.a) == 0 { r.a = make([]byte, buflen) } n, err := r.r.Read(r.a) if err != nil { if err == io.EOF { if len(r.b) > 0 { return nil, io.ErrUnexpectedEOF } } return nil, err } if len(r.b) == 0 { r.b = r.a[:n] } else { r.b = append(r.b, r.a[:n]...) } r.a = r.a[n:] return r.ReadCommands() } var errClosed = errors.New("closed") type writer struct { w *net.TCPConn b []byte err error } func newWriter(w *net.TCPConn) *writer { return &writer{w: w, b: make([]byte, 0, 256)} } func (w *writer) WriteNull() error { if w.err != nil { return w.err } w.b = append(w.b, '$', '-', '1', '\r', '\n') return nil } func (w *writer) WriteMultiBulkStart(count int) error { if w.err != nil { return w.err } w.b = append(w.b, '*') w.b = append(w.b, []byte(strconv.FormatInt(int64(count), 10))...) w.b = append(w.b, '\r', '\n') return nil } func (w *writer) WriteBulk(bulk string) error { if w.err != nil { return w.err } w.b = append(w.b, '$') w.b = append(w.b, []byte(strconv.FormatInt(int64(len(bulk)), 10))...) w.b = append(w.b, '\r', '\n') w.b = append(w.b, []byte(bulk)...) w.b = append(w.b, '\r', '\n') return nil } func (w *writer) Flush() error { if w.err != nil { return w.err } if len(w.b) == 0 { return nil } if _, err := w.w.Write(w.b); err != nil { w.err = err return err } w.b = w.b[:0] return nil } func (w *writer) WriteMultiBulk(bulks []string) error { if err := w.WriteMultiBulkStart(len(bulks)); err != nil { return err } for _, bulk := range bulks { if err := w.WriteBulk(bulk); err != nil { return err } } return nil } func (w *writer) WriteError(msg string) error { if w.err != nil { return w.err } w.b = append(w.b, '-') w.b = append(w.b, []byte(msg)...) w.b = append(w.b, '\r', '\n') return nil } func (w *writer) WriteString(msg string) error { if w.err != nil { return w.err } w.b = append(w.b, '+') w.b = append(w.b, []byte(msg)...) w.b = append(w.b, '\r', '\n') return nil } func (w *writer) WriteInt(num int) error { if w.err != nil { return w.err } w.b = append(w.b, ':') w.b = append(w.b, []byte(strconv.FormatInt(int64(num), 10))...) w.b = append(w.b, '\r', '\n') return nil } func (w *writer) Close() error { if w.err != nil { return w.err } if err := w.Flush(); err != nil { w.err = err return err } w.err = errClosed return nil }