server type and better coverage

This commit is contained in:
Josh Baker 2016-08-21 09:34:23 -07:00
parent a88a6c7cb0
commit e9d886853f
2 changed files with 187 additions and 37 deletions

122
redcon.go
View File

@ -53,24 +53,78 @@ func (err *errProtocol) Error() string {
return "Protocol error: " + err.msg
}
// ListenAndServe creates a new server and binds to addr.
func ListenAndServe(
// Server represents a Redcon server.
type Server struct {
mu sync.Mutex
addr string
handler func(conn Conn, cmds [][]string)
accept func(conn Conn) bool
closed func(conn Conn, err error)
ln *net.TCPListener
done bool
conns map[*conn]bool
}
// NewServer returns a new server
func NewServer(
addr string, handler func(conn Conn, cmds [][]string),
accept func(conn Conn) bool, closed func(conn Conn, err error),
) error {
) *Server {
return &Server{
addr: addr,
handler: handler,
accept: accept,
closed: closed,
conns: make(map[*conn]bool),
}
}
// Close stops listening on the TCP address.
// Already Accepted connections will be closed.
func (s *Server) Close() error {
if s.ln == nil {
return errors.New("not serving")
}
s.mu.Lock()
s.done = true
s.mu.Unlock()
return s.ln.Close()
}
// ListenAndServe serves incoming connections.
func (s *Server) ListenAndServe() error {
var addr = s.addr
var handler = s.handler
var accept = s.accept
var closed = s.closed
ln, err := net.Listen("tcp", addr)
if err != nil {
return err
}
defer ln.Close()
tcpln := ln.(*net.TCPListener)
s.ln = ln.(*net.TCPListener)
defer func() {
ln.Close()
func() {
s.mu.Lock()
defer s.mu.Unlock()
for c := range s.conns {
c.Close()
}
s.conns = nil
}()
}()
if handler == nil {
handler = func(conn Conn, cmds [][]string) {}
}
var mu sync.Mutex
for {
tcpc, err := tcpln.AcceptTCP()
tcpc, err := s.ln.AcceptTCP()
if err != nil {
s.mu.Lock()
done := s.done
s.mu.Unlock()
if done {
return nil
}
return err
}
c := &conn{
@ -83,23 +137,39 @@ func ListenAndServe(
c.Close()
continue
}
go handle(c, &mu, handler, closed)
s.mu.Lock()
s.conns[c] = true
s.mu.Unlock()
go handle(s, c, handler, closed)
}
}
func handle(c *conn, mu *sync.Mutex,
// ListenAndServe creates a new server and binds to addr.
func ListenAndServe(
addr string, handler func(conn Conn, cmds [][]string),
accept func(conn Conn) bool, closed func(conn Conn, err error),
) error {
return NewServer(addr, handler, accept, closed).ListenAndServe()
}
func handle(
s *Server, c *conn,
handler func(conn Conn, cmds [][]string),
closed func(conn Conn, err error)) {
var err error
defer func() {
c.conn.Close()
if closed != nil {
mu.Lock()
defer mu.Unlock()
if err == io.EOF {
err = nil
func() {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.conns, c)
if closed != nil {
if err == io.EOF {
err = nil
}
closed(c, err)
}
closed(c, err)
}
}()
}()
err = func() error {
for {
@ -137,7 +207,9 @@ type conn struct {
}
func (c *conn) Close() error {
return c.wr.Close()
err := c.wr.Close() // flush and close the writer
c.conn.Close() // close the connection. ignore this error
return err // return the writer error only
}
func (c *conn) WriteString(str string) {
c.wr.WriteString(str)
@ -152,7 +224,7 @@ func (c *conn) WriteError(msg string) {
c.wr.WriteError(msg)
}
func (c *conn) WriteArray(count int) {
c.wr.WriteMultiBulkStart(count)
c.wr.WriteArrayStart(count)
}
func (c *conn) WriteNull() {
c.wr.WriteNull()
@ -377,7 +449,7 @@ func (w *writer) WriteNull() error {
w.b = append(w.b, '$', '-', '1', '\r', '\n')
return nil
}
func (w *writer) WriteMultiBulkStart(count int) error {
func (w *writer) WriteArrayStart(count int) error {
if w.err != nil {
return w.err
}
@ -414,18 +486,6 @@ func (w *writer) Flush() error {
return nil
}
func (w *writer) WriteMultiBulk(bulks []string) error {
if err := w.WriteMultiBulkStart(len(bulks)); err != nil {
return err
}
for _, bulk := range bulks {
if err := w.WriteBulk(bulk); err != nil {
return err
}
}
return nil
}
func (w *writer) WriteError(msg string) error {
if w.err != nil {
return w.err

View File

@ -5,6 +5,8 @@ import (
"io"
"log"
"math/rand"
"net"
"strings"
"testing"
"time"
)
@ -178,9 +180,8 @@ func TestRandomCommands(t *testing.T) {
}
}
/*
func TestServer(t *testing.T) {
err := ListenAndServe(":11111",
s := NewServer(":12345",
func(conn Conn, cmds [][]string) {
for _, cmd := range cmds {
switch strings.ToLower(cmd[0]) {
@ -191,19 +192,108 @@ func TestServer(t *testing.T) {
case "quit":
conn.WriteString("OK")
conn.Close()
case "int":
conn.WriteInt(100)
case "bulk":
conn.WriteBulk("bulk")
case "null":
conn.WriteNull()
case "err":
conn.WriteError("ERR error")
case "array":
conn.WriteArray(2)
conn.WriteInt(99)
conn.WriteString("Hi!")
}
}
},
func(conn Conn) bool {
log.Printf("accept: %s", conn.RemoteAddr())
//log.Printf("accept: %s", conn.RemoteAddr())
return true
},
func(conn Conn, err error) {
log.Printf("closed: %s [%v]", conn.RemoteAddr(), err)
//log.Printf("closed: %s [%v]", conn.RemoteAddr(), err)
},
)
if err := s.Close(); err == nil {
t.Fatalf("expected an error, should not be able to close before serving")
}
go func() {
time.Sleep(time.Second / 4)
if err := ListenAndServe(":12345", nil, nil, nil); err == nil {
t.Fatalf("expected an error, should not be able to listen on the same port")
}
time.Sleep(time.Second / 4)
err := s.Close()
if err != nil {
t.Fatal(err)
}
err = s.Close()
if err == nil {
t.Fatalf("expected an error")
}
}()
go func() {
c, err := net.Dial("tcp", ":12345")
if err != nil {
t.Fatal(err)
}
defer c.Close()
do := func(cmd string) (string, error) {
io.WriteString(c, cmd)
buf := make([]byte, 1024)
n, err := c.Read(buf)
if err != nil {
return "", err
}
return string(buf[:n]), nil
}
res, err := do("PING\r\n")
if err != nil {
t.Fatal(err)
}
if res != "+PONG\r\n" {
t.Fatal("expecting '+PONG\r\n', got '%v'", res)
}
res, err = do("BULK\r\n")
if err != nil {
t.Fatal(err)
}
if res != "$4\r\nbulk\r\n" {
t.Fatal("expecting bulk, got '%v'", res)
}
res, err = do("INT\r\n")
if err != nil {
t.Fatal(err)
}
if res != ":100\r\n" {
t.Fatal("expecting int, got '%v'", res)
}
res, err = do("NULL\r\n")
if err != nil {
t.Fatal(err)
}
if res != "$-1\r\n" {
t.Fatal("expecting nul, got '%v'", res)
}
res, err = do("ARRAY\r\n")
if err != nil {
t.Fatal(err)
}
if res != "*2\r\n:99\r\n+Hi!\r\n" {
t.Fatal("expecting array, got '%v'", res)
}
res, err = do("ERR\r\n")
if err != nil {
t.Fatal(err)
}
if res != "-ERR error\r\n" {
t.Fatal("expecting array, got '%v'", res)
}
}()
err := s.ListenAndServe()
if err != nil {
log.Fatal(err)
t.Fatal(err)
}
}
*/