redcon/redcon.go

1409 lines
33 KiB
Go

// Package redcon implements a Redis compatible server framework
package redcon
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"strings"
"sync"
"time"
"github.com/tidwall/btree"
"github.com/tidwall/match"
)
var (
errUnbalancedQuotes = &errProtocol{"unbalanced quotes in request"}
errInvalidBulkLength = &errProtocol{"invalid bulk length"}
errInvalidMultiBulkLength = &errProtocol{"invalid multibulk length"}
errDetached = errors.New("detached")
errIncompleteCommand = errors.New("incomplete command")
errTooMuchData = errors.New("too much data")
errContextDone = errors.New("context done")
)
type errProtocol struct {
msg string
}
func (err *errProtocol) Error() string {
return "Protocol error: " + err.msg
}
// Conn represents a client connection
type Conn interface {
// RemoteAddr returns the remote address of the client connection.
RemoteAddr() string
// Close closes the connection.
Close() error
// WriteError writes an error to the client.
WriteError(msg string)
// WriteString writes a string to the client.
WriteString(str string)
// WriteBulk writes bulk bytes to the client.
WriteBulk(bulk []byte)
// WriteBulkString writes a bulk string to the client.
WriteBulkString(bulk string)
// WriteInt writes an integer to the client.
WriteInt(num int)
// WriteInt64 writes a 64-bit signed integer to the client.
WriteInt64(num int64)
// WriteUint64 writes a 64-bit unsigned integer to the client.
WriteUint64(num uint64)
// WriteArray writes an array header. You must then write additional
// sub-responses to the client to complete the response.
// For example to write two strings:
//
// c.WriteArray(2)
// c.WriteBulk("item 1")
// c.WriteBulk("item 2")
WriteArray(count int)
// WriteNull writes a null to the client
WriteNull()
// WriteRaw writes raw data to the client.
WriteRaw(data []byte)
// WriteAny writes any type to the client.
// nil -> null
// error -> error (adds "ERR " when first word is not uppercase)
// string -> bulk-string
// numbers -> bulk-string
// []byte -> bulk-string
// bool -> bulk-string ("0" or "1")
// slice -> array
// map -> array with key/value pairs
// SimpleString -> string
// SimpleInt -> integer
// everything-else -> bulk-string representation using fmt.Sprint()
WriteAny(any interface{})
// Context returns a user-defined context
Context() interface{}
// SetContext sets a user-defined context
SetContext(v interface{})
// SetReadBuffer updates the buffer read size for the connection
SetReadBuffer(bytes int)
// Detach return a connection that is detached from the server.
// Useful for operations like PubSub.
//
// dconn := conn.Detach()
// go func(){
// defer dconn.Close()
// cmd, err := dconn.ReadCommand()
// if err != nil{
// fmt.Printf("read failed: %v\n", err)
// return
// }
// fmt.Printf("received command: %v", cmd)
// hconn.WriteString("OK")
// if err := dconn.Flush(); err != nil{
// fmt.Printf("write failed: %v\n", err)
// return
// }
// }()
Detach() DetachedConn
// ReadPipeline returns all commands in current pipeline, if any
// The commands are removed from the pipeline.
ReadPipeline() []Command
// PeekPipeline returns all commands in current pipeline, if any.
// The commands remain in the pipeline.
PeekPipeline() []Command
// NetConn returns the base net.Conn connection
NetConn() net.Conn
}
// NewServer returns a new Redcon server configured on "tcp" network net.
func NewServer(
ctx context.Context,
addr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
) *Server {
return NewServerNetwork(ctx, "tcp", addr, handler, accept, closed)
}
// NewServerTLS returns a new Redcon TLS server configured on "tcp" network net.
func NewServerTLS(ctx context.Context,
addr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
config *tls.Config,
) *TLSServer {
return NewServerNetworkTLS(ctx, "tcp", addr, handler, accept, closed, config)
}
// NewServerNetwork returns a new Redcon server. The network net must be
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
func NewServerNetwork(
ctx context.Context,
net, laddr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
) *Server {
if handler == nil {
panic("handler is nil")
}
s := &Server{
ctx: ctx,
net: net,
laddr: laddr,
handler: handler,
accept: accept,
closed: closed,
conns: make(map[*conn]bool),
}
return s
}
// NewServerNetworkTLS returns a new TLS Redcon server. The network net must be
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
func NewServerNetworkTLS(
ctx context.Context,
net, laddr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
config *tls.Config,
) *TLSServer {
if handler == nil {
panic("handler is nil")
}
s := Server{
ctx: ctx,
net: net,
laddr: laddr,
handler: handler,
accept: accept,
closed: closed,
conns: make(map[*conn]bool),
}
tls := &TLSServer{
config: config,
Server: &s,
}
return tls
}
// Close stops listening on the TCP address.
// Already Accepted connections will be closed.
func (s *Server) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.ln == nil {
return errors.New("not serving")
}
s.done = true
return s.ln.Close()
}
// ListenAndServe serves incoming connections.
func (s *Server) ListenAndServe() error {
return s.ListenServeAndSignal(nil)
}
// Addr returns server's listen address
func (s *Server) Addr() net.Addr {
return s.ln.Addr()
}
// Close stops listening on the TCP address.
// Already Accepted connections will be closed.
func (s *TLSServer) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.ln == nil {
return errors.New("not serving")
}
s.done = true
return s.ln.Close()
}
// ListenAndServe serves incoming connections.
func (s *TLSServer) ListenAndServe() error {
return s.ListenServeAndSignal(nil)
}
// Serve creates a new server and serves with the given net.Listener.
func Serve(ln net.Listener,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
) error {
s := &Server{
net: ln.Addr().Network(),
laddr: ln.Addr().String(),
ln: ln,
handler: handler,
accept: accept,
closed: closed,
conns: make(map[*conn]bool),
}
return serve(s)
}
// ListenAndServe creates a new server and binds to addr configured on "tcp" network net.
func ListenAndServe(
ctx context.Context,
addr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
) error {
return ListenAndServeNetwork(ctx, "tcp", addr, handler, accept, closed)
}
// ListenAndServeTLS creates a new TLS server and binds to addr configured on "tcp" network net.
func ListenAndServeTLS(
ctx context.Context,
addr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
config *tls.Config,
) error {
return ListenAndServeNetworkTLS(ctx, "tcp", addr, handler, accept, closed, config)
}
// ListenAndServeNetwork creates a new server and binds to addr. The network net must be
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
func ListenAndServeNetwork(
ctx context.Context,
net, laddr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
) error {
return NewServerNetwork(ctx, net, laddr, handler, accept, closed).ListenAndServe()
}
// ListenAndServeNetworkTLS creates a new TLS server and binds to addr. The network net must be
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
func ListenAndServeNetworkTLS(
ctx context.Context,
net, laddr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
config *tls.Config,
) error {
return NewServerNetworkTLS(ctx, net, laddr, handler, accept, closed, config).ListenAndServe()
}
// ListenServeAndSignal serves incoming connections and passes nil or error
// when listening. signal can be nil.
func (s *Server) ListenServeAndSignal(signal chan error) error {
//var lc net.ListenConfig
//ln, err := lc.Listen(s.ctx, s.net, s.laddr)
ln, err := net.Listen(s.net, s.laddr)
if err != nil {
if signal != nil {
signal <- err
}
return err
}
s.ln = ln
if signal != nil {
signal <- nil
}
return serve(s)
}
// Serve serves incoming connections with the given net.Listener.
func (s *Server) Serve(ln net.Listener) error {
s.ln = ln
s.net = ln.Addr().Network()
s.laddr = ln.Addr().String()
return serve(s)
}
// ListenServeAndSignal serves incoming connections and passes nil or error
// when listening. signal can be nil.
func (s *TLSServer) ListenServeAndSignal(signal chan error) error {
ln, err := tls.Listen(s.net, s.laddr, s.config)
if err != nil {
if signal != nil {
signal <- err
}
return err
}
s.ln = ln
if signal != nil {
signal <- nil
}
return serve(s.Server)
}
func serve(s *Server) error {
defer func() {
s.ln.Close()
func() {
s.mu.Lock()
defer s.mu.Unlock()
for c := range s.conns {
c.Close()
}
s.conns = nil
}()
}()
go func() {
select {
case <-s.ctx.Done():
s.Close()
}
}()
for {
lnconn, err := s.ln.Accept()
if err != nil {
s.mu.Lock()
done := s.done
s.mu.Unlock()
if done {
select {
case <-s.ctx.Done():
return errContextDone
default:
}
return nil
}
if s.AcceptError != nil {
s.AcceptError(err)
}
continue
}
c := &conn{
conn: lnconn,
addr: lnconn.RemoteAddr().String(),
wr: NewWriter(lnconn),
rd: NewReader(lnconn),
}
s.mu.Lock()
c.idleClose = s.idleClose
s.conns[c] = true
s.mu.Unlock()
if s.accept != nil && !s.accept(c) {
s.mu.Lock()
delete(s.conns, c)
s.mu.Unlock()
c.Close()
continue
}
go handle(s, c)
}
}
// handle manages the server connection.
func handle(s *Server, c *conn) {
var err error
defer func() {
if err != errDetached {
// do not close the connection when a detach is detected.
c.conn.Close()
}
func() {
// remove the conn from the server
s.mu.Lock()
defer s.mu.Unlock()
delete(s.conns, c)
if s.closed != nil {
if err == io.EOF {
err = nil
}
s.closed(c, err)
}
}()
}()
err = func() error {
// read commands and feed back to the client
for {
// read pipeline commands
if c.idleClose != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.idleClose))
}
cmds, err := c.rd.readCommands(nil)
if err != nil {
if err, ok := err.(*errProtocol); ok {
// All protocol errors should attempt a response to
// the client. Ignore write errors.
c.wr.WriteError("ERR " + err.Error())
c.wr.Flush()
}
return err
}
c.cmds = cmds
for len(c.cmds) > 0 {
cmd := c.cmds[0]
if len(c.cmds) == 1 {
c.cmds = nil
} else {
c.cmds = c.cmds[1:]
}
s.handler(c, cmd)
}
if c.detached {
// client has been detached
return errDetached
}
if c.closed {
return nil
}
if err := c.wr.Flush(); err != nil {
return err
}
}
}()
}
// conn represents a client connection
type conn struct {
conn net.Conn
wr *Writer
rd *Reader
addr string
ctx interface{}
detached bool
closed bool
cmds []Command
idleClose time.Duration
}
func (c *conn) Close() error {
c.wr.Flush()
c.closed = true
return c.conn.Close()
}
func (c *conn) Context() interface{} { return c.ctx }
func (c *conn) SetContext(v interface{}) { c.ctx = v }
func (c *conn) SetReadBuffer(n int) {}
func (c *conn) WriteString(str string) { c.wr.WriteString(str) }
func (c *conn) WriteBulk(bulk []byte) { c.wr.WriteBulk(bulk) }
func (c *conn) WriteBulkString(bulk string) { c.wr.WriteBulkString(bulk) }
func (c *conn) WriteInt(num int) { c.wr.WriteInt(num) }
func (c *conn) WriteInt64(num int64) { c.wr.WriteInt64(num) }
func (c *conn) WriteUint64(num uint64) { c.wr.WriteUint64(num) }
func (c *conn) WriteError(msg string) { c.wr.WriteError(msg) }
func (c *conn) WriteArray(count int) { c.wr.WriteArray(count) }
func (c *conn) WriteNull() { c.wr.WriteNull() }
func (c *conn) WriteRaw(data []byte) { c.wr.WriteRaw(data) }
func (c *conn) WriteAny(v interface{}) { c.wr.WriteAny(v) }
func (c *conn) RemoteAddr() string { return c.addr }
func (c *conn) ReadPipeline() []Command {
cmds := c.cmds
c.cmds = nil
return cmds
}
func (c *conn) PeekPipeline() []Command {
return c.cmds
}
func (c *conn) NetConn() net.Conn {
return c.conn
}
// BaseWriter returns the underlying connection writer, if any
func BaseWriter(c Conn) *Writer {
if c, ok := c.(*conn); ok {
return c.wr
}
return nil
}
// DetachedConn represents a connection that is detached from the server
type DetachedConn interface {
// Conn is the original connection
Conn
// ReadCommand reads the next client command.
ReadCommand() (Command, error)
// Flush flushes any writes to the network.
Flush() error
}
// Detach removes the current connection from the server loop and returns
// a detached connection. This is useful for operations such as PubSub.
// The detached connection must be closed by calling Close() when done.
// All writes such as WriteString() will not be written to the client
// until Flush() is called.
func (c *conn) Detach() DetachedConn {
c.detached = true
cmds := c.cmds
c.cmds = nil
return &detachedConn{conn: c, cmds: cmds}
}
type detachedConn struct {
*conn
cmds []Command
}
// Flush writes and Write* calls to the client.
func (dc *detachedConn) Flush() error {
return dc.conn.wr.Flush()
}
// ReadCommand read the next command from the client.
func (dc *detachedConn) ReadCommand() (Command, error) {
if len(dc.cmds) > 0 {
cmd := dc.cmds[0]
if len(dc.cmds) == 1 {
dc.cmds = nil
} else {
dc.cmds = dc.cmds[1:]
}
return cmd, nil
}
cmd, err := dc.rd.ReadCommand()
if err != nil {
return Command{}, err
}
return cmd, nil
}
// Command represent a command
type Command struct {
// Raw is a encoded RESP message.
Raw []byte
// Args is a series of arguments that make up the command.
Args [][]byte
}
// Server defines a server for clients for managing client connections.
type Server struct {
ctx context.Context
mu sync.Mutex
net string
laddr string
handler func(conn Conn, cmd Command)
accept func(conn Conn) bool
closed func(conn Conn, err error)
conns map[*conn]bool
ln net.Listener
done bool
idleClose time.Duration
// AcceptError is an optional function used to handle Accept errors.
AcceptError func(err error)
}
// TLSServer defines a server for clients for managing client connections.
type TLSServer struct {
*Server
config *tls.Config
}
// Writer allows for writing RESP messages.
type Writer struct {
w io.Writer
b []byte
}
// NewWriter creates a new RESP writer.
func NewWriter(wr io.Writer) *Writer {
return &Writer{
w: wr,
}
}
// WriteNull writes a null to the client
func (w *Writer) WriteNull() {
w.b = AppendNull(w.b)
}
// WriteArray writes an array header. You must then write additional
// sub-responses to the client to complete the response.
// For example to write two strings:
//
// c.WriteArray(2)
// c.WriteBulk("item 1")
// c.WriteBulk("item 2")
func (w *Writer) WriteArray(count int) {
w.b = AppendArray(w.b, count)
}
// WriteBulk writes bulk bytes to the client.
func (w *Writer) WriteBulk(bulk []byte) {
w.b = AppendBulk(w.b, bulk)
}
// WriteBulkString writes a bulk string to the client.
func (w *Writer) WriteBulkString(bulk string) {
w.b = AppendBulkString(w.b, bulk)
}
// Buffer returns the unflushed buffer. This is a copy so changes
// to the resulting []byte will not affect the writer.
func (w *Writer) Buffer() []byte {
return append([]byte(nil), w.b...)
}
// SetBuffer replaces the unflushed buffer with new bytes.
func (w *Writer) SetBuffer(raw []byte) {
w.b = w.b[:0]
w.b = append(w.b, raw...)
}
// Flush writes all unflushed Write* calls to the underlying writer.
func (w *Writer) Flush() error {
if _, err := w.w.Write(w.b); err != nil {
return err
}
w.b = w.b[:0]
return nil
}
// WriteError writes an error to the client.
func (w *Writer) WriteError(msg string) {
w.b = AppendError(w.b, msg)
}
// WriteString writes a string to the client.
func (w *Writer) WriteString(msg string) {
w.b = AppendString(w.b, msg)
}
// WriteInt writes an integer to the client.
func (w *Writer) WriteInt(num int) {
w.WriteInt64(int64(num))
}
// WriteInt64 writes a 64-bit signed integer to the client.
func (w *Writer) WriteInt64(num int64) {
w.b = AppendInt(w.b, num)
}
// WriteUint64 writes a 64-bit unsigned integer to the client.
func (w *Writer) WriteUint64(num uint64) {
w.b = AppendUint(w.b, num)
}
// WriteRaw writes raw data to the client.
func (w *Writer) WriteRaw(data []byte) {
w.b = append(w.b, data...)
}
// WriteAny writes any type to client.
// nil -> null
// error -> error (adds "ERR " when first word is not uppercase)
// string -> bulk-string
// numbers -> bulk-string
// []byte -> bulk-string
// bool -> bulk-string ("0" or "1")
// slice -> array
// map -> array with key/value pairs
// SimpleString -> string
// SimpleInt -> integer
// everything-else -> bulk-string representation using fmt.Sprint()
func (w *Writer) WriteAny(v interface{}) {
w.b = AppendAny(w.b, v)
}
// Reader represent a reader for RESP or telnet commands.
type Reader struct {
rd *bufio.Reader
buf []byte
start int
end int
cmds []Command
}
// NewReader returns a command reader which will read RESP or telnet commands.
func NewReader(rd io.Reader) *Reader {
return &Reader{
rd: bufio.NewReader(rd),
buf: make([]byte, 4096),
}
}
func parseInt(b []byte) (int, bool) {
if len(b) == 1 && b[0] >= '0' && b[0] <= '9' {
return int(b[0] - '0'), true
}
var n int
var sign bool
var i int
if len(b) > 0 && b[0] == '-' {
sign = true
i++
}
for ; i < len(b); i++ {
if b[i] < '0' || b[i] > '9' {
return 0, false
}
n = n*10 + int(b[i]-'0')
}
if sign {
n *= -1
}
return n, true
}
func (rd *Reader) readCommands(leftover *int) ([]Command, error) {
var cmds []Command
b := rd.buf[rd.start:rd.end]
if rd.end-rd.start == 0 && len(rd.buf) > 4096 {
rd.buf = rd.buf[:4096]
rd.start = 0
rd.end = 0
}
if len(b) > 0 {
// we have data, yay!
// but is this enough data for a complete command? or multiple?
next:
switch b[0] {
default:
// just a plain text command
for i := 0; i < len(b); i++ {
if b[i] == '\n' {
var line []byte
if i > 0 && b[i-1] == '\r' {
line = b[:i-1]
} else {
line = b[:i]
}
var cmd Command
var quote bool
var quotech byte
var escape bool
outer:
for {
nline := make([]byte, 0, len(line))
for i := 0; i < len(line); i++ {
c := line[i]
if !quote {
if c == ' ' {
if len(nline) > 0 {
cmd.Args = append(cmd.Args, nline)
}
line = line[i+1:]
continue outer
}
if c == '"' || c == '\'' {
if i != 0 {
return nil, errUnbalancedQuotes
}
quotech = c
quote = true
line = line[i+1:]
continue outer
}
} else {
if escape {
escape = false
switch c {
case 'n':
c = '\n'
case 'r':
c = '\r'
case 't':
c = '\t'
}
} else if c == quotech {
quote = false
quotech = 0
cmd.Args = append(cmd.Args, nline)
line = line[i+1:]
if len(line) > 0 && line[0] != ' ' {
return nil, errUnbalancedQuotes
}
continue outer
} else if c == '\\' {
escape = true
continue
}
}
nline = append(nline, c)
}
if quote {
return nil, errUnbalancedQuotes
}
if len(line) > 0 {
cmd.Args = append(cmd.Args, line)
}
break
}
if len(cmd.Args) > 0 {
// convert this to resp command syntax
var wr Writer
wr.WriteArray(len(cmd.Args))
for i := range cmd.Args {
wr.WriteBulk(cmd.Args[i])
cmd.Args[i] = append([]byte(nil), cmd.Args[i]...)
}
cmd.Raw = wr.b
cmds = append(cmds, cmd)
}
b = b[i+1:]
if len(b) > 0 {
goto next
} else {
goto done
}
}
}
case '*':
// resp formatted command
marks := make([]int, 0, 16)
outer2:
for i := 1; i < len(b); i++ {
if b[i] == '\n' {
if b[i-1] != '\r' {
return nil, errInvalidMultiBulkLength
}
count, ok := parseInt(b[1 : i-1])
if !ok || count <= 0 {
return nil, errInvalidMultiBulkLength
}
marks = marks[:0]
for j := 0; j < count; j++ {
// read bulk length
i++
if i < len(b) {
if b[i] != '$' {
return nil, &errProtocol{"expected '$', got '" +
string(b[i]) + "'"}
}
si := i
for ; i < len(b); i++ {
if b[i] == '\n' {
if b[i-1] != '\r' {
return nil, errInvalidBulkLength
}
size, ok := parseInt(b[si+1 : i-1])
if !ok || size < 0 {
return nil, errInvalidBulkLength
}
if i+size+2 >= len(b) {
// not ready
break outer2
}
if b[i+size+2] != '\n' ||
b[i+size+1] != '\r' {
return nil, errInvalidBulkLength
}
i++
marks = append(marks, i, i+size)
i += size + 1
break
}
}
}
}
if len(marks) == count*2 {
var cmd Command
if rd.rd != nil {
// make a raw copy of the entire command when
// there's a underlying reader.
cmd.Raw = append([]byte(nil), b[:i+1]...)
} else {
// just assign the slice
cmd.Raw = b[:i+1]
}
cmd.Args = make([][]byte, len(marks)/2)
// slice up the raw command into the args based on
// the recorded marks.
for h := 0; h < len(marks); h += 2 {
cmd.Args[h/2] = cmd.Raw[marks[h]:marks[h+1]]
}
cmds = append(cmds, cmd)
b = b[i+1:]
if len(b) > 0 {
goto next
} else {
goto done
}
}
}
}
}
done:
rd.start = rd.end - len(b)
}
if leftover != nil {
*leftover = rd.end - rd.start
}
if len(cmds) > 0 {
return cmds, nil
}
if rd.rd == nil {
return nil, errIncompleteCommand
}
if rd.end == len(rd.buf) {
// at the end of the buffer.
if rd.start == rd.end {
// rewind the to the beginning
rd.start, rd.end = 0, 0
} else {
// must grow the buffer
newbuf := make([]byte, len(rd.buf)*2)
copy(newbuf, rd.buf)
rd.buf = newbuf
}
}
n, err := rd.rd.Read(rd.buf[rd.end:])
if err != nil {
return nil, err
}
rd.end += n
return rd.readCommands(leftover)
}
// ReadCommands reads the next pipeline commands.
func (rd *Reader) ReadCommands() ([]Command, error) {
for {
if len(rd.cmds) > 0 {
cmds := rd.cmds
rd.cmds = nil
return cmds, nil
}
cmds, err := rd.readCommands(nil)
if err != nil {
return []Command{}, err
}
rd.cmds = cmds
}
}
// ReadCommand reads the next command.
func (rd *Reader) ReadCommand() (Command, error) {
if len(rd.cmds) > 0 {
cmd := rd.cmds[0]
rd.cmds = rd.cmds[1:]
return cmd, nil
}
cmds, err := rd.readCommands(nil)
if err != nil {
return Command{}, err
}
rd.cmds = cmds
return rd.ReadCommand()
}
// Parse parses a raw RESP message and returns a command.
func Parse(raw []byte) (Command, error) {
rd := Reader{buf: raw, end: len(raw)}
var leftover int
cmds, err := rd.readCommands(&leftover)
if err != nil {
return Command{}, err
}
if leftover > 0 {
return Command{}, errTooMuchData
}
return cmds[0], nil
}
// A Handler responds to an RESP request.
type Handler interface {
ServeRESP(conn Conn, cmd Command)
}
// The HandlerFunc type is an adapter to allow the use of
// ordinary functions as RESP handlers. If f is a function
// with the appropriate signature, HandlerFunc(f) is a
// Handler that calls f.
type HandlerFunc func(conn Conn, cmd Command)
// ServeRESP calls f(w, r)
func (f HandlerFunc) ServeRESP(conn Conn, cmd Command) {
f(conn, cmd)
}
// ServeMux is an RESP command multiplexer.
type ServeMux struct {
handlers map[string]Handler
}
// NewServeMux allocates and returns a new ServeMux.
func NewServeMux() *ServeMux {
return &ServeMux{
handlers: make(map[string]Handler),
}
}
// HandleFunc registers the handler function for the given command.
func (m *ServeMux) HandleFunc(command string, handler func(conn Conn, cmd Command)) {
if handler == nil {
panic("redcon: nil handler")
}
m.Handle(command, HandlerFunc(handler))
}
// Handle registers the handler for the given command.
// If a handler already exists for command, Handle panics.
func (m *ServeMux) Handle(command string, handler Handler) {
if command == "" {
panic("redcon: invalid command")
}
if handler == nil {
panic("redcon: nil handler")
}
if _, exist := m.handlers[command]; exist {
panic("redcon: multiple registrations for " + command)
}
m.handlers[command] = handler
}
// ServeRESP dispatches the command to the handler.
func (m *ServeMux) ServeRESP(conn Conn, cmd Command) {
command := strings.ToLower(string(cmd.Args[0]))
if handler, ok := m.handlers[command]; ok {
handler.ServeRESP(conn, cmd)
} else {
conn.WriteError("ERR unknown command '" + command + "'")
}
}
// PubSub is a Redis compatible pub/sub server
type PubSub struct {
mu sync.RWMutex
nextid uint64
initd bool
chans *btree.BTree
conns map[Conn]*pubSubConn
}
// Subscribe a connection to PubSub
func (ps *PubSub) Subscribe(conn Conn, channel string) {
ps.subscribe(conn, false, channel)
}
// Psubscribe a connection to PubSub
func (ps *PubSub) Psubscribe(conn Conn, channel string) {
ps.subscribe(conn, true, channel)
}
// Publish a message to subscribers
func (ps *PubSub) Publish(channel, message string) int {
ps.mu.RLock()
defer ps.mu.RUnlock()
if !ps.initd {
return 0
}
var sent int
// write messages to all clients that are subscribed on the channel
pivot := &pubSubEntry{pattern: false, channel: channel}
ps.chans.Ascend(pivot, func(item interface{}) bool {
entry := item.(*pubSubEntry)
if entry.channel != pivot.channel || entry.pattern != pivot.pattern {
return false
}
entry.sconn.writeMessage(entry.pattern, "", channel, message)
sent++
return true
})
// match on and write all psubscribe clients
pivot = &pubSubEntry{pattern: true}
ps.chans.Ascend(pivot, func(item interface{}) bool {
entry := item.(*pubSubEntry)
if match.Match(channel, entry.channel) {
entry.sconn.writeMessage(entry.pattern, entry.channel, channel,
message)
}
sent++
return true
})
return sent
}
type pubSubConn struct {
id uint64
mu sync.Mutex
conn Conn
dconn DetachedConn
entries map[*pubSubEntry]bool
}
type pubSubEntry struct {
pattern bool
sconn *pubSubConn
channel string
}
func (sconn *pubSubConn) writeMessage(pat bool, pchan, channel, msg string) {
sconn.mu.Lock()
defer sconn.mu.Unlock()
if pat {
sconn.dconn.WriteArray(4)
sconn.dconn.WriteBulkString("pmessage")
sconn.dconn.WriteBulkString(pchan)
sconn.dconn.WriteBulkString(channel)
sconn.dconn.WriteBulkString(msg)
} else {
sconn.dconn.WriteArray(3)
sconn.dconn.WriteBulkString("message")
sconn.dconn.WriteBulkString(channel)
sconn.dconn.WriteBulkString(msg)
}
sconn.dconn.Flush()
}
// bgrunner runs in the background and reads incoming commands from the
// detached client.
func (sconn *pubSubConn) bgrunner(ps *PubSub) {
defer func() {
// client connection has ended, disconnect from the PubSub instances
// and close the network connection.
ps.mu.Lock()
defer ps.mu.Unlock()
for entry := range sconn.entries {
ps.chans.Delete(entry)
}
delete(ps.conns, sconn.conn)
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.Close()
}()
for {
cmd, err := sconn.dconn.ReadCommand()
if err != nil {
return
}
if len(cmd.Args) == 0 {
continue
}
switch strings.ToLower(string(cmd.Args[0])) {
case "psubscribe", "subscribe":
if len(cmd.Args) < 2 {
func() {
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.WriteError(fmt.Sprintf("ERR wrong number of "+
"arguments for '%s'", cmd.Args[0]))
sconn.dconn.Flush()
}()
continue
}
command := strings.ToLower(string(cmd.Args[0]))
for i := 1; i < len(cmd.Args); i++ {
if command == "psubscribe" {
ps.Psubscribe(sconn.conn, string(cmd.Args[i]))
} else {
ps.Subscribe(sconn.conn, string(cmd.Args[i]))
}
}
case "unsubscribe", "punsubscribe":
pattern := strings.ToLower(string(cmd.Args[0])) == "punsubscribe"
if len(cmd.Args) == 1 {
ps.unsubscribe(sconn.conn, pattern, true, "")
} else {
for i := 1; i < len(cmd.Args); i++ {
channel := string(cmd.Args[i])
ps.unsubscribe(sconn.conn, pattern, false, channel)
}
}
case "quit":
func() {
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.WriteString("OK")
sconn.dconn.Flush()
sconn.dconn.Close()
}()
return
case "ping":
var msg string
switch len(cmd.Args) {
case 1:
case 2:
msg = string(cmd.Args[1])
default:
func() {
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.WriteError(fmt.Sprintf("ERR wrong number of "+
"arguments for '%s'", cmd.Args[0]))
sconn.dconn.Flush()
}()
continue
}
func() {
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.WriteArray(2)
sconn.dconn.WriteBulkString("pong")
sconn.dconn.WriteBulkString(msg)
sconn.dconn.Flush()
}()
default:
func() {
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.WriteError(fmt.Sprintf("ERR Can't execute '%s': "+
"only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT are "+
"allowed in this context", cmd.Args[0]))
sconn.dconn.Flush()
}()
}
}
}
// byEntry is a "less" function that sorts the entries in a btree. The tree
// is sorted be (pattern, channel, conn.id). All pattern=true entries are at
// the end (right) of the tree.
func byEntry(a, b interface{}) bool {
aa := a.(*pubSubEntry)
bb := b.(*pubSubEntry)
if !aa.pattern && bb.pattern {
return true
}
if aa.pattern && !bb.pattern {
return false
}
if aa.channel < bb.channel {
return true
}
if aa.channel > bb.channel {
return false
}
var aid uint64
var bid uint64
if aa.sconn != nil {
aid = aa.sconn.id
}
if bb.sconn != nil {
bid = bb.sconn.id
}
return aid < bid
}
func (ps *PubSub) subscribe(conn Conn, pattern bool, channel string) {
ps.mu.Lock()
defer ps.mu.Unlock()
// initialize the PubSub instance
if !ps.initd {
ps.conns = make(map[Conn]*pubSubConn)
ps.chans = btree.New(byEntry)
ps.initd = true
}
// fetch the pubSubConn
sconn, ok := ps.conns[conn]
if !ok {
// initialize a new pubSubConn, which runs on a detached connection,
// and attach it to the PubSub channels/conn btree
ps.nextid++
dconn := conn.Detach()
sconn = &pubSubConn{
id: ps.nextid,
conn: conn,
dconn: dconn,
entries: make(map[*pubSubEntry]bool),
}
ps.conns[conn] = sconn
}
sconn.mu.Lock()
defer sconn.mu.Unlock()
// add an entry to the pubsub btree
entry := &pubSubEntry{
pattern: pattern,
channel: channel,
sconn: sconn,
}
ps.chans.Set(entry)
sconn.entries[entry] = true
// send a message to the client
sconn.dconn.WriteArray(3)
if pattern {
sconn.dconn.WriteBulkString("psubscribe")
} else {
sconn.dconn.WriteBulkString("subscribe")
}
sconn.dconn.WriteBulkString(channel)
var count int
for ient := range sconn.entries {
if ient.pattern == pattern {
count++
}
}
sconn.dconn.WriteInt(count)
sconn.dconn.Flush()
// start the background client operation
if !ok {
go sconn.bgrunner(ps)
}
}
func (ps *PubSub) unsubscribe(conn Conn, pattern, all bool, channel string) {
ps.mu.Lock()
defer ps.mu.Unlock()
// fetch the pubSubConn. This must exist
sconn := ps.conns[conn]
sconn.mu.Lock()
defer sconn.mu.Unlock()
removeEntry := func(entry *pubSubEntry) {
if entry != nil {
ps.chans.Delete(entry)
delete(sconn.entries, entry)
}
sconn.dconn.WriteArray(3)
if pattern {
sconn.dconn.WriteBulkString("punsubscribe")
} else {
sconn.dconn.WriteBulkString("unsubscribe")
}
if entry != nil {
sconn.dconn.WriteBulkString(entry.channel)
} else {
sconn.dconn.WriteNull()
}
var count int
for ient := range sconn.entries {
if ient.pattern == pattern {
count++
}
}
sconn.dconn.WriteInt(count)
}
if all {
// unsubscribe from all (p)subscribe entries
var entries []*pubSubEntry
for ient := range sconn.entries {
if ient.pattern == pattern {
entries = append(entries, ient)
}
}
if len(entries) == 0 {
removeEntry(nil)
} else {
for _, entry := range entries {
removeEntry(entry)
}
}
} else {
// unsubscribe single channel from (p)subscribe.
for ient := range sconn.entries {
if ient.pattern == pattern && ient.channel == channel {
removeEntry(ient)
break
}
}
}
sconn.dconn.Flush()
}
// SetIdleClose will automatically close idle connections after the specified
// duration. Use zero to disable this feature.
func (s *Server) SetIdleClose(dur time.Duration) {
s.mu.Lock()
s.idleClose = dur
s.mu.Unlock()
}