diff --git a/redcon.go b/redcon.go index 79b3632..67d2035 100644 --- a/redcon.go +++ b/redcon.go @@ -53,24 +53,78 @@ func (err *errProtocol) Error() string { return "Protocol error: " + err.msg } -// ListenAndServe creates a new server and binds to addr. -func ListenAndServe( +// Server represents a Redcon server. +type Server struct { + mu sync.Mutex + addr string + handler func(conn Conn, cmds [][]string) + accept func(conn Conn) bool + closed func(conn Conn, err error) + ln *net.TCPListener + done bool + conns map[*conn]bool +} + +// NewServer returns a new server +func NewServer( addr string, handler func(conn Conn, cmds [][]string), accept func(conn Conn) bool, closed func(conn Conn, err error), -) error { +) *Server { + return &Server{ + addr: addr, + handler: handler, + accept: accept, + closed: closed, + conns: make(map[*conn]bool), + } +} + +// Close stops listening on the TCP address. +// Already Accepted connections will be closed. +func (s *Server) Close() error { + if s.ln == nil { + return errors.New("not serving") + } + s.mu.Lock() + s.done = true + s.mu.Unlock() + return s.ln.Close() +} + +// ListenAndServe serves incoming connections. +func (s *Server) ListenAndServe() error { + var addr = s.addr + var handler = s.handler + var accept = s.accept + var closed = s.closed ln, err := net.Listen("tcp", addr) if err != nil { return err } - defer ln.Close() - tcpln := ln.(*net.TCPListener) + s.ln = ln.(*net.TCPListener) + defer func() { + ln.Close() + func() { + s.mu.Lock() + defer s.mu.Unlock() + for c := range s.conns { + c.Close() + } + s.conns = nil + }() + }() if handler == nil { handler = func(conn Conn, cmds [][]string) {} } - var mu sync.Mutex for { - tcpc, err := tcpln.AcceptTCP() + tcpc, err := s.ln.AcceptTCP() if err != nil { + s.mu.Lock() + done := s.done + s.mu.Unlock() + if done { + return nil + } return err } c := &conn{ @@ -83,23 +137,39 @@ func ListenAndServe( c.Close() continue } - go handle(c, &mu, handler, closed) + s.mu.Lock() + s.conns[c] = true + s.mu.Unlock() + go handle(s, c, handler, closed) } } -func handle(c *conn, mu *sync.Mutex, + +// ListenAndServe creates a new server and binds to addr. +func ListenAndServe( + addr string, handler func(conn Conn, cmds [][]string), + accept func(conn Conn) bool, closed func(conn Conn, err error), +) error { + return NewServer(addr, handler, accept, closed).ListenAndServe() +} + +func handle( + s *Server, c *conn, handler func(conn Conn, cmds [][]string), closed func(conn Conn, err error)) { var err error defer func() { c.conn.Close() - if closed != nil { - mu.Lock() - defer mu.Unlock() - if err == io.EOF { - err = nil + func() { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.conns, c) + if closed != nil { + if err == io.EOF { + err = nil + } + closed(c, err) } - closed(c, err) - } + }() }() err = func() error { for { @@ -137,7 +207,9 @@ type conn struct { } func (c *conn) Close() error { - return c.wr.Close() + err := c.wr.Close() // flush and close the writer + c.conn.Close() // close the connection. ignore this error + return err // return the writer error only } func (c *conn) WriteString(str string) { c.wr.WriteString(str) @@ -152,7 +224,7 @@ func (c *conn) WriteError(msg string) { c.wr.WriteError(msg) } func (c *conn) WriteArray(count int) { - c.wr.WriteMultiBulkStart(count) + c.wr.WriteArrayStart(count) } func (c *conn) WriteNull() { c.wr.WriteNull() @@ -377,7 +449,7 @@ func (w *writer) WriteNull() error { w.b = append(w.b, '$', '-', '1', '\r', '\n') return nil } -func (w *writer) WriteMultiBulkStart(count int) error { +func (w *writer) WriteArrayStart(count int) error { if w.err != nil { return w.err } @@ -414,18 +486,6 @@ func (w *writer) Flush() error { return nil } -func (w *writer) WriteMultiBulk(bulks []string) error { - if err := w.WriteMultiBulkStart(len(bulks)); err != nil { - return err - } - for _, bulk := range bulks { - if err := w.WriteBulk(bulk); err != nil { - return err - } - } - return nil -} - func (w *writer) WriteError(msg string) error { if w.err != nil { return w.err diff --git a/redcon_test.go b/redcon_test.go index 8f88239..471320b 100644 --- a/redcon_test.go +++ b/redcon_test.go @@ -5,6 +5,8 @@ import ( "io" "log" "math/rand" + "net" + "strings" "testing" "time" ) @@ -178,9 +180,8 @@ func TestRandomCommands(t *testing.T) { } } -/* func TestServer(t *testing.T) { - err := ListenAndServe(":11111", + s := NewServer(":12345", func(conn Conn, cmds [][]string) { for _, cmd := range cmds { switch strings.ToLower(cmd[0]) { @@ -191,19 +192,108 @@ func TestServer(t *testing.T) { case "quit": conn.WriteString("OK") conn.Close() + case "int": + conn.WriteInt(100) + case "bulk": + conn.WriteBulk("bulk") + case "null": + conn.WriteNull() + case "err": + conn.WriteError("ERR error") + case "array": + conn.WriteArray(2) + conn.WriteInt(99) + conn.WriteString("Hi!") } } }, func(conn Conn) bool { - log.Printf("accept: %s", conn.RemoteAddr()) + //log.Printf("accept: %s", conn.RemoteAddr()) return true }, func(conn Conn, err error) { - log.Printf("closed: %s [%v]", conn.RemoteAddr(), err) + //log.Printf("closed: %s [%v]", conn.RemoteAddr(), err) }, ) + if err := s.Close(); err == nil { + t.Fatalf("expected an error, should not be able to close before serving") + } + go func() { + time.Sleep(time.Second / 4) + if err := ListenAndServe(":12345", nil, nil, nil); err == nil { + t.Fatalf("expected an error, should not be able to listen on the same port") + } + time.Sleep(time.Second / 4) + + err := s.Close() + if err != nil { + t.Fatal(err) + } + err = s.Close() + if err == nil { + t.Fatalf("expected an error") + } + }() + go func() { + c, err := net.Dial("tcp", ":12345") + if err != nil { + t.Fatal(err) + } + defer c.Close() + do := func(cmd string) (string, error) { + io.WriteString(c, cmd) + buf := make([]byte, 1024) + n, err := c.Read(buf) + if err != nil { + return "", err + } + return string(buf[:n]), nil + } + res, err := do("PING\r\n") + if err != nil { + t.Fatal(err) + } + if res != "+PONG\r\n" { + t.Fatal("expecting '+PONG\r\n', got '%v'", res) + } + res, err = do("BULK\r\n") + if err != nil { + t.Fatal(err) + } + if res != "$4\r\nbulk\r\n" { + t.Fatal("expecting bulk, got '%v'", res) + } + res, err = do("INT\r\n") + if err != nil { + t.Fatal(err) + } + if res != ":100\r\n" { + t.Fatal("expecting int, got '%v'", res) + } + res, err = do("NULL\r\n") + if err != nil { + t.Fatal(err) + } + if res != "$-1\r\n" { + t.Fatal("expecting nul, got '%v'", res) + } + res, err = do("ARRAY\r\n") + if err != nil { + t.Fatal(err) + } + if res != "*2\r\n:99\r\n+Hi!\r\n" { + t.Fatal("expecting array, got '%v'", res) + } + res, err = do("ERR\r\n") + if err != nil { + t.Fatal(err) + } + if res != "-ERR error\r\n" { + t.Fatal("expecting array, got '%v'", res) + } + }() + err := s.ListenAndServe() if err != nil { - log.Fatal(err) + t.Fatal(err) } } -*/