diff --git a/example/clonebytes.go b/example/clonebytes.go new file mode 100644 index 0000000..4db5053 --- /dev/null +++ b/example/clonebytes.go @@ -0,0 +1,79 @@ +package main + +import ( + "log" + "sync" + + "github.com/tidwall/redcon" +) + +var addr = ":6380" + +func main() { + var mu sync.RWMutex + var items = make(map[string][]byte) + go log.Printf("started server at %s", addr) + err := redcon.ListenAndServeBytes(addr, + func(conn redcon.Conn, commands [][][]byte) { + for _, args := range commands { + switch string(args[0]) { + default: + conn.WriteError("ERR unknown command '" + string(args[0]) + "'") + case "ping": + conn.WriteString("PONG") + case "quit": + conn.WriteString("OK") + conn.Close() + case "set": + if len(args) != 3 { + conn.WriteError("ERR wrong number of arguments for '" + string(args[0]) + "' command") + continue + } + mu.Lock() + items[string(args[1])] = args[2] + mu.Unlock() + conn.WriteString("OK") + case "get": + if len(args) != 2 { + conn.WriteError("ERR wrong number of arguments for '" + string(args[0]) + "' command") + continue + } + mu.RLock() + val, ok := items[string(args[1])] + mu.RUnlock() + if !ok { + conn.WriteNull() + } else { + conn.WriteBulkBytes(val) + } + case "del": + if len(args) != 2 { + conn.WriteError("ERR wrong number of arguments for '" + string(args[0]) + "' command") + continue + } + mu.Lock() + _, ok := items[string(args[1])] + delete(items, string(args[1])) + mu.Unlock() + if !ok { + conn.WriteInt(0) + } else { + conn.WriteInt(1) + } + } + } + }, + func(conn redcon.Conn) bool { + // use this function to accept or deny the connection. + // log.Printf("accept: %s", conn.RemoteAddr()) + return true + }, + func(conn redcon.Conn, err error) { + // this is called when the connection has been closed + // log.Printf("closed: %s, err: %v", conn.RemoteAddr(), err) + }, + ) + if err != nil { + log.Fatal(err) + } +} diff --git a/redcon.go b/redcon.go index fdcc306..6274204 100644 --- a/redcon.go +++ b/redcon.go @@ -64,16 +64,17 @@ func (err *errProtocol) Error() string { // Server represents a Redcon server. type Server struct { - mu sync.Mutex - addr string - handler func(conn Conn, cmds [][]string) - accept func(conn Conn) bool - closed func(conn Conn, err error) - ln *net.TCPListener - done bool - conns map[*conn]bool - rdpool [][]byte - wrpool [][]byte + mu sync.Mutex + addr string + shandler func(conn Conn, cmds [][]string) + bhandler func(conn Conn, cmds [][][]byte) + accept func(conn Conn) bool + closed func(conn Conn, err error) + ln *net.TCPListener + done bool + conns map[*conn]bool + rdpool [][]byte + wrpool [][]byte } // NewServer returns a new server @@ -82,11 +83,26 @@ func NewServer( accept func(conn Conn) bool, closed func(conn Conn, err error), ) *Server { return &Server{ - addr: addr, - handler: handler, - accept: accept, - closed: closed, - conns: make(map[*conn]bool), + addr: addr, + shandler: handler, + accept: accept, + closed: closed, + conns: make(map[*conn]bool), + } +} + +// NewServerBytes returns a new server +// It uses []byte instead of string for the handler commands. +func NewServerBytes( + addr string, handler func(conn Conn, cmds [][][]byte), + accept func(conn Conn) bool, closed func(conn Conn, err error), +) *Server { + return &Server{ + addr: addr, + bhandler: handler, + accept: accept, + closed: closed, + conns: make(map[*conn]bool), } } @@ -111,7 +127,8 @@ func (s *Server) ListenAndServe() error { // when listening. signal can be nil. func (s *Server) ListenServeAndSignal(signal chan error) error { var addr = s.addr - var handler = s.handler + var shandler = s.shandler + var bhandler = s.bhandler var accept = s.accept var closed = s.closed ln, err := net.Listen("tcp", addr) @@ -139,9 +156,6 @@ func (s *Server) ListenServeAndSignal(signal chan error) error { s.conns = nil }() }() - if handler == nil { - handler = func(conn Conn, cmds [][]string) {} - } for { tcpc, err := tln.AcceptTCP() if err != nil { @@ -178,7 +192,14 @@ func (s *Server) ListenServeAndSignal(signal chan error) error { c.Close() continue } - go handle(s, c, handler, closed) + if shandler == nil && bhandler == nil { + shandler = func(conn Conn, cmds [][]string) {} + } else if shandler != nil { + bhandler = nil + } else if bhandler != nil { + shandler = nil + } + go handle(s, c, shandler, bhandler, closed) } } @@ -190,9 +211,19 @@ func ListenAndServe( return NewServer(addr, handler, accept, closed).ListenAndServe() } +// ListenAndServeBytes creates a new server and binds to addr. +// It uses []byte instead of string for the handler commands. +func ListenAndServeBytes( + addr string, handler func(conn Conn, cmds [][][]byte), + accept func(conn Conn) bool, closed func(conn Conn, err error), +) error { + return NewServerBytes(addr, handler, accept, closed).ListenAndServe() +} + func handle( s *Server, c *conn, - handler func(conn Conn, cmds [][]string), + shandler func(conn Conn, cmds [][]string), + bhandler func(conn Conn, cmds [][][]byte), closed func(conn Conn, err error)) { var err error defer func() { @@ -228,7 +259,28 @@ func handle( return err } if len(cmds) > 0 { - handler(c, cmds) + if shandler != nil { + // convert bytes to strings + scmds := make([][]string, len(cmds)) + for i := 0; i < len(cmds); i++ { + scmds[i] = make([]string, len(cmds[i])) + for j := 0; j < len(scmds[i]); j++ { + scmds[i][j] = string(scmds[i][j]) + } + } + shandler(c, scmds) + } else if bhandler != nil { + // copy the byte commands once, before exposing to the + // client. + for i := 0; i < len(cmds); i++ { + for j := 0; j < len(cmds[i]); j++ { + nb := make([]byte, len(cmds[i][j])) + copy(nb, cmds[i][j]) + cmds[i][j] = nb + } + } + bhandler(c, cmds) + } } if c.wr.err != nil { if c.wr.err == errClosed { @@ -313,11 +365,11 @@ func (rd *reader) reassign(r io.Reader) { } // ReadCommands reads one or more commands from the reader. -func (r *reader) ReadCommands() ([][]string, error) { +func (r *reader) ReadCommands() ([][][]byte, error) { if r.end-r.start > 0 { b := r.buf[r.start:r.end] // we have some potential commands. - var cmds [][]string + var cmds [][][]byte next: switch b[0] { default: @@ -330,7 +382,7 @@ func (r *reader) ReadCommands() ([][]string, error) { } else { line = b[:i] } - var args []string + var args [][]byte var quote bool var escape bool outer: @@ -341,7 +393,7 @@ func (r *reader) ReadCommands() ([][]string, error) { if !quote { if c == ' ' { if len(nline) > 0 { - args = append(args, string(nline)) + args = append(args, nline) } line = line[i+1:] continue outer @@ -367,7 +419,7 @@ func (r *reader) ReadCommands() ([][]string, error) { } } else if c == '"' { quote = false - args = append(args, string(nline)) + args = append(args, nline) line = line[i+1:] if len(line) > 0 && line[0] != ' ' { return nil, errUnbalancedQuotes @@ -384,7 +436,7 @@ func (r *reader) ReadCommands() ([][]string, error) { return nil, errUnbalancedQuotes } if len(line) > 0 { - args = append(args, string(line)) + args = append(args, line) } break } @@ -404,7 +456,7 @@ func (r *reader) ReadCommands() ([][]string, error) { var si int outer2: for i := 0; i < len(b); i++ { - var args []string + var args [][]byte if b[i] == '\n' { if b[i-1] != '\r' { return nil, errInvalidMultiBulkLength @@ -413,7 +465,7 @@ func (r *reader) ReadCommands() ([][]string, error) { if err != nil || ni <= 0 { return nil, errInvalidMultiBulkLength } - args = make([]string, 0, ni) + args = make([][]byte, 0, ni) for j := 0; j < ni; j++ { // read bulk length i++ @@ -450,7 +502,7 @@ func (r *reader) ReadCommands() ([][]string, error) { } } i += ni2 + 1 - args = append(args, string(arg)) + args = append(args, arg) break } } diff --git a/redcon_test.go b/redcon_test.go index 73c2b4a..7baf369 100644 --- a/redcon_test.go +++ b/redcon_test.go @@ -155,7 +155,7 @@ func TestRandomCommands(t *testing.T) { log.Fatal(err) } for _, cmd := range cmds { - if len(cmd) == 3 && cmd[0] == "RESET" && cmd[1] == "THE" && cmd[2] == "INDEX" { + if len(cmd) == 3 && string(cmd[0]) == "RESET" && string(cmd[1]) == "THE" && string(cmd[2]) == "INDEX" { if idx != len(gcmds) { t.Fatalf("did not process all commands") } @@ -186,7 +186,7 @@ func TestRandomCommands(t *testing.T) { continue } } - } else if cmd[i] == gcmds[idx][i] { + } else if string(cmd[i]) == string(gcmds[idx][i]) { continue } t.Fatalf("not equal for index %d/%d", idx, i)