diff --git a/redcon.go b/redcon.go index 1a3ad98..a18eae6 100644 --- a/redcon.go +++ b/redcon.go @@ -42,103 +42,108 @@ func ListenAndServe( return err } defer ln.Close() + if handler == nil { + handler = func(conn Conn, cmds [][]string) {} + } var mu sync.Mutex for { - conn, err := ln.Accept() + nc, err := ln.Accept() if err != nil { return err } - wr := newWriter(conn) - wrc := &connWriter{wr, conn.RemoteAddr().String()} - if accept != nil && !accept(wrc) { - conn.Close() + tcpc := nc.(*net.TCPConn) + c := &conn{tcpc, newWriter(tcpc), tcpc.RemoteAddr().String()} + if accept != nil && !accept(c) { + c.Close() continue } - go func() { - var err error - defer func() { - conn.Close() - if closed != nil { - mu.Lock() - defer mu.Unlock() - if err == io.EOF { - err = nil - } - closed(wrc, err) - } - }() - rd := newReader(conn) - err = func() error { - for { - cmds, err := rd.ReadCommands() - if err != nil { - if err, ok := err.(*errProtocol); ok { - // All protocol errors should attempt a response to - // the client. Ignore errors. - wr.WriteError("ERR " + err.Error()) - wr.Flush() - } - return err - } - if len(cmds) > 0 { - if handler != nil { - handler(wrc, cmds) - } - } - if wr.err != nil { - if wr.err == errClosed { - return nil - } - return wr.err - } - if err := wr.Flush(); err != nil { - return err - } - } - }() - }() + go handle(c, &mu, handler, closed) } } +func handle(c *conn, mu *sync.Mutex, + handler func(conn Conn, cmds [][]string), + closed func(conn Conn, err error)) { + var err error + defer func() { + c.conn.Close() + if closed != nil { + mu.Lock() + defer mu.Unlock() + if err == io.EOF { + err = nil + } + closed(c, err) + } + }() + rd := newReader(c.conn) + err = func() error { + for { + cmds, err := rd.ReadCommands() + if err != nil { + if err, ok := err.(*errProtocol); ok { + // All protocol errors should attempt a response to + // the client. Ignore errors. + c.wr.WriteError("ERR " + err.Error()) + c.wr.Flush() + } + return err + } + if len(cmds) > 0 { + handler(c, cmds) + } + if c.wr.err != nil { + if c.wr.err == errClosed { + return nil + } + return c.wr.err + } + if err := c.wr.Flush(); err != nil { + return err + } + } + }() +} -type connWriter struct { - wr *respWriter +type conn struct { + conn *net.TCPConn + wr *writer addr string } -func (wrc *connWriter) Close() error { - return wrc.wr.Close() +func (c *conn) Close() error { + return c.wr.Close() } -func (wrc *connWriter) WriteString(str string) { - wrc.wr.WriteString(str) +func (c *conn) WriteString(str string) { + c.wr.WriteString(str) } -func (wrc *connWriter) WriteBulk(bulk string) { - wrc.wr.WriteBulk(bulk) +func (c *conn) WriteBulk(bulk string) { + c.wr.WriteBulk(bulk) } -func (wrc *connWriter) WriteInt(num int) { - wrc.wr.WriteInt(num) +func (c *conn) WriteInt(num int) { + c.wr.WriteInt(num) } -func (wrc *connWriter) WriteError(msg string) { - wrc.wr.WriteError(msg) +func (c *conn) WriteError(msg string) { + c.wr.WriteError(msg) } -func (wrc *connWriter) WriteArray(count int) { - wrc.wr.WriteMultiBulkStart(count) +func (c *conn) WriteArray(count int) { + c.wr.WriteMultiBulkStart(count) } -func (wrc *connWriter) WriteNull() { - wrc.wr.WriteNull() +func (c *conn) WriteNull() { + c.wr.WriteNull() } -func (wrc *connWriter) RemoteAddr() string { - return wrc.addr +func (c *conn) RemoteAddr() string { + return c.addr } // Reader represents a RESP command reader. type respReader struct { - r io.Reader // base reader - b []byte // unprocessed bytes - a []byte // static read buffer + r *net.TCPConn // base reader + b []byte // unprocessed bytes + a []byte // static read buffer } // NewReader returns a RESP command reader. -func newReader(r io.Reader) *respReader { +func newReader(r *net.TCPConn) *respReader { return &respReader{ r: r, a: make([]byte, 8192), @@ -316,24 +321,24 @@ func (r *respReader) ReadCommands() ([][]string, error) { var errClosed = errors.New("closed") -type respWriter struct { - w io.Writer +type writer struct { + w *net.TCPConn b []byte err error } -func newWriter(w io.Writer) *respWriter { - return &respWriter{w: w, b: make([]byte, 0, 256)} +func newWriter(w *net.TCPConn) *writer { + return &writer{w: w, b: make([]byte, 0, 256)} } -func (w *respWriter) WriteNull() error { +func (w *writer) WriteNull() error { if w.err != nil { return w.err } w.b = append(w.b, '$', '-', '1', '\r', '\n') return nil } -func (w *respWriter) WriteMultiBulkStart(count int) error { +func (w *writer) WriteMultiBulkStart(count int) error { if w.err != nil { return w.err } @@ -343,7 +348,7 @@ func (w *respWriter) WriteMultiBulkStart(count int) error { return nil } -func (w *respWriter) WriteBulk(bulk string) error { +func (w *writer) WriteBulk(bulk string) error { if w.err != nil { return w.err } @@ -355,7 +360,7 @@ func (w *respWriter) WriteBulk(bulk string) error { return nil } -func (w *respWriter) Flush() error { +func (w *writer) Flush() error { if w.err != nil { return w.err } @@ -370,7 +375,7 @@ func (w *respWriter) Flush() error { return nil } -func (w *respWriter) WriteMultiBulk(bulks []string) error { +func (w *writer) WriteMultiBulk(bulks []string) error { if err := w.WriteMultiBulkStart(len(bulks)); err != nil { return err } @@ -382,7 +387,7 @@ func (w *respWriter) WriteMultiBulk(bulks []string) error { return nil } -func (w *respWriter) WriteError(msg string) error { +func (w *writer) WriteError(msg string) error { if w.err != nil { return w.err } @@ -392,7 +397,7 @@ func (w *respWriter) WriteError(msg string) error { return nil } -func (w *respWriter) WriteString(msg string) error { +func (w *writer) WriteString(msg string) error { if w.err != nil { return w.err } @@ -402,7 +407,7 @@ func (w *respWriter) WriteString(msg string) error { return nil } -func (w *respWriter) WriteInt(num int) error { +func (w *writer) WriteInt(num int) error { if w.err != nil { return w.err } @@ -412,7 +417,7 @@ func (w *respWriter) WriteInt(num int) error { return nil } -func (w *respWriter) Close() error { +func (w *writer) Close() error { if w.err != nil { return w.err }