tile38/internal/server/server.go

1591 lines
41 KiB
Go

package server
import (
"bytes"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/url"
"os"
"path"
"path/filepath"
"runtime"
"runtime/debug"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/tidwall/boxtree/d2"
"github.com/tidwall/buntdb"
"github.com/tidwall/evio"
"github.com/tidwall/geojson"
"github.com/tidwall/geojson/geometry"
"github.com/tidwall/redcon"
"github.com/tidwall/resp"
"github.com/tidwall/tile38/core"
"github.com/tidwall/tile38/internal/collection"
"github.com/tidwall/tile38/internal/endpoint"
"github.com/tidwall/tile38/internal/expire"
"github.com/tidwall/tile38/internal/log"
"github.com/tidwall/tinybtree"
)
var errOOM = errors.New("OOM command not allowed when used memory > 'maxmemory'")
const (
goingLive = "going live"
hookLogPrefix = "hook:log:"
)
// commandDetails is detailed information about a mutable command. It's used
// for geofence formulas.
type commandDetails struct {
command string // client command, like "SET" or "DEL"
key, id string // collection key and object id of object
newKey string // new key, for RENAME command
obj geojson.Object // new object
fmap map[string]int // map of field names to value indexes
fields *collection.Fields // array of field values
oldObj geojson.Object // previous object, if any
oldFields *collection.Fields // previous object field values
updated bool // object was updated
timestamp time.Time // timestamp when the update occured
parent bool // when true, only children are forwarded
pattern string // PDEL key pattern
children []*commandDetails // for multi actions such as "PDEL"
}
// Server is a tile38 controller
type Server struct {
// static values
host string
port int
http bool
dir string
started time.Time
config *Config
epc *endpoint.Manager
// env opts
geomParseOpts geojson.ParseOptions
// atomics
followc aint // counter increases when follow property changes
statsTotalConns aint // counter for total connections
statsTotalCommands aint // counter for total commands
statsExpired aint // item expiration counter
lastShrinkDuration aint
stopServer abool
outOfMemory abool
connsmu sync.RWMutex
conns map[int]*Client
exlistmu sync.RWMutex
exlist []exitem
mu sync.RWMutex
aof *os.File // active aof file
aofdirty int32 // mark the aofbuf as having data
aofbuf []byte // prewrite buffer
aofsz int // active size of the aof file
qdb *buntdb.DB // hook queue log
qidx uint64 // hook queue log last idx
cols tinybtree.BTree // data collections
expires map[string]map[string]time.Time // synced with cols
follows map[*bytes.Buffer]bool
fcond *sync.Cond
lstack []*commandDetails
lives map[*liveBuffer]bool
lcond *sync.Cond
fcup bool // follow caught up
fcuponce bool // follow caught up once
shrinking bool // aof shrinking flag
shrinklog [][]string // aof shrinking log
hooks map[string]*Hook // hook name
hookTree d2.BoxTree // hook spatial tree containing all
hooksOut map[string]*Hook // hooks with "outside" detection
aofconnM map[net.Conn]bool
luascripts *lScriptMap
luapool *lStatePool
pubsub *pubsub
hookex expire.List
}
// Serve starts a new tile38 server
func Serve(host string, port int, dir string, http bool) error {
if core.AppendFileName == "" {
core.AppendFileName = path.Join(dir, "appendonly.aof")
}
if core.QueueFileName == "" {
core.QueueFileName = path.Join(dir, "queue.db")
}
log.Infof("Server started, Tile38 version %s, git %s", core.Version, core.GitSHA)
// Initialize the server
server := &Server{
host: host,
port: port,
dir: dir,
follows: make(map[*bytes.Buffer]bool),
fcond: sync.NewCond(&sync.Mutex{}),
lives: make(map[*liveBuffer]bool),
lcond: sync.NewCond(&sync.Mutex{}),
hooks: make(map[string]*Hook),
hooksOut: make(map[string]*Hook),
aofconnM: make(map[net.Conn]bool),
expires: make(map[string]map[string]time.Time),
started: time.Now(),
conns: make(map[int]*Client),
http: http,
pubsub: newPubsub(),
}
server.hookex.Expired = func(item expire.Item) {
switch v := item.(type) {
case *Hook:
server.possiblyExpireHook(v.Name)
}
}
server.epc = endpoint.NewManager(server)
server.luascripts = server.newScriptMap()
server.luapool = server.newPool()
defer server.luapool.Shutdown()
if err := os.MkdirAll(dir, 0700); err != nil {
return err
}
var err error
server.config, err = loadConfig(filepath.Join(dir, "config"))
if err != nil {
return err
}
server.geomParseOpts = *geojson.DefaultParseOptions
server.geomParseOpts.AllowSimplePoints = true
// Allow for geometry indexing options through environment variables:
// T38IDXGEOMKIND -- None, RTree, QuadTree
// T38IDXGEOM -- Min number of points in a geometry for indexing.
// T38IDXMULTI -- Min number of object in a Multi/Collection for indexing.
n, err := strconv.ParseUint(os.Getenv("T38IDXGEOM"), 10, 32)
if err == nil {
server.geomParseOpts.IndexGeometry = int(n)
}
n, err = strconv.ParseUint(os.Getenv("T38IDXMULTI"), 10, 32)
if err == nil {
server.geomParseOpts.IndexChildren = int(n)
}
requireValid := os.Getenv("REQUIREVALID")
if requireValid != "" {
server.geomParseOpts.RequireValid = true
}
indexKind := os.Getenv("T38IDXGEOMKIND")
switch indexKind {
default:
log.Errorf("Unknown index kind: %s", indexKind)
case "":
case "None":
server.geomParseOpts.IndexGeometryKind = geometry.None
case "RTree":
server.geomParseOpts.IndexGeometryKind = geometry.RTree
case "QuadTree":
server.geomParseOpts.IndexGeometryKind = geometry.QuadTree
}
if server.geomParseOpts.IndexGeometryKind == geometry.None {
log.Debugf("Geom indexing: %s",
server.geomParseOpts.IndexGeometryKind,
)
} else {
log.Debugf("Geom indexing: %s (%d points)",
server.geomParseOpts.IndexGeometryKind,
server.geomParseOpts.IndexGeometry,
)
}
log.Debugf("Multi indexing: RTree (%d points)", server.geomParseOpts.IndexChildren)
// Load the queue before the aof
qdb, err := buntdb.Open(core.QueueFileName)
if err != nil {
return err
}
var qidx uint64
if err := qdb.View(func(tx *buntdb.Tx) error {
val, err := tx.Get("hook:idx")
if err != nil {
if err == buntdb.ErrNotFound {
return nil
}
return err
}
qidx = stringToUint64(val)
return nil
}); err != nil {
return err
}
err = qdb.CreateIndex("hooks", hookLogPrefix+"*", buntdb.IndexJSONCaseSensitive("hook"))
if err != nil {
return err
}
server.qdb = qdb
server.qidx = qidx
if err := server.migrateAOF(); err != nil {
return err
}
if core.AppendOnly == true {
f, err := os.OpenFile(core.AppendFileName, os.O_CREATE|os.O_RDWR, 0600)
if err != nil {
return err
}
server.aof = f
if err := server.loadAOF(); err != nil {
return err
}
defer func() {
server.flushAOF(false)
server.aof.Sync()
}()
}
server.fillExpiresList()
// Start background routines
if server.config.followHost() != "" {
go server.follow(server.config.followHost(), server.config.followPort(),
server.followc.get())
}
go server.processLives()
go server.watchOutOfMemory()
go server.watchLuaStatePool()
go server.watchAutoGC()
go server.backgroundExpiring()
go server.backgroundSyncAOF()
defer func() {
// Stop background routines
server.followc.add(1) // this will force any follow communication to die
server.stopServer.set(true)
// notify the live geofence connections that we are stopping.
server.lcond.L.Lock()
server.lcond.Wait()
server.lcond.L.Lock()
}()
// Start the network server
if core.Evio {
return server.evioServe()
}
return server.netServe()
}
func (server *Server) isProtected() bool {
if core.ProtectedMode == "no" {
// --protected-mode no
return false
}
if server.host != "" && server.host != "127.0.0.1" &&
server.host != "::1" && server.host != "localhost" {
// -h address
return false
}
is := server.config.protectedMode() != "no" && server.config.requirePass() == ""
return is
}
func (server *Server) evioServe() error {
var events evio.Events
if core.NumThreads == 0 {
events.NumLoops = -1
} else {
events.NumLoops = core.NumThreads
}
events.LoadBalance = evio.LeastConnections
events.Serving = func(eserver evio.Server) (action evio.Action) {
if eserver.NumLoops == 1 {
log.Infof("Running single-threaded")
} else {
log.Infof("Running on %d threads", eserver.NumLoops)
}
for _, addr := range eserver.Addrs {
log.Infof("Ready to accept connections at %s",
addr)
}
return
}
var clientID int64
events.Opened = func(econn evio.Conn) (out []byte, opts evio.Options, action evio.Action) {
// create the client
client := new(Client)
client.id = int(atomic.AddInt64(&clientID, 1))
client.opened = time.Now()
client.remoteAddr = econn.RemoteAddr().String()
// keep track of the client
econn.SetContext(client)
// add client to server map
server.connsmu.Lock()
server.conns[client.id] = client
server.connsmu.Unlock()
server.statsTotalConns.add(1)
// set the client keep-alive, if needed
if server.config.keepAlive() > 0 {
opts.TCPKeepAlive = time.Duration(server.config.keepAlive()) * time.Second
}
log.Debugf("Opened connection: %s", client.remoteAddr)
// check if the connection is protected
if !strings.HasPrefix(client.remoteAddr, "127.0.0.1:") &&
!strings.HasPrefix(client.remoteAddr, "[::1]:") {
if server.isProtected() {
// This is a protected server. Only loopback is allowed.
out = append(out, deniedMessage...)
action = evio.Close
return
}
}
return
}
events.Closed = func(econn evio.Conn, err error) (action evio.Action) {
// load the client
client := econn.Context().(*Client)
// delete from server map
server.connsmu.Lock()
delete(server.conns, client.id)
server.connsmu.Unlock()
log.Debugf("Closed connection: %s", client.remoteAddr)
return
}
events.Data = func(econn evio.Conn, in []byte) (out []byte, action evio.Action) {
// load the client
client := econn.Context().(*Client)
// read the payload packet from the client input stream.
packet := client.in.Begin(in)
// load the pipeline reader
pr := &client.pr
rdbuf := bytes.NewBuffer(packet)
pr.rd = rdbuf
pr.wr = client
msgs, err := pr.ReadMessages()
if err != nil {
log.Error(err)
action = evio.Close
return
}
for _, msg := range msgs {
// Just closing connection if we have deprecated HTTP or WS connection,
// And --http-transport = false
if !server.http && (msg.ConnType == WebSocket ||
msg.ConnType == HTTP) {
action = evio.Close
break
}
if msg != nil && msg.Command() != "" {
if client.outputType != Null {
msg.OutputType = client.outputType
}
if msg.Command() == "quit" {
if msg.OutputType == RESP {
io.WriteString(client, "+OK\r\n")
}
action = evio.Close
break
}
// increment last used
client.mu.Lock()
client.last = time.Now()
client.mu.Unlock()
// update total command count
server.statsTotalCommands.add(1)
// handle the command
err := server.handleInputCommand(client, msg)
if err != nil {
if err.Error() == goingLive {
client.goLiveErr = err
client.goLiveMsg = msg
action = evio.Detach
return
}
log.Error(err)
action = evio.Close
return
}
client.outputType = msg.OutputType
} else {
client.Write([]byte("HTTP/1.1 500 Bad Request\r\nConnection: close\r\n\r\n"))
action = evio.Close
break
}
if msg.ConnType == HTTP || msg.ConnType == WebSocket {
action = evio.Close
break
}
}
packet = packet[len(packet)-rdbuf.Len():]
client.in.End(packet)
out = client.out
client.out = nil
return
}
events.Detached = func(econn evio.Conn, rwc io.ReadWriteCloser) (action evio.Action) {
client := econn.Context().(*Client)
client.conn = rwc
if len(client.out) > 0 {
rwc.Write(client.out)
client.out = nil
}
client.in = evio.InputStream{}
client.pr.rd = rwc
client.pr.wr = rwc
log.Debugf("Detached connection: %s", client.remoteAddr)
go func() {
defer func() {
// close connection
rwc.Close()
server.connsmu.Lock()
delete(server.conns, client.id)
server.connsmu.Unlock()
log.Debugf("Closed connection: %s", client.remoteAddr)
}()
err := server.goLive(
client.goLiveErr,
&liveConn{econn.RemoteAddr(), rwc},
&client.pr,
client.goLiveMsg,
client.goLiveMsg.ConnType == WebSocket,
)
if err != nil {
log.Error(err)
}
}()
return
}
events.PreWrite = func() {
if atomic.LoadInt32(&server.aofdirty) != 0 {
server.mu.Lock()
defer server.mu.Unlock()
server.flushAOF(false)
atomic.StoreInt32(&server.aofdirty, 1)
}
}
return evio.Serve(events, fmt.Sprintf("%s:%d", server.host, server.port))
}
func (server *Server) netServe() error {
ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", server.host, server.port))
if err != nil {
return err
}
defer ln.Close()
log.Infof("Ready to accept connections at %s", ln.Addr())
var clientID int64
for {
conn, err := ln.Accept()
if err != nil {
return err
}
go func(conn net.Conn) {
// open connection
// create the client
client := new(Client)
client.id = int(atomic.AddInt64(&clientID, 1))
client.opened = time.Now()
client.remoteAddr = conn.RemoteAddr().String()
// add client to server map
server.connsmu.Lock()
server.conns[client.id] = client
server.connsmu.Unlock()
server.statsTotalConns.add(1)
// set the client keep-alive, if needed
if server.config.keepAlive() > 0 {
if conn, ok := conn.(*net.TCPConn); ok {
conn.SetKeepAlive(true)
conn.SetKeepAlivePeriod(
time.Duration(server.config.keepAlive()) * time.Second,
)
}
}
log.Debugf("Opened connection: %s", client.remoteAddr)
defer func() {
// close connection
// delete from server map
server.connsmu.Lock()
delete(server.conns, client.id)
server.connsmu.Unlock()
log.Debugf("Closed connection: %s", client.remoteAddr)
conn.Close()
}()
// check if the connection is protected
if !strings.HasPrefix(client.remoteAddr, "127.0.0.1:") &&
!strings.HasPrefix(client.remoteAddr, "[::1]:") {
if server.isProtected() {
// This is a protected server. Only loopback is allowed.
conn.Write(deniedMessage)
return // close connection
}
}
packet := make([]byte, 0xFFFF)
for {
var close bool
n, err := conn.Read(packet)
if err != nil {
return
}
in := packet[:n]
// read the payload packet from the client input stream.
packet := client.in.Begin(in)
// load the pipeline reader
pr := &client.pr
rdbuf := bytes.NewBuffer(packet)
pr.rd = rdbuf
pr.wr = client
msgs, err := pr.ReadMessages()
if err != nil {
log.Error(err)
return // close connection
}
for _, msg := range msgs {
// Just closing connection if we have deprecated HTTP or WS connection,
// And --http-transport = false
if !server.http && (msg.ConnType == WebSocket ||
msg.ConnType == HTTP) {
close = true // close connection
break
}
if msg != nil && msg.Command() != "" {
if client.outputType != Null {
msg.OutputType = client.outputType
}
if msg.Command() == "quit" {
if msg.OutputType == RESP {
io.WriteString(client, "+OK\r\n")
}
close = true // close connection
break
}
// increment last used
client.mu.Lock()
client.last = time.Now()
client.mu.Unlock()
// update total command count
server.statsTotalCommands.add(1)
// handle the command
err := server.handleInputCommand(client, msg)
if err != nil {
if err.Error() == goingLive {
client.goLiveErr = err
client.goLiveMsg = msg
// detach
var rwc io.ReadWriteCloser = conn
client.conn = rwc
if len(client.out) > 0 {
client.conn.Write(client.out)
client.out = nil
}
client.in = evio.InputStream{}
client.pr.rd = rwc
client.pr.wr = rwc
log.Debugf("Detached connection: %s", client.remoteAddr)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
err := server.goLive(
client.goLiveErr,
&liveConn{conn.RemoteAddr(), rwc},
&client.pr,
client.goLiveMsg,
client.goLiveMsg.ConnType == WebSocket,
)
if err != nil {
log.Error(err)
}
}()
wg.Wait()
return // close connection
}
log.Error(err)
return // close connection, NOW
}
client.outputType = msg.OutputType
} else {
client.Write([]byte("HTTP/1.1 500 Bad Request\r\nConnection: close\r\n\r\n"))
break
}
if msg.ConnType == HTTP || msg.ConnType == WebSocket {
close = true // close connection
break
}
}
packet = packet[len(packet)-rdbuf.Len():]
client.in.End(packet)
// write to client
if len(client.out) > 0 {
if atomic.LoadInt32(&server.aofdirty) != 0 {
func() {
// prewrite
server.mu.Lock()
defer server.mu.Unlock()
server.flushAOF(false)
}()
atomic.StoreInt32(&server.aofdirty, 0)
}
conn.Write(client.out)
client.out = nil
}
if close {
break
}
}
}(conn)
}
}
type liveConn struct {
remoteAddr net.Addr
rwc io.ReadWriteCloser
}
func (conn *liveConn) Close() error {
return conn.rwc.Close()
}
func (conn *liveConn) LocalAddr() net.Addr {
panic("not supported")
}
func (conn *liveConn) RemoteAddr() net.Addr {
return conn.remoteAddr
}
func (conn *liveConn) Read(b []byte) (n int, err error) {
return conn.rwc.Read(b)
}
func (conn *liveConn) Write(b []byte) (n int, err error) {
return conn.rwc.Write(b)
}
func (conn *liveConn) SetDeadline(deadline time.Time) error {
panic("not supported")
}
func (conn *liveConn) SetReadDeadline(deadline time.Time) error {
panic("not supported")
}
func (conn *liveConn) SetWriteDeadline(deadline time.Time) error {
panic("not supported")
}
func (server *Server) watchAutoGC() {
t := time.NewTicker(time.Second)
defer t.Stop()
s := time.Now()
for range t.C {
if server.stopServer.on() {
return
}
autoGC := server.config.autoGC()
if autoGC == 0 {
continue
}
if time.Now().Sub(s) < time.Second*time.Duration(autoGC) {
continue
}
var mem1, mem2 runtime.MemStats
runtime.ReadMemStats(&mem1)
log.Debugf("autogc(before): "+
"alloc: %v, heap_alloc: %v, heap_released: %v",
mem1.Alloc, mem1.HeapAlloc, mem1.HeapReleased)
runtime.GC()
debug.FreeOSMemory()
runtime.ReadMemStats(&mem2)
log.Debugf("autogc(after): "+
"alloc: %v, heap_alloc: %v, heap_released: %v",
mem2.Alloc, mem2.HeapAlloc, mem2.HeapReleased)
s = time.Now()
}
}
func (server *Server) watchOutOfMemory() {
t := time.NewTicker(time.Second * 2)
defer t.Stop()
var mem runtime.MemStats
for range t.C {
if server.stopServer.on() {
return
}
oom := server.outOfMemory.on()
if server.config.maxMemory() == 0 {
if oom {
server.outOfMemory.set(false)
}
return
}
if oom {
runtime.GC()
}
runtime.ReadMemStats(&mem)
server.outOfMemory.set(int(mem.HeapAlloc) > server.config.maxMemory())
}
}
func (server *Server) watchLuaStatePool() {
t := time.NewTicker(time.Second * 10)
defer t.Stop()
for range t.C {
if server.stopServer.on() {
return
}
server.luapool.Prune()
}
}
// backgroundSyncAOF ensures that the aof buffer is does not grow too big.
func (server *Server) backgroundSyncAOF() {
t := time.NewTicker(time.Second)
defer t.Stop()
for range t.C {
if server.stopServer.on() {
return
}
func() {
server.mu.Lock()
defer server.mu.Unlock()
if len(server.aofbuf) > 0 {
server.flushAOF(true)
}
server.aofbuf = nil
}()
}
}
func (server *Server) setCol(key string, col *collection.Collection) {
server.cols.Set(key, col)
}
func (server *Server) getCol(key string) *collection.Collection {
if value, ok := server.cols.Get(key); ok {
return value.(*collection.Collection)
}
return nil
}
func (server *Server) scanGreaterOrEqual(
key string, iterator func(key string, col *collection.Collection) bool,
) {
server.cols.Ascend(key, func(ikey string, ivalue interface{}) bool {
return iterator(ikey, ivalue.(*collection.Collection))
})
}
func (server *Server) deleteCol(key string) *collection.Collection {
if prev, ok := server.cols.Delete(key); ok {
return prev.(*collection.Collection)
}
return nil
}
func isReservedFieldName(field string) bool {
switch field {
case "z", "lat", "lon":
return true
}
return false
}
func (server *Server) handleInputCommand(client *Client, msg *Message) error {
start := time.Now()
serializeOutput := func(res resp.Value) (string, error) {
var resStr string
var err error
switch msg.OutputType {
case JSON:
resStr = res.String()
case RESP:
var resBytes []byte
resBytes, err = res.MarshalRESP()
resStr = string(resBytes)
}
return resStr, err
}
writeOutput := func(res string) error {
switch msg.ConnType {
default:
err := fmt.Errorf("unsupported conn type: %v", msg.ConnType)
log.Error(err)
return err
case WebSocket:
return WriteWebSocketMessage(client, []byte(res))
case HTTP:
_, err := fmt.Fprintf(client, "HTTP/1.1 200 OK\r\n"+
"Connection: close\r\n"+
"Content-Length: %d\r\n"+
"Content-Type: application/json; charset=utf-8\r\n"+
"\r\n", len(res)+2)
if err != nil {
return err
}
_, err = io.WriteString(client, res)
if err != nil {
return err
}
_, err = io.WriteString(client, "\r\n")
return err
case RESP:
var err error
if msg.OutputType == JSON {
_, err = fmt.Fprintf(client, "$%d\r\n%s\r\n", len(res), res)
} else {
_, err = io.WriteString(client, res)
}
return err
case Native:
_, err := fmt.Fprintf(client, "$%d %s\r\n", len(res), res)
return err
}
}
// Ping. Just send back the response. No need to put through the pipeline.
if msg.Command() == "ping" || msg.Command() == "echo" {
switch msg.OutputType {
case JSON:
if len(msg.Args) > 1 {
return writeOutput(`{"ok":true,"` + msg.Command() + `":` + jsonString(msg.Args[1]) + `,"elapsed":"` + time.Now().Sub(start).String() + `"}`)
}
return writeOutput(`{"ok":true,"` + msg.Command() + `":"pong","elapsed":"` + time.Now().Sub(start).String() + `"}`)
case RESP:
if len(msg.Args) > 1 {
data := redcon.AppendBulkString(nil, msg.Args[1])
return writeOutput(string(data))
}
return writeOutput("+PONG\r\n")
}
return nil
}
writeErr := func(errMsg string) error {
switch msg.OutputType {
case JSON:
return writeOutput(`{"ok":false,"err":` + jsonString(errMsg) + `,"elapsed":"` + time.Now().Sub(start).String() + "\"}")
case RESP:
if errMsg == errInvalidNumberOfArguments.Error() {
return writeOutput("-ERR wrong number of arguments for '" + msg.Command() + "' command\r\n")
}
v, _ := resp.ErrorValue(errors.New("ERR " + errMsg)).MarshalRESP()
return writeOutput(string(v))
}
return nil
}
var write bool
if !client.authd || msg.Command() == "auth" {
if server.config.requirePass() != "" {
password := ""
// This better be an AUTH command or the Message should contain an Auth
if msg.Command() != "auth" && msg.Auth == "" {
// Just shut down the pipeline now. The less the client connection knows the better.
return writeErr("authentication required")
}
if msg.Auth != "" {
password = msg.Auth
} else {
if len(msg.Args) > 1 {
password = msg.Args[1]
}
}
if server.config.requirePass() != strings.TrimSpace(password) {
return writeErr("invalid password")
}
client.authd = true
if msg.ConnType != HTTP {
resStr, _ := serializeOutput(OKMessage(msg, start))
return writeOutput(resStr)
}
} else if msg.Command() == "auth" {
return writeErr("invalid password")
}
}
// choose the locking strategy
switch msg.Command() {
default:
server.mu.RLock()
defer server.mu.RUnlock()
case "set", "del", "drop", "fset", "flushdb",
"setchan", "pdelchan", "delchan",
"sethook", "pdelhook", "delhook",
"expire", "persist", "jset", "pdel", "rename", "renamenx":
// write operations
write = true
server.mu.Lock()
defer server.mu.Unlock()
if server.config.followHost() != "" {
return writeErr("not the leader")
}
if server.config.readOnly() {
return writeErr("read only")
}
case "eval", "evalsha":
// write operations (potentially) but no AOF for the script command itself
server.mu.Lock()
defer server.mu.Unlock()
if server.config.followHost() != "" {
return writeErr("not the leader")
}
if server.config.readOnly() {
return writeErr("read only")
}
case "get", "keys", "scan", "nearby", "within", "intersects", "hooks",
"chans", "search", "ttl", "bounds", "server", "info", "type", "jget",
"evalro", "evalrosha":
// read operations
server.mu.RLock()
defer server.mu.RUnlock()
if server.config.followHost() != "" && !server.fcuponce {
return writeErr("catching up to leader")
}
case "follow", "slaveof", "replconf", "readonly", "config":
// system operations
// does not write to aof, but requires a write lock.
server.mu.Lock()
defer server.mu.Unlock()
case "output":
// this is local connection operation. Locks not needed.
case "echo":
case "massinsert":
// dev operation
server.mu.Lock()
defer server.mu.Unlock()
case "sleep":
// dev operation
server.mu.RLock()
defer server.mu.RUnlock()
case "shutdown":
// dev operation
server.mu.Lock()
defer server.mu.Unlock()
case "aofshrink":
server.mu.RLock()
defer server.mu.RUnlock()
case "client":
server.mu.Lock()
defer server.mu.Unlock()
case "evalna", "evalnasha":
// No locking for scripts, otherwise writes cannot happen within scripts
case "subscribe", "psubscribe", "publish":
// No locking for pubsub
}
res, d, err := server.command(msg, client)
if res.Type() == resp.Error {
return writeErr(res.String())
}
if err != nil {
if err.Error() == goingLive {
return err
}
return writeErr(err.Error())
}
if write {
if err := server.writeAOF(msg.Args, &d); err != nil {
if _, ok := err.(errAOFHook); ok {
return writeErr(err.Error())
}
log.Fatal(err)
return err
}
}
if !isRespValueEmptyString(res) {
var resStr string
resStr, err := serializeOutput(res)
if err != nil {
return err
}
if err := writeOutput(resStr); err != nil {
return err
}
}
return nil
}
func isRespValueEmptyString(val resp.Value) bool {
return !val.IsNull() && (val.Type() == resp.SimpleString || val.Type() == resp.BulkString) && len(val.Bytes()) == 0
}
func randomKey(n int) string {
b := make([]byte, n)
nn, err := rand.Read(b)
if err != nil {
panic(err)
}
if nn != n {
panic("random failed")
}
return fmt.Sprintf("%x", b)
}
func (server *Server) reset() {
server.aofsz = 0
server.cols = tinybtree.BTree{}
server.exlistmu.Lock()
server.exlist = nil
server.exlistmu.Unlock()
server.expires = make(map[string]map[string]time.Time)
}
func (server *Server) command(msg *Message, client *Client) (
res resp.Value, d commandDetails, err error,
) {
switch msg.Command() {
default:
err = fmt.Errorf("unknown command '%s'", msg.Args[0])
case "set":
res, d, err = server.cmdSet(msg)
case "fset":
res, d, err = server.cmdFset(msg)
case "del":
res, d, err = server.cmdDel(msg)
case "pdel":
res, d, err = server.cmdPdel(msg)
case "drop":
res, d, err = server.cmdDrop(msg)
case "flushdb":
res, d, err = server.cmdFlushDB(msg)
case "rename":
res, d, err = server.cmdRename(msg, false)
case "renamenx":
res, d, err = server.cmdRename(msg, true)
case "sethook":
res, d, err = server.cmdSetHook(msg, false)
case "delhook":
res, d, err = server.cmdDelHook(msg, false)
case "pdelhook":
res, d, err = server.cmdPDelHook(msg, false)
case "hooks":
res, err = server.cmdHooks(msg, false)
case "setchan":
res, d, err = server.cmdSetHook(msg, true)
case "delchan":
res, d, err = server.cmdDelHook(msg, true)
case "pdelchan":
res, d, err = server.cmdPDelHook(msg, true)
case "chans":
res, err = server.cmdHooks(msg, true)
case "expire":
res, d, err = server.cmdExpire(msg)
case "persist":
res, d, err = server.cmdPersist(msg)
case "ttl":
res, err = server.cmdTTL(msg)
case "shutdown":
if !core.DevMode {
err = fmt.Errorf("unknown command '%s'", msg.Args[0])
return
}
log.Fatal("shutdown requested by developer")
case "massinsert":
if !core.DevMode {
err = fmt.Errorf("unknown command '%s'", msg.Args[0])
return
}
res, err = server.cmdMassInsert(msg)
case "sleep":
if !core.DevMode {
err = fmt.Errorf("unknown command '%s'", msg.Args[0])
return
}
res, err = server.cmdSleep(msg)
case "follow", "slaveof":
res, err = server.cmdFollow(msg)
case "replconf":
res, err = server.cmdReplConf(msg, client)
case "readonly":
res, err = server.cmdReadOnly(msg)
case "stats":
res, err = server.cmdStats(msg)
case "server":
res, err = server.cmdServer(msg)
case "info":
res, err = server.cmdInfo(msg)
case "scan":
res, err = server.cmdScan(msg)
case "nearby":
res, err = server.cmdNearby(msg)
case "within":
res, err = server.cmdWithin(msg)
case "intersects":
res, err = server.cmdIntersects(msg)
case "search":
res, err = server.cmdSearch(msg)
case "bounds":
res, err = server.cmdBounds(msg)
case "get":
res, err = server.cmdGet(msg)
case "jget":
res, err = server.cmdJget(msg)
case "jset":
res, d, err = server.cmdJset(msg)
case "jdel":
res, d, err = server.cmdJdel(msg)
case "type":
res, err = server.cmdType(msg)
case "keys":
res, err = server.cmdKeys(msg)
case "output":
res, err = server.cmdOutput(msg)
case "aof":
res, err = server.cmdAOF(msg)
case "aofmd5":
res, err = server.cmdAOFMD5(msg)
case "gc":
runtime.GC()
debug.FreeOSMemory()
res = OKMessage(msg, time.Now())
case "aofshrink":
go server.aofshrink()
res = OKMessage(msg, time.Now())
case "config get":
res, err = server.cmdConfigGet(msg)
case "config set":
res, err = server.cmdConfigSet(msg)
case "config rewrite":
res, err = server.cmdConfigRewrite(msg)
case "config", "script":
// These get rewritten into "config foo" and "script bar"
err = fmt.Errorf("unknown command '%s'", msg.Args[0])
if len(msg.Args) > 1 {
msg.Args[1] = msg.Args[0] + " " + msg.Args[1]
msg.Args = msg.Args[1:]
msg._command = ""
return server.command(msg, client)
}
case "client":
res, err = server.cmdClient(msg, client)
case "eval", "evalro", "evalna":
res, err = server.cmdEvalUnified(false, msg)
case "evalsha", "evalrosha", "evalnasha":
res, err = server.cmdEvalUnified(true, msg)
case "script load":
res, err = server.cmdScriptLoad(msg)
case "script exists":
res, err = server.cmdScriptExists(msg)
case "script flush":
res, err = server.cmdScriptFlush(msg)
case "subscribe":
res, err = server.cmdSubscribe(msg)
case "psubscribe":
res, err = server.cmdPsubscribe(msg)
case "publish":
res, err = server.cmdPublish(msg)
case "test":
res, err = server.cmdTest(msg)
}
return
}
// This phrase is copied nearly verbatim from Redis.
var deniedMessage = []byte(strings.Replace(strings.TrimSpace(`
-DENIED Tile38 is running in protected mode because protected mode is enabled,
no bind address was specified, no authentication password is requested to
clients. In this mode connections are only accepted from the loopback
interface. If you want to connect from external computers to Tile38 you may
adopt one of the following solutions: 1) Just disable protected mode sending
the command 'CONFIG SET protected-mode no' from the loopback interface by
connecting to Tile38 from the same host the server is running, however MAKE
SURE Tile38 is not publicly accessible from internet if you do so. Use CONFIG
REWRITE to make this change permanent. 2) Alternatively you can just disable
the protected mode by editing the Tile38 configuration file, and setting the
protected mode option to 'no', and then restarting the server. 3) If you
started the server manually just for testing, restart it with the
'--protected-mode no' option. 4) Setup a bind address or an authentication
password. NOTE: You only need to do one of the above things in order for the
server to start accepting connections from the outside.
`), "\n", " ", -1) + "\r\n")
// SetKeepAlive sets the connection keepalive
func setKeepAlive(conn net.Conn, period time.Duration) error {
if tcp, ok := conn.(*net.TCPConn); ok {
if err := tcp.SetKeepAlive(true); err != nil {
return err
}
return tcp.SetKeepAlivePeriod(period)
}
return nil
}
var errCloseHTTP = errors.New("close http")
// WriteWebSocketMessage write a websocket message to an io.Writer.
func WriteWebSocketMessage(w io.Writer, data []byte) error {
var msg []byte
buf := make([]byte, 10+len(data))
buf[0] = 129 // FIN + TEXT
if len(data) <= 125 {
buf[1] = byte(len(data))
copy(buf[2:], data)
msg = buf[:2+len(data)]
} else if len(data) <= 0xFFFF {
buf[1] = 126
binary.BigEndian.PutUint16(buf[2:], uint16(len(data)))
copy(buf[4:], data)
msg = buf[:4+len(data)]
} else {
buf[1] = 127
binary.BigEndian.PutUint64(buf[2:], uint64(len(data)))
copy(buf[10:], data)
msg = buf[:10+len(data)]
}
_, err := w.Write(msg)
return err
}
// OKMessage returns a default OK message in JSON or RESP.
func OKMessage(msg *Message, start time.Time) resp.Value {
switch msg.OutputType {
case JSON:
return resp.StringValue(`{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}")
case RESP:
return resp.SimpleStringValue("OK")
}
return resp.SimpleStringValue("")
}
// NOMessage is no message
var NOMessage = resp.SimpleStringValue("")
var errInvalidHTTP = errors.New("invalid HTTP request")
// Type is resp type
type Type byte
// Protocol Types
const (
Null Type = iota
RESP
Telnet
Native
HTTP
WebSocket
JSON
)
// Message is a resp message
type Message struct {
_command string
Args []string
ConnType Type
OutputType Type
Auth string
}
// Command returns the first argument as a lowercase string
func (msg *Message) Command() string {
if msg._command == "" {
msg._command = strings.ToLower(msg.Args[0])
}
return msg._command
}
// PipelineReader ...
type PipelineReader struct {
rd io.Reader
wr io.Writer
packet [0xFFFF]byte
buf []byte
}
const kindHTTP redcon.Kind = 9999
// NewPipelineReader ...
func NewPipelineReader(rd io.ReadWriter) *PipelineReader {
return &PipelineReader{rd: rd, wr: rd}
}
func readcrlfline(packet []byte) (line string, leftover []byte, ok bool) {
for i := 1; i < len(packet); i++ {
if packet[i] == '\n' && packet[i-1] == '\r' {
return string(packet[:i-1]), packet[i+1:], true
}
}
return "", packet, false
}
func readNextHTTPCommand(packet []byte, argsIn [][]byte, msg *Message, wr io.Writer) (
complete bool, args [][]byte, kind redcon.Kind, leftover []byte, err error,
) {
args = argsIn[:0]
msg.ConnType = HTTP
msg.OutputType = JSON
opacket := packet
ready, err := func() (bool, error) {
var line string
var ok bool
// read header
var headers []string
for {
line, packet, ok = readcrlfline(packet)
if !ok {
return false, nil
}
if line == "" {
break
}
headers = append(headers, line)
}
parts := strings.Split(headers[0], " ")
if len(parts) != 3 {
return false, errInvalidHTTP
}
method := parts[0]
path := parts[1]
if len(path) == 0 || path[0] != '/' {
return false, errInvalidHTTP
}
path, err = url.QueryUnescape(path[1:])
if err != nil {
return false, errInvalidHTTP
}
if method != "GET" && method != "POST" {
return false, errInvalidHTTP
}
contentLength := 0
websocket := false
websocketVersion := 0
websocketKey := ""
for _, header := range headers[1:] {
if header[0] == 'a' || header[0] == 'A' {
if strings.HasPrefix(strings.ToLower(header), "authorization:") {
msg.Auth = strings.TrimSpace(header[len("authorization:"):])
}
} else if header[0] == 'u' || header[0] == 'U' {
if strings.HasPrefix(strings.ToLower(header), "upgrade:") && strings.ToLower(strings.TrimSpace(header[len("upgrade:"):])) == "websocket" {
websocket = true
}
} else if header[0] == 's' || header[0] == 'S' {
if strings.HasPrefix(strings.ToLower(header), "sec-websocket-version:") {
var n uint64
n, err = strconv.ParseUint(strings.TrimSpace(header[len("sec-websocket-version:"):]), 10, 64)
if err != nil {
return false, err
}
websocketVersion = int(n)
} else if strings.HasPrefix(strings.ToLower(header), "sec-websocket-key:") {
websocketKey = strings.TrimSpace(header[len("sec-websocket-key:"):])
}
} else if header[0] == 'c' || header[0] == 'C' {
if strings.HasPrefix(strings.ToLower(header), "content-length:") {
var n uint64
n, err = strconv.ParseUint(strings.TrimSpace(header[len("content-length:"):]), 10, 64)
if err != nil {
return false, err
}
contentLength = int(n)
}
}
}
if websocket && websocketVersion >= 13 && websocketKey != "" {
msg.ConnType = WebSocket
if wr == nil {
return false, errors.New("connection is nil")
}
sum := sha1.Sum([]byte(websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
accept := base64.StdEncoding.EncodeToString(sum[:])
wshead := "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + accept + "\r\n\r\n"
if _, err = wr.Write([]byte(wshead)); err != nil {
return false, err
}
} else if contentLength > 0 {
msg.ConnType = HTTP
if len(packet) < contentLength {
return false, nil
}
path += string(packet[:contentLength])
packet = packet[contentLength:]
}
if path == "" {
return true, nil
}
nmsg, err := readNativeMessageLine([]byte(path))
if err != nil {
return false, err
}
msg.OutputType = JSON
msg.Args = nmsg.Args
return true, nil
}()
if err != nil || !ready {
return false, args[:0], kindHTTP, opacket, err
}
return true, args[:0], kindHTTP, packet, nil
}
func readNextCommand(packet []byte, argsIn [][]byte, msg *Message, wr io.Writer) (
complete bool, args [][]byte, kind redcon.Kind, leftover []byte, err error,
) {
if packet[0] == 'G' || packet[0] == 'P' {
// could be an HTTP request
var line []byte
for i := 1; i < len(packet); i++ {
if packet[i] == '\n' {
if packet[i-1] == '\r' {
line = packet[:i+1]
break
}
}
}
if len(line) == 0 {
return false, argsIn[:0], redcon.Redis, packet, nil
}
if len(line) > 11 && string(line[len(line)-11:len(line)-5]) == " HTTP/" {
return readNextHTTPCommand(packet, argsIn, msg, wr)
}
}
return redcon.ReadNextCommand(packet, args)
}
// ReadMessages ...
func (rd *PipelineReader) ReadMessages() ([]*Message, error) {
var msgs []*Message
moreData:
n, err := rd.rd.Read(rd.packet[:])
if err != nil {
return nil, err
}
if n == 0 {
// need more data
goto moreData
}
data := rd.packet[:n]
if len(rd.buf) > 0 {
data = append(rd.buf, data...)
}
for len(data) > 0 {
msg := &Message{}
complete, args, kind, leftover, err := readNextCommand(data, nil, msg, rd.wr)
if err != nil {
break
}
if !complete {
break
}
if kind == kindHTTP {
if len(msg.Args) == 0 {
return nil, errInvalidHTTP
}
msgs = append(msgs, msg)
} else if len(args) > 0 {
for i := 0; i < len(args); i++ {
msg.Args = append(msg.Args, string(args[i]))
}
switch kind {
case redcon.Redis:
msg.ConnType = RESP
msg.OutputType = RESP
case redcon.Tile38:
msg.ConnType = Native
msg.OutputType = JSON
case redcon.Telnet:
msg.ConnType = RESP
msg.OutputType = RESP
}
msgs = append(msgs, msg)
}
data = leftover
}
if len(data) > 0 {
rd.buf = append(rd.buf[:0], data...)
} else if len(rd.buf) > 0 {
rd.buf = rd.buf[:0]
}
if err != nil && len(msgs) == 0 {
return nil, err
}
return msgs, nil
}
func readNativeMessageLine(line []byte) (*Message, error) {
var args []string
reading:
for len(line) != 0 {
if line[0] == '{' {
// The native protocol cannot understand json boundaries so it assumes that
// a json element must be at the end of the line.
args = append(args, string(line))
break
}
if line[0] == '"' && line[len(line)-1] == '"' {
if len(args) > 0 &&
strings.ToLower(args[0]) == "set" &&
strings.ToLower(args[len(args)-1]) == "string" {
// Setting a string value that is contained inside double quotes.
// This is only because of the boundary issues of the native protocol.
args = append(args, string(line[1:len(line)-1]))
break
}
}
i := 0
for ; i < len(line); i++ {
if line[i] == ' ' {
arg := string(line[:i])
if arg != "" {
args = append(args, arg)
}
line = line[i+1:]
continue reading
}
}
args = append(args, string(line))
break
}
return &Message{Args: args, ConnType: Native, OutputType: JSON}, nil
}