mirror of https://github.com/tidwall/redcon.git
server type and better coverage
This commit is contained in:
parent
a88a6c7cb0
commit
e9d886853f
112
redcon.go
112
redcon.go
|
@ -53,24 +53,78 @@ func (err *errProtocol) Error() string {
|
||||||
return "Protocol error: " + err.msg
|
return "Protocol error: " + err.msg
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenAndServe creates a new server and binds to addr.
|
// Server represents a Redcon server.
|
||||||
func ListenAndServe(
|
type Server struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
addr string
|
||||||
|
handler func(conn Conn, cmds [][]string)
|
||||||
|
accept func(conn Conn) bool
|
||||||
|
closed func(conn Conn, err error)
|
||||||
|
ln *net.TCPListener
|
||||||
|
done bool
|
||||||
|
conns map[*conn]bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer returns a new server
|
||||||
|
func NewServer(
|
||||||
addr string, handler func(conn Conn, cmds [][]string),
|
addr string, handler func(conn Conn, cmds [][]string),
|
||||||
accept func(conn Conn) bool, closed func(conn Conn, err error),
|
accept func(conn Conn) bool, closed func(conn Conn, err error),
|
||||||
) error {
|
) *Server {
|
||||||
|
return &Server{
|
||||||
|
addr: addr,
|
||||||
|
handler: handler,
|
||||||
|
accept: accept,
|
||||||
|
closed: closed,
|
||||||
|
conns: make(map[*conn]bool),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops listening on the TCP address.
|
||||||
|
// Already Accepted connections will be closed.
|
||||||
|
func (s *Server) Close() error {
|
||||||
|
if s.ln == nil {
|
||||||
|
return errors.New("not serving")
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
s.done = true
|
||||||
|
s.mu.Unlock()
|
||||||
|
return s.ln.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenAndServe serves incoming connections.
|
||||||
|
func (s *Server) ListenAndServe() error {
|
||||||
|
var addr = s.addr
|
||||||
|
var handler = s.handler
|
||||||
|
var accept = s.accept
|
||||||
|
var closed = s.closed
|
||||||
ln, err := net.Listen("tcp", addr)
|
ln, err := net.Listen("tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer ln.Close()
|
s.ln = ln.(*net.TCPListener)
|
||||||
tcpln := ln.(*net.TCPListener)
|
defer func() {
|
||||||
|
ln.Close()
|
||||||
|
func() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
for c := range s.conns {
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
s.conns = nil
|
||||||
|
}()
|
||||||
|
}()
|
||||||
if handler == nil {
|
if handler == nil {
|
||||||
handler = func(conn Conn, cmds [][]string) {}
|
handler = func(conn Conn, cmds [][]string) {}
|
||||||
}
|
}
|
||||||
var mu sync.Mutex
|
|
||||||
for {
|
for {
|
||||||
tcpc, err := tcpln.AcceptTCP()
|
tcpc, err := s.ln.AcceptTCP()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.mu.Lock()
|
||||||
|
done := s.done
|
||||||
|
s.mu.Unlock()
|
||||||
|
if done {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c := &conn{
|
c := &conn{
|
||||||
|
@ -83,24 +137,40 @@ func ListenAndServe(
|
||||||
c.Close()
|
c.Close()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
go handle(c, &mu, handler, closed)
|
s.mu.Lock()
|
||||||
|
s.conns[c] = true
|
||||||
|
s.mu.Unlock()
|
||||||
|
go handle(s, c, handler, closed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func handle(c *conn, mu *sync.Mutex,
|
|
||||||
|
// ListenAndServe creates a new server and binds to addr.
|
||||||
|
func ListenAndServe(
|
||||||
|
addr string, handler func(conn Conn, cmds [][]string),
|
||||||
|
accept func(conn Conn) bool, closed func(conn Conn, err error),
|
||||||
|
) error {
|
||||||
|
return NewServer(addr, handler, accept, closed).ListenAndServe()
|
||||||
|
}
|
||||||
|
|
||||||
|
func handle(
|
||||||
|
s *Server, c *conn,
|
||||||
handler func(conn Conn, cmds [][]string),
|
handler func(conn Conn, cmds [][]string),
|
||||||
closed func(conn Conn, err error)) {
|
closed func(conn Conn, err error)) {
|
||||||
var err error
|
var err error
|
||||||
defer func() {
|
defer func() {
|
||||||
c.conn.Close()
|
c.conn.Close()
|
||||||
|
func() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.conns, c)
|
||||||
if closed != nil {
|
if closed != nil {
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
closed(c, err)
|
closed(c, err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
}()
|
||||||
err = func() error {
|
err = func() error {
|
||||||
for {
|
for {
|
||||||
cmds, err := c.rd.ReadCommands()
|
cmds, err := c.rd.ReadCommands()
|
||||||
|
@ -137,7 +207,9 @@ type conn struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) Close() error {
|
func (c *conn) Close() error {
|
||||||
return c.wr.Close()
|
err := c.wr.Close() // flush and close the writer
|
||||||
|
c.conn.Close() // close the connection. ignore this error
|
||||||
|
return err // return the writer error only
|
||||||
}
|
}
|
||||||
func (c *conn) WriteString(str string) {
|
func (c *conn) WriteString(str string) {
|
||||||
c.wr.WriteString(str)
|
c.wr.WriteString(str)
|
||||||
|
@ -152,7 +224,7 @@ func (c *conn) WriteError(msg string) {
|
||||||
c.wr.WriteError(msg)
|
c.wr.WriteError(msg)
|
||||||
}
|
}
|
||||||
func (c *conn) WriteArray(count int) {
|
func (c *conn) WriteArray(count int) {
|
||||||
c.wr.WriteMultiBulkStart(count)
|
c.wr.WriteArrayStart(count)
|
||||||
}
|
}
|
||||||
func (c *conn) WriteNull() {
|
func (c *conn) WriteNull() {
|
||||||
c.wr.WriteNull()
|
c.wr.WriteNull()
|
||||||
|
@ -377,7 +449,7 @@ func (w *writer) WriteNull() error {
|
||||||
w.b = append(w.b, '$', '-', '1', '\r', '\n')
|
w.b = append(w.b, '$', '-', '1', '\r', '\n')
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (w *writer) WriteMultiBulkStart(count int) error {
|
func (w *writer) WriteArrayStart(count int) error {
|
||||||
if w.err != nil {
|
if w.err != nil {
|
||||||
return w.err
|
return w.err
|
||||||
}
|
}
|
||||||
|
@ -414,18 +486,6 @@ func (w *writer) Flush() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *writer) WriteMultiBulk(bulks []string) error {
|
|
||||||
if err := w.WriteMultiBulkStart(len(bulks)); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
for _, bulk := range bulks {
|
|
||||||
if err := w.WriteBulk(bulk); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *writer) WriteError(msg string) error {
|
func (w *writer) WriteError(msg string) error {
|
||||||
if w.err != nil {
|
if w.err != nil {
|
||||||
return w.err
|
return w.err
|
||||||
|
|
102
redcon_test.go
102
redcon_test.go
|
@ -5,6 +5,8 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -178,9 +180,8 @@ func TestRandomCommands(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
func TestServer(t *testing.T) {
|
func TestServer(t *testing.T) {
|
||||||
err := ListenAndServe(":11111",
|
s := NewServer(":12345",
|
||||||
func(conn Conn, cmds [][]string) {
|
func(conn Conn, cmds [][]string) {
|
||||||
for _, cmd := range cmds {
|
for _, cmd := range cmds {
|
||||||
switch strings.ToLower(cmd[0]) {
|
switch strings.ToLower(cmd[0]) {
|
||||||
|
@ -191,19 +192,108 @@ func TestServer(t *testing.T) {
|
||||||
case "quit":
|
case "quit":
|
||||||
conn.WriteString("OK")
|
conn.WriteString("OK")
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
case "int":
|
||||||
|
conn.WriteInt(100)
|
||||||
|
case "bulk":
|
||||||
|
conn.WriteBulk("bulk")
|
||||||
|
case "null":
|
||||||
|
conn.WriteNull()
|
||||||
|
case "err":
|
||||||
|
conn.WriteError("ERR error")
|
||||||
|
case "array":
|
||||||
|
conn.WriteArray(2)
|
||||||
|
conn.WriteInt(99)
|
||||||
|
conn.WriteString("Hi!")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
func(conn Conn) bool {
|
func(conn Conn) bool {
|
||||||
log.Printf("accept: %s", conn.RemoteAddr())
|
//log.Printf("accept: %s", conn.RemoteAddr())
|
||||||
return true
|
return true
|
||||||
},
|
},
|
||||||
func(conn Conn, err error) {
|
func(conn Conn, err error) {
|
||||||
log.Printf("closed: %s [%v]", conn.RemoteAddr(), err)
|
//log.Printf("closed: %s [%v]", conn.RemoteAddr(), err)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
if err := s.Close(); err == nil {
|
||||||
|
t.Fatalf("expected an error, should not be able to close before serving")
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
time.Sleep(time.Second / 4)
|
||||||
|
if err := ListenAndServe(":12345", nil, nil, nil); err == nil {
|
||||||
|
t.Fatalf("expected an error, should not be able to listen on the same port")
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second / 4)
|
||||||
|
|
||||||
|
err := s.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
err = s.Close()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected an error")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
c, err := net.Dial("tcp", ":12345")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
do := func(cmd string) (string, error) {
|
||||||
|
io.WriteString(c, cmd)
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, err := c.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(buf[:n]), nil
|
||||||
|
}
|
||||||
|
res, err := do("PING\r\n")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if res != "+PONG\r\n" {
|
||||||
|
t.Fatal("expecting '+PONG\r\n', got '%v'", res)
|
||||||
|
}
|
||||||
|
res, err = do("BULK\r\n")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if res != "$4\r\nbulk\r\n" {
|
||||||
|
t.Fatal("expecting bulk, got '%v'", res)
|
||||||
|
}
|
||||||
|
res, err = do("INT\r\n")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if res != ":100\r\n" {
|
||||||
|
t.Fatal("expecting int, got '%v'", res)
|
||||||
|
}
|
||||||
|
res, err = do("NULL\r\n")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if res != "$-1\r\n" {
|
||||||
|
t.Fatal("expecting nul, got '%v'", res)
|
||||||
|
}
|
||||||
|
res, err = do("ARRAY\r\n")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if res != "*2\r\n:99\r\n+Hi!\r\n" {
|
||||||
|
t.Fatal("expecting array, got '%v'", res)
|
||||||
|
}
|
||||||
|
res, err = do("ERR\r\n")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if res != "-ERR error\r\n" {
|
||||||
|
t.Fatal("expecting array, got '%v'", res)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
err := s.ListenAndServe()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
Loading…
Reference in New Issue