forked from mirror/redcon
1378 lines
33 KiB
Go
1378 lines
33 KiB
Go
// Package redcon implements a Redis compatible server framework
|
|
package redcon
|
|
|
|
import (
|
|
"bufio"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/tidwall/btree"
|
|
"github.com/tidwall/match"
|
|
)
|
|
|
|
var (
|
|
errUnbalancedQuotes = &errProtocol{"unbalanced quotes in request"}
|
|
errInvalidBulkLength = &errProtocol{"invalid bulk length"}
|
|
errInvalidMultiBulkLength = &errProtocol{"invalid multibulk length"}
|
|
errDetached = errors.New("detached")
|
|
errIncompleteCommand = errors.New("incomplete command")
|
|
errTooMuchData = errors.New("too much data")
|
|
)
|
|
|
|
type errProtocol struct {
|
|
msg string
|
|
}
|
|
|
|
func (err *errProtocol) Error() string {
|
|
return "Protocol error: " + err.msg
|
|
}
|
|
|
|
// Conn represents a client connection
|
|
type Conn interface {
|
|
// RemoteAddr returns the remote address of the client connection.
|
|
RemoteAddr() string
|
|
// Close closes the connection.
|
|
Close() error
|
|
// WriteError writes an error to the client.
|
|
WriteError(msg string)
|
|
// WriteString writes a string to the client.
|
|
WriteString(str string)
|
|
// WriteBulk writes bulk bytes to the client.
|
|
WriteBulk(bulk []byte)
|
|
// WriteBulkString writes a bulk string to the client.
|
|
WriteBulkString(bulk string)
|
|
// WriteInt writes an integer to the client.
|
|
WriteInt(num int)
|
|
// WriteInt64 writes a 64-bit signed integer to the client.
|
|
WriteInt64(num int64)
|
|
// WriteUint64 writes a 64-bit unsigned integer to the client.
|
|
WriteUint64(num uint64)
|
|
// WriteArray writes an array header. You must then write additional
|
|
// sub-responses to the client to complete the response.
|
|
// For example to write two strings:
|
|
//
|
|
// c.WriteArray(2)
|
|
// c.WriteBulkString("item 1")
|
|
// c.WriteBulkString("item 2")
|
|
WriteArray(count int)
|
|
// WriteNull writes a null to the client
|
|
WriteNull()
|
|
// WriteRaw writes raw data to the client.
|
|
WriteRaw(data []byte)
|
|
// WriteAny writes any type to the client.
|
|
// nil -> null
|
|
// error -> error (adds "ERR " when first word is not uppercase)
|
|
// string -> bulk-string
|
|
// numbers -> bulk-string
|
|
// []byte -> bulk-string
|
|
// bool -> bulk-string ("0" or "1")
|
|
// slice -> array
|
|
// map -> array with key/value pairs
|
|
// SimpleString -> string
|
|
// SimpleInt -> integer
|
|
// everything-else -> bulk-string representation using fmt.Sprint()
|
|
WriteAny(any interface{})
|
|
// Context returns a user-defined context
|
|
Context() interface{}
|
|
// SetContext sets a user-defined context
|
|
SetContext(v interface{})
|
|
// SetReadBuffer updates the buffer read size for the connection
|
|
SetReadBuffer(bytes int)
|
|
// Detach return a connection that is detached from the server.
|
|
// Useful for operations like PubSub.
|
|
//
|
|
// dconn := conn.Detach()
|
|
// go func(){
|
|
// defer dconn.Close()
|
|
// cmd, err := dconn.ReadCommand()
|
|
// if err != nil{
|
|
// fmt.Printf("read failed: %v\n", err)
|
|
// return
|
|
// }
|
|
// fmt.Printf("received command: %v", cmd)
|
|
// hconn.WriteString("OK")
|
|
// if err := dconn.Flush(); err != nil{
|
|
// fmt.Printf("write failed: %v\n", err)
|
|
// return
|
|
// }
|
|
// }()
|
|
Detach() DetachedConn
|
|
// ReadPipeline returns all commands in current pipeline, if any
|
|
// The commands are removed from the pipeline.
|
|
ReadPipeline() []Command
|
|
// PeekPipeline returns all commands in current pipeline, if any.
|
|
// The commands remain in the pipeline.
|
|
PeekPipeline() []Command
|
|
// NetConn returns the base net.Conn connection
|
|
NetConn() net.Conn
|
|
}
|
|
|
|
// NewServer returns a new Redcon server configured on "tcp" network net.
|
|
func NewServer(addr string,
|
|
handler func(conn Conn, cmd Command),
|
|
accept func(conn Conn) bool,
|
|
closed func(conn Conn, err error),
|
|
) *Server {
|
|
return NewServerNetwork("tcp", addr, handler, accept, closed)
|
|
}
|
|
|
|
// NewServerTLS returns a new Redcon TLS server configured on "tcp" network net.
|
|
func NewServerTLS(addr string,
|
|
handler func(conn Conn, cmd Command),
|
|
accept func(conn Conn) bool,
|
|
closed func(conn Conn, err error),
|
|
config *tls.Config,
|
|
) *TLSServer {
|
|
return NewServerNetworkTLS("tcp", addr, handler, accept, closed, config)
|
|
}
|
|
|
|
// NewServerNetwork returns a new Redcon server. The network net must be
|
|
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
|
|
func NewServerNetwork(
|
|
net, laddr string,
|
|
handler func(conn Conn, cmd Command),
|
|
accept func(conn Conn) bool,
|
|
closed func(conn Conn, err error),
|
|
) *Server {
|
|
if handler == nil {
|
|
panic("handler is nil")
|
|
}
|
|
s := &Server{
|
|
net: net,
|
|
laddr: laddr,
|
|
handler: handler,
|
|
accept: accept,
|
|
closed: closed,
|
|
conns: make(map[*conn]bool),
|
|
}
|
|
return s
|
|
}
|
|
|
|
// NewServerNetworkTLS returns a new TLS Redcon server. The network net must be
|
|
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
|
|
func NewServerNetworkTLS(
|
|
net, laddr string,
|
|
handler func(conn Conn, cmd Command),
|
|
accept func(conn Conn) bool,
|
|
closed func(conn Conn, err error),
|
|
config *tls.Config,
|
|
) *TLSServer {
|
|
if handler == nil {
|
|
panic("handler is nil")
|
|
}
|
|
s := Server{
|
|
net: net,
|
|
laddr: laddr,
|
|
handler: handler,
|
|
accept: accept,
|
|
closed: closed,
|
|
conns: make(map[*conn]bool),
|
|
}
|
|
|
|
tls := &TLSServer{
|
|
config: config,
|
|
Server: &s,
|
|
}
|
|
return tls
|
|
}
|
|
|
|
// Close stops listening on the TCP address.
|
|
// Already Accepted connections will be closed.
|
|
func (s *Server) Close() error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.ln == nil {
|
|
return errors.New("not serving")
|
|
}
|
|
s.done = true
|
|
return s.ln.Close()
|
|
}
|
|
|
|
// ListenAndServe serves incoming connections.
|
|
func (s *Server) ListenAndServe() error {
|
|
return s.ListenServeAndSignal(nil)
|
|
}
|
|
|
|
// Addr returns server's listen address
|
|
func (s *Server) Addr() net.Addr {
|
|
return s.ln.Addr()
|
|
}
|
|
|
|
// Close stops listening on the TCP address.
|
|
// Already Accepted connections will be closed.
|
|
func (s *TLSServer) Close() error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.ln == nil {
|
|
return errors.New("not serving")
|
|
}
|
|
s.done = true
|
|
return s.ln.Close()
|
|
}
|
|
|
|
// ListenAndServe serves incoming connections.
|
|
func (s *TLSServer) ListenAndServe() error {
|
|
return s.ListenServeAndSignal(nil)
|
|
}
|
|
|
|
// Serve creates a new server and serves with the given net.Listener.
|
|
func Serve(ln net.Listener,
|
|
handler func(conn Conn, cmd Command),
|
|
accept func(conn Conn) bool,
|
|
closed func(conn Conn, err error),
|
|
) error {
|
|
s := &Server{
|
|
net: ln.Addr().Network(),
|
|
laddr: ln.Addr().String(),
|
|
ln: ln,
|
|
handler: handler,
|
|
accept: accept,
|
|
closed: closed,
|
|
conns: make(map[*conn]bool),
|
|
}
|
|
|
|
return serve(s)
|
|
}
|
|
|
|
// ListenAndServe creates a new server and binds to addr configured on "tcp" network net.
|
|
func ListenAndServe(addr string,
|
|
handler func(conn Conn, cmd Command),
|
|
accept func(conn Conn) bool,
|
|
closed func(conn Conn, err error),
|
|
) error {
|
|
return ListenAndServeNetwork("tcp", addr, handler, accept, closed)
|
|
}
|
|
|
|
// ListenAndServeTLS creates a new TLS server and binds to addr configured on "tcp" network net.
|
|
func ListenAndServeTLS(addr string,
|
|
handler func(conn Conn, cmd Command),
|
|
accept func(conn Conn) bool,
|
|
closed func(conn Conn, err error),
|
|
config *tls.Config,
|
|
) error {
|
|
return ListenAndServeNetworkTLS("tcp", addr, handler, accept, closed, config)
|
|
}
|
|
|
|
// ListenAndServeNetwork creates a new server and binds to addr. The network net must be
|
|
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
|
|
func ListenAndServeNetwork(
|
|
net, laddr string,
|
|
handler func(conn Conn, cmd Command),
|
|
accept func(conn Conn) bool,
|
|
closed func(conn Conn, err error),
|
|
) error {
|
|
return NewServerNetwork(net, laddr, handler, accept, closed).ListenAndServe()
|
|
}
|
|
|
|
// ListenAndServeNetworkTLS creates a new TLS server and binds to addr. The network net must be
|
|
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
|
|
func ListenAndServeNetworkTLS(
|
|
net, laddr string,
|
|
handler func(conn Conn, cmd Command),
|
|
accept func(conn Conn) bool,
|
|
closed func(conn Conn, err error),
|
|
config *tls.Config,
|
|
) error {
|
|
return NewServerNetworkTLS(net, laddr, handler, accept, closed, config).ListenAndServe()
|
|
}
|
|
|
|
// ListenServeAndSignal serves incoming connections and passes nil or error
|
|
// when listening. signal can be nil.
|
|
func (s *Server) ListenServeAndSignal(signal chan error) error {
|
|
ln, err := net.Listen(s.net, s.laddr)
|
|
if err != nil {
|
|
if signal != nil {
|
|
signal <- err
|
|
}
|
|
return err
|
|
}
|
|
s.ln = ln
|
|
if signal != nil {
|
|
signal <- nil
|
|
}
|
|
return serve(s)
|
|
}
|
|
|
|
// Serve serves incoming connections with the given net.Listener.
|
|
func (s *Server) Serve(ln net.Listener) error {
|
|
s.ln = ln
|
|
s.net = ln.Addr().Network()
|
|
s.laddr = ln.Addr().String()
|
|
return serve(s)
|
|
}
|
|
|
|
// ListenServeAndSignal serves incoming connections and passes nil or error
|
|
// when listening. signal can be nil.
|
|
func (s *TLSServer) ListenServeAndSignal(signal chan error) error {
|
|
ln, err := tls.Listen(s.net, s.laddr, s.config)
|
|
if err != nil {
|
|
if signal != nil {
|
|
signal <- err
|
|
}
|
|
return err
|
|
}
|
|
s.ln = ln
|
|
if signal != nil {
|
|
signal <- nil
|
|
}
|
|
return serve(s.Server)
|
|
}
|
|
|
|
func serve(s *Server) error {
|
|
defer func() {
|
|
s.ln.Close()
|
|
func() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
for c := range s.conns {
|
|
c.Close()
|
|
}
|
|
s.conns = nil
|
|
}()
|
|
}()
|
|
for {
|
|
lnconn, err := s.ln.Accept()
|
|
if err != nil {
|
|
s.mu.Lock()
|
|
done := s.done
|
|
s.mu.Unlock()
|
|
if done {
|
|
return nil
|
|
}
|
|
if s.AcceptError != nil {
|
|
s.AcceptError(err)
|
|
}
|
|
continue
|
|
}
|
|
c := &conn{
|
|
conn: lnconn,
|
|
addr: lnconn.RemoteAddr().String(),
|
|
wr: NewWriter(lnconn),
|
|
rd: NewReader(lnconn),
|
|
}
|
|
s.mu.Lock()
|
|
c.idleClose = s.idleClose
|
|
s.conns[c] = true
|
|
s.mu.Unlock()
|
|
if s.accept != nil && !s.accept(c) {
|
|
s.mu.Lock()
|
|
delete(s.conns, c)
|
|
s.mu.Unlock()
|
|
c.Close()
|
|
continue
|
|
}
|
|
go handle(s, c)
|
|
}
|
|
}
|
|
|
|
// handle manages the server connection.
|
|
func handle(s *Server, c *conn) {
|
|
var err error
|
|
defer func() {
|
|
if err != errDetached {
|
|
// do not close the connection when a detach is detected.
|
|
c.conn.Close()
|
|
}
|
|
func() {
|
|
// remove the conn from the server
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
delete(s.conns, c)
|
|
if s.closed != nil {
|
|
if err == io.EOF {
|
|
err = nil
|
|
}
|
|
s.closed(c, err)
|
|
}
|
|
}()
|
|
}()
|
|
|
|
err = func() error {
|
|
// read commands and feed back to the client
|
|
for {
|
|
// read pipeline commands
|
|
if c.idleClose != 0 {
|
|
c.conn.SetReadDeadline(time.Now().Add(c.idleClose))
|
|
}
|
|
cmds, err := c.rd.readCommands(nil)
|
|
if err != nil {
|
|
if err, ok := err.(*errProtocol); ok {
|
|
// All protocol errors should attempt a response to
|
|
// the client. Ignore write errors.
|
|
c.wr.WriteError("ERR " + err.Error())
|
|
c.wr.Flush()
|
|
}
|
|
return err
|
|
}
|
|
c.cmds = cmds
|
|
for len(c.cmds) > 0 {
|
|
cmd := c.cmds[0]
|
|
if len(c.cmds) == 1 {
|
|
c.cmds = nil
|
|
} else {
|
|
c.cmds = c.cmds[1:]
|
|
}
|
|
s.handler(c, cmd)
|
|
}
|
|
if c.detached {
|
|
// client has been detached
|
|
return errDetached
|
|
}
|
|
if c.closed {
|
|
return nil
|
|
}
|
|
if err := c.wr.Flush(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// conn represents a client connection
|
|
type conn struct {
|
|
conn net.Conn
|
|
wr *Writer
|
|
rd *Reader
|
|
addr string
|
|
ctx interface{}
|
|
detached bool
|
|
closed bool
|
|
cmds []Command
|
|
idleClose time.Duration
|
|
}
|
|
|
|
func (c *conn) Close() error {
|
|
c.wr.Flush()
|
|
c.closed = true
|
|
return c.conn.Close()
|
|
}
|
|
func (c *conn) Context() interface{} { return c.ctx }
|
|
func (c *conn) SetContext(v interface{}) { c.ctx = v }
|
|
func (c *conn) SetReadBuffer(n int) {}
|
|
func (c *conn) WriteString(str string) { c.wr.WriteString(str) }
|
|
func (c *conn) WriteBulk(bulk []byte) { c.wr.WriteBulk(bulk) }
|
|
func (c *conn) WriteBulkString(bulk string) { c.wr.WriteBulkString(bulk) }
|
|
func (c *conn) WriteInt(num int) { c.wr.WriteInt(num) }
|
|
func (c *conn) WriteInt64(num int64) { c.wr.WriteInt64(num) }
|
|
func (c *conn) WriteUint64(num uint64) { c.wr.WriteUint64(num) }
|
|
func (c *conn) WriteError(msg string) { c.wr.WriteError(msg) }
|
|
func (c *conn) WriteArray(count int) { c.wr.WriteArray(count) }
|
|
func (c *conn) WriteNull() { c.wr.WriteNull() }
|
|
func (c *conn) WriteRaw(data []byte) { c.wr.WriteRaw(data) }
|
|
func (c *conn) WriteAny(v interface{}) { c.wr.WriteAny(v) }
|
|
func (c *conn) RemoteAddr() string { return c.addr }
|
|
func (c *conn) ReadPipeline() []Command {
|
|
cmds := c.cmds
|
|
c.cmds = nil
|
|
return cmds
|
|
}
|
|
func (c *conn) PeekPipeline() []Command {
|
|
return c.cmds
|
|
}
|
|
func (c *conn) NetConn() net.Conn {
|
|
return c.conn
|
|
}
|
|
|
|
// BaseWriter returns the underlying connection writer, if any
|
|
func BaseWriter(c Conn) *Writer {
|
|
if c, ok := c.(*conn); ok {
|
|
return c.wr
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DetachedConn represents a connection that is detached from the server
|
|
type DetachedConn interface {
|
|
// Conn is the original connection
|
|
Conn
|
|
// ReadCommand reads the next client command.
|
|
ReadCommand() (Command, error)
|
|
// Flush flushes any writes to the network.
|
|
Flush() error
|
|
}
|
|
|
|
// Detach removes the current connection from the server loop and returns
|
|
// a detached connection. This is useful for operations such as PubSub.
|
|
// The detached connection must be closed by calling Close() when done.
|
|
// All writes such as WriteString() will not be written to the client
|
|
// until Flush() is called.
|
|
func (c *conn) Detach() DetachedConn {
|
|
c.detached = true
|
|
cmds := c.cmds
|
|
c.cmds = nil
|
|
return &detachedConn{conn: c, cmds: cmds}
|
|
}
|
|
|
|
type detachedConn struct {
|
|
*conn
|
|
cmds []Command
|
|
}
|
|
|
|
// Flush writes and Write* calls to the client.
|
|
func (dc *detachedConn) Flush() error {
|
|
return dc.conn.wr.Flush()
|
|
}
|
|
|
|
// ReadCommand read the next command from the client.
|
|
func (dc *detachedConn) ReadCommand() (Command, error) {
|
|
if len(dc.cmds) > 0 {
|
|
cmd := dc.cmds[0]
|
|
if len(dc.cmds) == 1 {
|
|
dc.cmds = nil
|
|
} else {
|
|
dc.cmds = dc.cmds[1:]
|
|
}
|
|
return cmd, nil
|
|
}
|
|
cmd, err := dc.rd.ReadCommand()
|
|
if err != nil {
|
|
return Command{}, err
|
|
}
|
|
return cmd, nil
|
|
}
|
|
|
|
// Command represent a command
|
|
type Command struct {
|
|
// Raw is a encoded RESP message.
|
|
Raw []byte
|
|
// Args is a series of arguments that make up the command.
|
|
Args [][]byte
|
|
}
|
|
|
|
// Server defines a server for clients for managing client connections.
|
|
type Server struct {
|
|
mu sync.Mutex
|
|
net string
|
|
laddr string
|
|
handler func(conn Conn, cmd Command)
|
|
accept func(conn Conn) bool
|
|
closed func(conn Conn, err error)
|
|
conns map[*conn]bool
|
|
ln net.Listener
|
|
done bool
|
|
idleClose time.Duration
|
|
|
|
// AcceptError is an optional function used to handle Accept errors.
|
|
AcceptError func(err error)
|
|
}
|
|
|
|
// TLSServer defines a server for clients for managing client connections.
|
|
type TLSServer struct {
|
|
*Server
|
|
config *tls.Config
|
|
}
|
|
|
|
// Writer allows for writing RESP messages.
|
|
type Writer struct {
|
|
w io.Writer
|
|
b []byte
|
|
}
|
|
|
|
// NewWriter creates a new RESP writer.
|
|
func NewWriter(wr io.Writer) *Writer {
|
|
return &Writer{
|
|
w: wr,
|
|
}
|
|
}
|
|
|
|
// WriteNull writes a null to the client
|
|
func (w *Writer) WriteNull() {
|
|
w.b = AppendNull(w.b)
|
|
}
|
|
|
|
// WriteArray writes an array header. You must then write additional
|
|
// sub-responses to the client to complete the response.
|
|
// For example to write two strings:
|
|
//
|
|
// c.WriteArray(2)
|
|
// c.WriteBulkString("item 1")
|
|
// c.WriteBulkString("item 2")
|
|
func (w *Writer) WriteArray(count int) {
|
|
w.b = AppendArray(w.b, count)
|
|
}
|
|
|
|
// WriteBulk writes bulk bytes to the client.
|
|
func (w *Writer) WriteBulk(bulk []byte) {
|
|
w.b = AppendBulk(w.b, bulk)
|
|
}
|
|
|
|
// WriteBulkString writes a bulk string to the client.
|
|
func (w *Writer) WriteBulkString(bulk string) {
|
|
w.b = AppendBulkString(w.b, bulk)
|
|
}
|
|
|
|
// Buffer returns the unflushed buffer. This is a copy so changes
|
|
// to the resulting []byte will not affect the writer.
|
|
func (w *Writer) Buffer() []byte {
|
|
return append([]byte(nil), w.b...)
|
|
}
|
|
|
|
// SetBuffer replaces the unflushed buffer with new bytes.
|
|
func (w *Writer) SetBuffer(raw []byte) {
|
|
w.b = w.b[:0]
|
|
w.b = append(w.b, raw...)
|
|
}
|
|
|
|
// Flush writes all unflushed Write* calls to the underlying writer.
|
|
func (w *Writer) Flush() error {
|
|
if _, err := w.w.Write(w.b); err != nil {
|
|
return err
|
|
}
|
|
w.b = w.b[:0]
|
|
return nil
|
|
}
|
|
|
|
// WriteError writes an error to the client.
|
|
func (w *Writer) WriteError(msg string) {
|
|
w.b = AppendError(w.b, msg)
|
|
}
|
|
|
|
// WriteString writes a string to the client.
|
|
func (w *Writer) WriteString(msg string) {
|
|
w.b = AppendString(w.b, msg)
|
|
}
|
|
|
|
// WriteInt writes an integer to the client.
|
|
func (w *Writer) WriteInt(num int) {
|
|
w.WriteInt64(int64(num))
|
|
}
|
|
|
|
// WriteInt64 writes a 64-bit signed integer to the client.
|
|
func (w *Writer) WriteInt64(num int64) {
|
|
w.b = AppendInt(w.b, num)
|
|
}
|
|
|
|
// WriteUint64 writes a 64-bit unsigned integer to the client.
|
|
func (w *Writer) WriteUint64(num uint64) {
|
|
w.b = AppendUint(w.b, num)
|
|
}
|
|
|
|
// WriteRaw writes raw data to the client.
|
|
func (w *Writer) WriteRaw(data []byte) {
|
|
w.b = append(w.b, data...)
|
|
}
|
|
|
|
// WriteAny writes any type to client.
|
|
// nil -> null
|
|
// error -> error (adds "ERR " when first word is not uppercase)
|
|
// string -> bulk-string
|
|
// numbers -> bulk-string
|
|
// []byte -> bulk-string
|
|
// bool -> bulk-string ("0" or "1")
|
|
// slice -> array
|
|
// map -> array with key/value pairs
|
|
// SimpleString -> string
|
|
// SimpleInt -> integer
|
|
// everything-else -> bulk-string representation using fmt.Sprint()
|
|
func (w *Writer) WriteAny(v interface{}) {
|
|
w.b = AppendAny(w.b, v)
|
|
}
|
|
|
|
// Reader represent a reader for RESP or telnet commands.
|
|
type Reader struct {
|
|
rd *bufio.Reader
|
|
buf []byte
|
|
start int
|
|
end int
|
|
cmds []Command
|
|
}
|
|
|
|
// NewReader returns a command reader which will read RESP or telnet commands.
|
|
func NewReader(rd io.Reader) *Reader {
|
|
return &Reader{
|
|
rd: bufio.NewReader(rd),
|
|
buf: make([]byte, 4096),
|
|
}
|
|
}
|
|
|
|
func parseInt(b []byte) (int, bool) {
|
|
if len(b) == 1 && b[0] >= '0' && b[0] <= '9' {
|
|
return int(b[0] - '0'), true
|
|
}
|
|
var n int
|
|
var sign bool
|
|
var i int
|
|
if len(b) > 0 && b[0] == '-' {
|
|
sign = true
|
|
i++
|
|
}
|
|
for ; i < len(b); i++ {
|
|
if b[i] < '0' || b[i] > '9' {
|
|
return 0, false
|
|
}
|
|
n = n*10 + int(b[i]-'0')
|
|
}
|
|
if sign {
|
|
n *= -1
|
|
}
|
|
return n, true
|
|
}
|
|
|
|
func (rd *Reader) readCommands(leftover *int) ([]Command, error) {
|
|
var cmds []Command
|
|
b := rd.buf[rd.start:rd.end]
|
|
if rd.end-rd.start == 0 && len(rd.buf) > 4096 {
|
|
rd.buf = rd.buf[:4096]
|
|
rd.start = 0
|
|
rd.end = 0
|
|
}
|
|
if len(b) > 0 {
|
|
// we have data, yay!
|
|
// but is this enough data for a complete command? or multiple?
|
|
next:
|
|
switch b[0] {
|
|
default:
|
|
// just a plain text command
|
|
for i := 0; i < len(b); i++ {
|
|
if b[i] == '\n' {
|
|
var line []byte
|
|
if i > 0 && b[i-1] == '\r' {
|
|
line = b[:i-1]
|
|
} else {
|
|
line = b[:i]
|
|
}
|
|
var cmd Command
|
|
var quote bool
|
|
var quotech byte
|
|
var escape bool
|
|
outer:
|
|
for {
|
|
nline := make([]byte, 0, len(line))
|
|
for i := 0; i < len(line); i++ {
|
|
c := line[i]
|
|
if !quote {
|
|
if c == ' ' {
|
|
if len(nline) > 0 {
|
|
cmd.Args = append(cmd.Args, nline)
|
|
}
|
|
line = line[i+1:]
|
|
continue outer
|
|
}
|
|
if c == '"' || c == '\'' {
|
|
if i != 0 {
|
|
return nil, errUnbalancedQuotes
|
|
}
|
|
quotech = c
|
|
quote = true
|
|
line = line[i+1:]
|
|
continue outer
|
|
}
|
|
} else {
|
|
if escape {
|
|
escape = false
|
|
switch c {
|
|
case 'n':
|
|
c = '\n'
|
|
case 'r':
|
|
c = '\r'
|
|
case 't':
|
|
c = '\t'
|
|
}
|
|
} else if c == quotech {
|
|
quote = false
|
|
quotech = 0
|
|
cmd.Args = append(cmd.Args, nline)
|
|
line = line[i+1:]
|
|
if len(line) > 0 && line[0] != ' ' {
|
|
return nil, errUnbalancedQuotes
|
|
}
|
|
continue outer
|
|
} else if c == '\\' {
|
|
escape = true
|
|
continue
|
|
}
|
|
}
|
|
nline = append(nline, c)
|
|
}
|
|
if quote {
|
|
return nil, errUnbalancedQuotes
|
|
}
|
|
if len(line) > 0 {
|
|
cmd.Args = append(cmd.Args, line)
|
|
}
|
|
break
|
|
}
|
|
if len(cmd.Args) > 0 {
|
|
// convert this to resp command syntax
|
|
var wr Writer
|
|
wr.WriteArray(len(cmd.Args))
|
|
for i := range cmd.Args {
|
|
wr.WriteBulk(cmd.Args[i])
|
|
cmd.Args[i] = append([]byte(nil), cmd.Args[i]...)
|
|
}
|
|
cmd.Raw = wr.b
|
|
cmds = append(cmds, cmd)
|
|
}
|
|
b = b[i+1:]
|
|
if len(b) > 0 {
|
|
goto next
|
|
} else {
|
|
goto done
|
|
}
|
|
}
|
|
}
|
|
case '*':
|
|
// resp formatted command
|
|
marks := make([]int, 0, 16)
|
|
outer2:
|
|
for i := 1; i < len(b); i++ {
|
|
if b[i] == '\n' {
|
|
if b[i-1] != '\r' {
|
|
return nil, errInvalidMultiBulkLength
|
|
}
|
|
count, ok := parseInt(b[1 : i-1])
|
|
if !ok || count <= 0 {
|
|
return nil, errInvalidMultiBulkLength
|
|
}
|
|
marks = marks[:0]
|
|
for j := 0; j < count; j++ {
|
|
// read bulk length
|
|
i++
|
|
if i < len(b) {
|
|
if b[i] != '$' {
|
|
return nil, &errProtocol{"expected '$', got '" +
|
|
string(b[i]) + "'"}
|
|
}
|
|
si := i
|
|
for ; i < len(b); i++ {
|
|
if b[i] == '\n' {
|
|
if b[i-1] != '\r' {
|
|
return nil, errInvalidBulkLength
|
|
}
|
|
size, ok := parseInt(b[si+1 : i-1])
|
|
if !ok || size < 0 {
|
|
return nil, errInvalidBulkLength
|
|
}
|
|
if i+size+2 >= len(b) {
|
|
// not ready
|
|
break outer2
|
|
}
|
|
if b[i+size+2] != '\n' ||
|
|
b[i+size+1] != '\r' {
|
|
return nil, errInvalidBulkLength
|
|
}
|
|
i++
|
|
marks = append(marks, i, i+size)
|
|
i += size + 1
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if len(marks) == count*2 {
|
|
var cmd Command
|
|
if rd.rd != nil {
|
|
// make a raw copy of the entire command when
|
|
// there's a underlying reader.
|
|
cmd.Raw = append([]byte(nil), b[:i+1]...)
|
|
} else {
|
|
// just assign the slice
|
|
cmd.Raw = b[:i+1]
|
|
}
|
|
cmd.Args = make([][]byte, len(marks)/2)
|
|
// slice up the raw command into the args based on
|
|
// the recorded marks.
|
|
for h := 0; h < len(marks); h += 2 {
|
|
cmd.Args[h/2] = cmd.Raw[marks[h]:marks[h+1]]
|
|
}
|
|
cmds = append(cmds, cmd)
|
|
b = b[i+1:]
|
|
if len(b) > 0 {
|
|
goto next
|
|
} else {
|
|
goto done
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
done:
|
|
rd.start = rd.end - len(b)
|
|
}
|
|
if leftover != nil {
|
|
*leftover = rd.end - rd.start
|
|
}
|
|
if len(cmds) > 0 {
|
|
return cmds, nil
|
|
}
|
|
if rd.rd == nil {
|
|
return nil, errIncompleteCommand
|
|
}
|
|
if rd.end == len(rd.buf) {
|
|
// at the end of the buffer.
|
|
if rd.start == rd.end {
|
|
// rewind the to the beginning
|
|
rd.start, rd.end = 0, 0
|
|
} else {
|
|
// must grow the buffer
|
|
newbuf := make([]byte, len(rd.buf)*2)
|
|
copy(newbuf, rd.buf)
|
|
rd.buf = newbuf
|
|
}
|
|
}
|
|
n, err := rd.rd.Read(rd.buf[rd.end:])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
rd.end += n
|
|
return rd.readCommands(leftover)
|
|
}
|
|
|
|
// ReadCommands reads the next pipeline commands.
|
|
func (rd *Reader) ReadCommands() ([]Command, error) {
|
|
for {
|
|
if len(rd.cmds) > 0 {
|
|
cmds := rd.cmds
|
|
rd.cmds = nil
|
|
return cmds, nil
|
|
}
|
|
cmds, err := rd.readCommands(nil)
|
|
if err != nil {
|
|
return []Command{}, err
|
|
}
|
|
rd.cmds = cmds
|
|
}
|
|
}
|
|
|
|
// ReadCommand reads the next command.
|
|
func (rd *Reader) ReadCommand() (Command, error) {
|
|
if len(rd.cmds) > 0 {
|
|
cmd := rd.cmds[0]
|
|
rd.cmds = rd.cmds[1:]
|
|
return cmd, nil
|
|
}
|
|
cmds, err := rd.readCommands(nil)
|
|
if err != nil {
|
|
return Command{}, err
|
|
}
|
|
rd.cmds = cmds
|
|
return rd.ReadCommand()
|
|
}
|
|
|
|
// Parse parses a raw RESP message and returns a command.
|
|
func Parse(raw []byte) (Command, error) {
|
|
rd := Reader{buf: raw, end: len(raw)}
|
|
var leftover int
|
|
cmds, err := rd.readCommands(&leftover)
|
|
if err != nil {
|
|
return Command{}, err
|
|
}
|
|
if leftover > 0 {
|
|
return Command{}, errTooMuchData
|
|
}
|
|
return cmds[0], nil
|
|
|
|
}
|
|
|
|
// A Handler responds to an RESP request.
|
|
type Handler interface {
|
|
ServeRESP(conn Conn, cmd Command)
|
|
}
|
|
|
|
// The HandlerFunc type is an adapter to allow the use of
|
|
// ordinary functions as RESP handlers. If f is a function
|
|
// with the appropriate signature, HandlerFunc(f) is a
|
|
// Handler that calls f.
|
|
type HandlerFunc func(conn Conn, cmd Command)
|
|
|
|
// ServeRESP calls f(w, r)
|
|
func (f HandlerFunc) ServeRESP(conn Conn, cmd Command) {
|
|
f(conn, cmd)
|
|
}
|
|
|
|
// ServeMux is an RESP command multiplexer.
|
|
type ServeMux struct {
|
|
handlers map[string]Handler
|
|
}
|
|
|
|
// NewServeMux allocates and returns a new ServeMux.
|
|
func NewServeMux() *ServeMux {
|
|
return &ServeMux{
|
|
handlers: make(map[string]Handler),
|
|
}
|
|
}
|
|
|
|
// HandleFunc registers the handler function for the given command.
|
|
func (m *ServeMux) HandleFunc(command string, handler func(conn Conn, cmd Command)) {
|
|
if handler == nil {
|
|
panic("redcon: nil handler")
|
|
}
|
|
m.Handle(command, HandlerFunc(handler))
|
|
}
|
|
|
|
// Handle registers the handler for the given command.
|
|
// If a handler already exists for command, Handle panics.
|
|
func (m *ServeMux) Handle(command string, handler Handler) {
|
|
if command == "" {
|
|
panic("redcon: invalid command")
|
|
}
|
|
if handler == nil {
|
|
panic("redcon: nil handler")
|
|
}
|
|
if _, exist := m.handlers[command]; exist {
|
|
panic("redcon: multiple registrations for " + command)
|
|
}
|
|
|
|
m.handlers[command] = handler
|
|
}
|
|
|
|
// ServeRESP dispatches the command to the handler.
|
|
func (m *ServeMux) ServeRESP(conn Conn, cmd Command) {
|
|
command := strings.ToLower(string(cmd.Args[0]))
|
|
|
|
if handler, ok := m.handlers[command]; ok {
|
|
handler.ServeRESP(conn, cmd)
|
|
} else {
|
|
conn.WriteError("ERR unknown command '" + command + "'")
|
|
}
|
|
}
|
|
|
|
// PubSub is a Redis compatible pub/sub server
|
|
type PubSub struct {
|
|
mu sync.RWMutex
|
|
nextid uint64
|
|
initd bool
|
|
chans *btree.BTree
|
|
conns map[Conn]*pubSubConn
|
|
}
|
|
|
|
// Subscribe a connection to PubSub
|
|
func (ps *PubSub) Subscribe(conn Conn, channel string) {
|
|
ps.subscribe(conn, false, channel)
|
|
}
|
|
|
|
// Psubscribe a connection to PubSub
|
|
func (ps *PubSub) Psubscribe(conn Conn, channel string) {
|
|
ps.subscribe(conn, true, channel)
|
|
}
|
|
|
|
// Publish a message to subscribers
|
|
func (ps *PubSub) Publish(channel, message string) int {
|
|
ps.mu.RLock()
|
|
defer ps.mu.RUnlock()
|
|
if !ps.initd {
|
|
return 0
|
|
}
|
|
var sent int
|
|
// write messages to all clients that are subscribed on the channel
|
|
pivot := &pubSubEntry{pattern: false, channel: channel}
|
|
ps.chans.Ascend(pivot, func(item interface{}) bool {
|
|
entry := item.(*pubSubEntry)
|
|
if entry.channel != pivot.channel || entry.pattern != pivot.pattern {
|
|
return false
|
|
}
|
|
entry.sconn.writeMessage(entry.pattern, "", channel, message)
|
|
sent++
|
|
return true
|
|
})
|
|
|
|
// match on and write all psubscribe clients
|
|
pivot = &pubSubEntry{pattern: true}
|
|
ps.chans.Ascend(pivot, func(item interface{}) bool {
|
|
entry := item.(*pubSubEntry)
|
|
if match.Match(channel, entry.channel) {
|
|
entry.sconn.writeMessage(entry.pattern, entry.channel, channel,
|
|
message)
|
|
}
|
|
sent++
|
|
return true
|
|
})
|
|
|
|
return sent
|
|
}
|
|
|
|
type pubSubConn struct {
|
|
id uint64
|
|
mu sync.Mutex
|
|
conn Conn
|
|
dconn DetachedConn
|
|
entries map[*pubSubEntry]bool
|
|
}
|
|
|
|
type pubSubEntry struct {
|
|
pattern bool
|
|
sconn *pubSubConn
|
|
channel string
|
|
}
|
|
|
|
func (sconn *pubSubConn) writeMessage(pat bool, pchan, channel, msg string) {
|
|
sconn.mu.Lock()
|
|
defer sconn.mu.Unlock()
|
|
if pat {
|
|
sconn.dconn.WriteArray(4)
|
|
sconn.dconn.WriteBulkString("pmessage")
|
|
sconn.dconn.WriteBulkString(pchan)
|
|
sconn.dconn.WriteBulkString(channel)
|
|
sconn.dconn.WriteBulkString(msg)
|
|
} else {
|
|
sconn.dconn.WriteArray(3)
|
|
sconn.dconn.WriteBulkString("message")
|
|
sconn.dconn.WriteBulkString(channel)
|
|
sconn.dconn.WriteBulkString(msg)
|
|
}
|
|
sconn.dconn.Flush()
|
|
}
|
|
|
|
// bgrunner runs in the background and reads incoming commands from the
|
|
// detached client.
|
|
func (sconn *pubSubConn) bgrunner(ps *PubSub) {
|
|
defer func() {
|
|
// client connection has ended, disconnect from the PubSub instances
|
|
// and close the network connection.
|
|
ps.mu.Lock()
|
|
defer ps.mu.Unlock()
|
|
for entry := range sconn.entries {
|
|
ps.chans.Delete(entry)
|
|
}
|
|
delete(ps.conns, sconn.conn)
|
|
sconn.mu.Lock()
|
|
defer sconn.mu.Unlock()
|
|
sconn.dconn.Close()
|
|
}()
|
|
for {
|
|
cmd, err := sconn.dconn.ReadCommand()
|
|
if err != nil {
|
|
return
|
|
}
|
|
if len(cmd.Args) == 0 {
|
|
continue
|
|
}
|
|
switch strings.ToLower(string(cmd.Args[0])) {
|
|
case "psubscribe", "subscribe":
|
|
if len(cmd.Args) < 2 {
|
|
func() {
|
|
sconn.mu.Lock()
|
|
defer sconn.mu.Unlock()
|
|
sconn.dconn.WriteError(fmt.Sprintf("ERR wrong number of "+
|
|
"arguments for '%s'", cmd.Args[0]))
|
|
sconn.dconn.Flush()
|
|
}()
|
|
continue
|
|
}
|
|
command := strings.ToLower(string(cmd.Args[0]))
|
|
for i := 1; i < len(cmd.Args); i++ {
|
|
if command == "psubscribe" {
|
|
ps.Psubscribe(sconn.conn, string(cmd.Args[i]))
|
|
} else {
|
|
ps.Subscribe(sconn.conn, string(cmd.Args[i]))
|
|
}
|
|
}
|
|
case "unsubscribe", "punsubscribe":
|
|
pattern := strings.ToLower(string(cmd.Args[0])) == "punsubscribe"
|
|
if len(cmd.Args) == 1 {
|
|
ps.unsubscribe(sconn.conn, pattern, true, "")
|
|
} else {
|
|
for i := 1; i < len(cmd.Args); i++ {
|
|
channel := string(cmd.Args[i])
|
|
ps.unsubscribe(sconn.conn, pattern, false, channel)
|
|
}
|
|
}
|
|
case "quit":
|
|
func() {
|
|
sconn.mu.Lock()
|
|
defer sconn.mu.Unlock()
|
|
sconn.dconn.WriteString("OK")
|
|
sconn.dconn.Flush()
|
|
sconn.dconn.Close()
|
|
}()
|
|
return
|
|
case "ping":
|
|
var msg string
|
|
switch len(cmd.Args) {
|
|
case 1:
|
|
case 2:
|
|
msg = string(cmd.Args[1])
|
|
default:
|
|
func() {
|
|
sconn.mu.Lock()
|
|
defer sconn.mu.Unlock()
|
|
sconn.dconn.WriteError(fmt.Sprintf("ERR wrong number of "+
|
|
"arguments for '%s'", cmd.Args[0]))
|
|
sconn.dconn.Flush()
|
|
}()
|
|
continue
|
|
}
|
|
func() {
|
|
sconn.mu.Lock()
|
|
defer sconn.mu.Unlock()
|
|
sconn.dconn.WriteArray(2)
|
|
sconn.dconn.WriteBulkString("pong")
|
|
sconn.dconn.WriteBulkString(msg)
|
|
sconn.dconn.Flush()
|
|
}()
|
|
default:
|
|
func() {
|
|
sconn.mu.Lock()
|
|
defer sconn.mu.Unlock()
|
|
sconn.dconn.WriteError(fmt.Sprintf("ERR Can't execute '%s': "+
|
|
"only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT are "+
|
|
"allowed in this context", cmd.Args[0]))
|
|
sconn.dconn.Flush()
|
|
}()
|
|
}
|
|
}
|
|
}
|
|
|
|
// byEntry is a "less" function that sorts the entries in a btree. The tree
|
|
// is sorted be (pattern, channel, conn.id). All pattern=true entries are at
|
|
// the end (right) of the tree.
|
|
func byEntry(a, b interface{}) bool {
|
|
aa := a.(*pubSubEntry)
|
|
bb := b.(*pubSubEntry)
|
|
if !aa.pattern && bb.pattern {
|
|
return true
|
|
}
|
|
if aa.pattern && !bb.pattern {
|
|
return false
|
|
}
|
|
if aa.channel < bb.channel {
|
|
return true
|
|
}
|
|
if aa.channel > bb.channel {
|
|
return false
|
|
}
|
|
var aid uint64
|
|
var bid uint64
|
|
if aa.sconn != nil {
|
|
aid = aa.sconn.id
|
|
}
|
|
if bb.sconn != nil {
|
|
bid = bb.sconn.id
|
|
}
|
|
return aid < bid
|
|
}
|
|
|
|
func (ps *PubSub) subscribe(conn Conn, pattern bool, channel string) {
|
|
ps.mu.Lock()
|
|
defer ps.mu.Unlock()
|
|
|
|
// initialize the PubSub instance
|
|
if !ps.initd {
|
|
ps.conns = make(map[Conn]*pubSubConn)
|
|
ps.chans = btree.New(byEntry)
|
|
ps.initd = true
|
|
}
|
|
|
|
// fetch the pubSubConn
|
|
sconn, ok := ps.conns[conn]
|
|
if !ok {
|
|
// initialize a new pubSubConn, which runs on a detached connection,
|
|
// and attach it to the PubSub channels/conn btree
|
|
ps.nextid++
|
|
dconn := conn.Detach()
|
|
sconn = &pubSubConn{
|
|
id: ps.nextid,
|
|
conn: conn,
|
|
dconn: dconn,
|
|
entries: make(map[*pubSubEntry]bool),
|
|
}
|
|
ps.conns[conn] = sconn
|
|
}
|
|
sconn.mu.Lock()
|
|
defer sconn.mu.Unlock()
|
|
|
|
// add an entry to the pubsub btree
|
|
entry := &pubSubEntry{
|
|
pattern: pattern,
|
|
channel: channel,
|
|
sconn: sconn,
|
|
}
|
|
ps.chans.Set(entry)
|
|
sconn.entries[entry] = true
|
|
|
|
// send a message to the client
|
|
sconn.dconn.WriteArray(3)
|
|
if pattern {
|
|
sconn.dconn.WriteBulkString("psubscribe")
|
|
} else {
|
|
sconn.dconn.WriteBulkString("subscribe")
|
|
}
|
|
sconn.dconn.WriteBulkString(channel)
|
|
var count int
|
|
for entry := range sconn.entries {
|
|
if entry.pattern == pattern {
|
|
count++
|
|
}
|
|
}
|
|
sconn.dconn.WriteInt(count)
|
|
sconn.dconn.Flush()
|
|
|
|
// start the background client operation
|
|
if !ok {
|
|
go sconn.bgrunner(ps)
|
|
}
|
|
}
|
|
|
|
func (ps *PubSub) unsubscribe(conn Conn, pattern, all bool, channel string) {
|
|
ps.mu.Lock()
|
|
defer ps.mu.Unlock()
|
|
// fetch the pubSubConn. This must exist
|
|
sconn := ps.conns[conn]
|
|
sconn.mu.Lock()
|
|
defer sconn.mu.Unlock()
|
|
|
|
removeEntry := func(entry *pubSubEntry) {
|
|
if entry != nil {
|
|
ps.chans.Delete(entry)
|
|
delete(sconn.entries, entry)
|
|
}
|
|
sconn.dconn.WriteArray(3)
|
|
if pattern {
|
|
sconn.dconn.WriteBulkString("punsubscribe")
|
|
} else {
|
|
sconn.dconn.WriteBulkString("unsubscribe")
|
|
}
|
|
if entry != nil {
|
|
sconn.dconn.WriteBulkString(entry.channel)
|
|
} else {
|
|
sconn.dconn.WriteNull()
|
|
}
|
|
var count int
|
|
for entry := range sconn.entries {
|
|
if entry.pattern == pattern {
|
|
count++
|
|
}
|
|
}
|
|
sconn.dconn.WriteInt(count)
|
|
}
|
|
if all {
|
|
// unsubscribe from all (p)subscribe entries
|
|
var entries []*pubSubEntry
|
|
for entry := range sconn.entries {
|
|
if entry.pattern == pattern {
|
|
entries = append(entries, entry)
|
|
}
|
|
}
|
|
if len(entries) == 0 {
|
|
removeEntry(nil)
|
|
} else {
|
|
for _, entry := range entries {
|
|
removeEntry(entry)
|
|
}
|
|
}
|
|
} else {
|
|
// unsubscribe single channel from (p)subscribe.
|
|
for entry := range sconn.entries {
|
|
if entry.pattern == pattern && entry.channel == channel {
|
|
removeEntry(entry)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
sconn.dconn.Flush()
|
|
}
|
|
|
|
// SetIdleClose will automatically close idle connections after the specified
|
|
// duration. Use zero to disable this feature.
|
|
func (s *Server) SetIdleClose(dur time.Duration) {
|
|
s.mu.Lock()
|
|
s.idleClose = dur
|
|
s.mu.Unlock()
|
|
}
|