tile38/internal/server/server.go

1529 lines
40 KiB
Go
Raw Normal View History

2016-03-05 02:08:16 +03:00
package server
import (
"bytes"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
2016-03-29 15:53:53 +03:00
"encoding/binary"
2016-03-05 02:08:16 +03:00
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"runtime"
"runtime/debug"
"strconv"
2016-03-05 02:08:16 +03:00
"strings"
"sync"
"sync/atomic"
2016-03-29 15:53:53 +03:00
"time"
2016-03-05 02:08:16 +03:00
2021-01-26 00:34:40 +03:00
"github.com/tidwall/btree"
"github.com/tidwall/buntdb"
"github.com/tidwall/geojson"
"github.com/tidwall/geojson/geometry"
2020-11-07 15:22:11 +03:00
"github.com/tidwall/gjson"
"github.com/tidwall/redcon"
2017-10-05 18:20:40 +03:00
"github.com/tidwall/resp"
2021-02-08 03:54:56 +03:00
"github.com/tidwall/rtree"
"github.com/tidwall/tile38/core"
"github.com/tidwall/tile38/internal/collection"
2019-04-24 15:09:41 +03:00
"github.com/tidwall/tile38/internal/deadline"
"github.com/tidwall/tile38/internal/endpoint"
"github.com/tidwall/tile38/internal/expire"
"github.com/tidwall/tile38/internal/log"
2021-05-04 04:44:54 +03:00
"github.com/prometheus/client_golang/prometheus"
2016-03-05 02:08:16 +03:00
)
var errOOM = errors.New("OOM command not allowed when used memory > 'maxmemory'")
2019-04-26 21:50:49 +03:00
func errTimeoutOnCmd(cmd string) error {
return fmt.Errorf("timeout not supported for '%s'", cmd)
}
2016-03-08 03:37:39 +03:00
2018-11-24 01:53:33 +03:00
const (
goingLive = "going live"
hookLogPrefix = "hook:log:"
)
// commandDetails is detailed information about a mutable command. It's used
// for geofence formulas.
type commandDetails struct {
command string // client command, like "SET" or "DEL"
key, id string // collection key and object id of object
2018-12-28 04:15:53 +03:00
newKey string // new key, for RENAME command
2018-11-24 01:53:33 +03:00
fmap map[string]int // map of field names to value indexes
obj geojson.Object // new object
fields []float64 // array of field values
oldObj geojson.Object // previous object, if any
oldFields []float64 // previous object field values
updated bool // object was updated
timestamp time.Time // timestamp when the update occured
parent bool // when true, only children are forwarded
pattern string // PDEL key pattern
children []*commandDetails // for multi actions such as "PDEL"
2016-03-08 03:37:39 +03:00
}
// Server is a tile38 controller
type Server struct {
// static values
host string
port int
http bool
dir string
started time.Time
config *Config
epc *endpoint.Manager
// env opts
geomParseOpts geojson.ParseOptions
geomIndexOpts geometry.IndexOptions
2020-11-07 15:22:11 +03:00
http500Errors bool
// atomics
2018-11-23 12:14:26 +03:00
followc aint // counter increases when follow property changes
statsTotalConns aint // counter for total connections
statsTotalCommands aint // counter for total commands
2019-03-14 21:23:23 +03:00
statsTotalMsgsSent aint // counter for total sent webhook messages
2018-11-23 12:14:26 +03:00
statsExpired aint // item expiration counter
lastShrinkDuration aint
stopServer abool
outOfMemory abool
connsmu sync.RWMutex
conns map[int]*Client
2018-11-11 02:16:04 +03:00
mu sync.RWMutex
2021-01-26 00:34:40 +03:00
aof *os.File // active aof file
aofdirty int32 // mark the aofbuf as having data
aofbuf []byte // prewrite buffer
aofsz int // active size of the aof file
qdb *buntdb.DB // hook queue log
qidx uint64 // hook queue log last idx
cols *btree.BTree // data collections
follows map[*bytes.Buffer]bool
fcond *sync.Cond
2018-11-24 01:53:33 +03:00
lstack []*commandDetails
lives map[*liveBuffer]bool
lcond *sync.Cond
2018-11-24 04:15:14 +03:00
fcup bool // follow caught up
fcuponce bool // follow caught up once
shrinking bool // aof shrinking flag
shrinklog [][]string // aof shrinking log
hooks map[string]*Hook // hook name
2021-02-08 03:54:56 +03:00
hookCross rtree.RTree // hook spatial tree for "cross" geofences
hookTree rtree.RTree // hook spatial tree for all
2018-11-24 04:15:14 +03:00
hooksOut map[string]*Hook // hooks with "outside" detection
aofconnM map[net.Conn]io.Closer
luascripts *lScriptMap
luapool *lStatePool
pubsub *pubsub
hookex expire.List
2020-08-12 22:38:35 +03:00
monconnsMu sync.RWMutex
2021-04-28 15:09:48 +03:00
monconns map[net.Conn]bool // monitor connections
}
// Serve starts a new tile38 server
func Serve(host string, port int, dir string, useHTTP bool, metricsAddr string) error {
if core.AppendFileName == "" {
core.AppendFileName = path.Join(dir, "appendonly.aof")
}
if core.QueueFileName == "" {
core.QueueFileName = path.Join(dir, "queue.db")
}
log.Infof("Server started, Tile38 version %s, git %s", core.Version, core.GitSHA)
// Initialize the server
server := &Server{
host: host,
port: port,
dir: dir,
follows: make(map[*bytes.Buffer]bool),
fcond: sync.NewCond(&sync.Mutex{}),
lives: make(map[*liveBuffer]bool),
lcond: sync.NewCond(&sync.Mutex{}),
hooks: make(map[string]*Hook),
2018-11-24 04:15:14 +03:00
hooksOut: make(map[string]*Hook),
aofconnM: make(map[net.Conn]io.Closer),
started: time.Now(),
conns: make(map[int]*Client),
http: useHTTP,
pubsub: newPubsub(),
2020-08-12 22:38:35 +03:00
monconns: make(map[net.Conn]bool),
2021-07-31 17:42:58 +03:00
cols: btree.NewNonConcurrent(byCollectionKey),
}
server.hookex.Expired = func(item expire.Item) {
switch v := item.(type) {
case *Hook:
server.possiblyExpireHook(v.Name)
}
}
server.epc = endpoint.NewManager(server)
server.luascripts = server.newScriptMap()
server.luapool = server.newPool()
defer server.luapool.Shutdown()
if err := os.MkdirAll(dir, 0700); err != nil {
return err
}
var err error
server.config, err = loadConfig(filepath.Join(dir, "config"))
if err != nil {
return err
}
2016-03-05 02:08:16 +03:00
2020-11-07 15:22:11 +03:00
// Send "500 Internal Server" error instead of "200 OK" for json responses
// with `"ok":false`. T38HTTP500ERRORS=1
server.http500Errors, _ = strconv.ParseBool(os.Getenv("T38HTTP500ERRORS"))
// Allow for geometry indexing options through environment variables:
// T38IDXGEOMKIND -- None, RTree, QuadTree
// T38IDXGEOM -- Min number of points in a geometry for indexing.
// T38IDXMULTI -- Min number of object in a Multi/Collection for indexing.
server.geomParseOpts = *geojson.DefaultParseOptions
server.geomIndexOpts = *geometry.DefaultIndexOptions
n, err := strconv.ParseUint(os.Getenv("T38IDXGEOM"), 10, 32)
if err == nil {
server.geomParseOpts.IndexGeometry = int(n)
server.geomIndexOpts.MinPoints = int(n)
}
n, err = strconv.ParseUint(os.Getenv("T38IDXMULTI"), 10, 32)
if err == nil {
server.geomParseOpts.IndexChildren = int(n)
}
requireValid := os.Getenv("REQUIREVALID")
if requireValid != "" {
server.geomParseOpts.RequireValid = true
}
indexKind := os.Getenv("T38IDXGEOMKIND")
switch indexKind {
default:
log.Errorf("Unknown index kind: %s", indexKind)
case "":
case "None":
server.geomParseOpts.IndexGeometryKind = geometry.None
server.geomIndexOpts.Kind = geometry.None
case "RTree":
server.geomParseOpts.IndexGeometryKind = geometry.RTree
server.geomIndexOpts.Kind = geometry.RTree
case "QuadTree":
server.geomParseOpts.IndexGeometryKind = geometry.QuadTree
server.geomIndexOpts.Kind = geometry.QuadTree
}
if server.geomParseOpts.IndexGeometryKind == geometry.None {
log.Debugf("Geom indexing: %s",
server.geomParseOpts.IndexGeometryKind,
)
} else {
log.Debugf("Geom indexing: %s (%d points)",
server.geomParseOpts.IndexGeometryKind,
server.geomParseOpts.IndexGeometry,
)
}
log.Debugf("Multi indexing: RTree (%d points)", server.geomParseOpts.IndexChildren)
// Load the queue before the aof
qdb, err := buntdb.Open(core.QueueFileName)
if err != nil {
return err
}
var qidx uint64
if err := qdb.View(func(tx *buntdb.Tx) error {
val, err := tx.Get("hook:idx")
if err != nil {
if err == buntdb.ErrNotFound {
return nil
}
return err
}
qidx = stringToUint64(val)
return nil
}); err != nil {
return err
}
err = qdb.CreateIndex("hooks", hookLogPrefix+"*", buntdb.IndexJSONCaseSensitive("hook"))
2016-03-05 02:08:16 +03:00
if err != nil {
return err
}
server.qdb = qdb
server.qidx = qidx
if err := server.migrateAOF(); err != nil {
return err
}
if core.AppendOnly {
f, err := os.OpenFile(core.AppendFileName, os.O_CREATE|os.O_RDWR, 0600)
2016-03-05 02:08:16 +03:00
if err != nil {
return err
2016-03-05 02:08:16 +03:00
}
server.aof = f
if err := server.loadAOF(); err != nil {
return err
}
defer func() {
2019-03-10 20:48:14 +03:00
server.flushAOF(false)
server.aof.Sync()
}()
2016-03-05 02:08:16 +03:00
}
Fix excessive memory usage for objects with TTLs This commit fixes an issue where Tile38 was using lots of extra memory to track objects that are marked to expire. This was creating problems with applications that set big TTLs. How it worked before: Every collection had a unique hashmap that stores expiration timestamps for every object in that collection. Along with the hashmaps, there's also one big server-wide list that gets appended every time a new SET+EX is performed. From a background routine, this list is looped over at least 10 times per second and is randomly searched for potential candidates that might need expiring. The routine then removes those entries from the list and tests if the objects matching the entries have actually expired. If so, these objects are deleted them from the database. When at least 25% of the 20 candidates are deleted the loop is immediately continued, otherwise the loop backs off with a 100ms pause. Why this was a problem. The list grows one entry for every SET+EX. When TTLs are long, like 24-hours or more, it would take at least that much time before the entry is removed. So for databased that have objects that use TTLs and are updated often this could lead to a very large list. How it was fixed. The list was removed and the hashmap is now search randomly. This required a new hashmap implementation, as the built-in Go map does not provide an operation for randomly geting entries. The chosen implementation is a robinhood-hash because it provides open-addressing, which makes for simple random bucket selections. Issue #502
2019-10-29 21:04:07 +03:00
// server.fillExpiresList()
2016-03-05 02:08:16 +03:00
// Start background routines
if server.config.followHost() != "" {
go server.follow(server.config.followHost(), server.config.followPort(),
server.followc.get())
2016-03-08 16:11:03 +03:00
}
2021-05-04 04:44:54 +03:00
if metricsAddr != "" {
log.Infof("Listening for metrics at: %s", metricsAddr)
go func() {
http.HandleFunc("/", server.MetricsIndexHandler)
http.HandleFunc("/metrics", server.MetricsHandler)
log.Fatal(http.ListenAndServe(metricsAddr, nil))
2021-05-04 04:44:54 +03:00
}()
}
go server.processLives()
go server.watchOutOfMemory()
go server.watchLuaStatePool()
go server.watchAutoGC()
go server.backgroundExpiring()
2019-03-10 20:48:14 +03:00
go server.backgroundSyncAOF()
defer func() {
// Stop background routines
server.followc.add(1) // this will force any follow communication to die
2018-11-23 12:14:26 +03:00
server.stopServer.set(true)
// notify the live geofence connections that we are stopping.
server.lcond.L.Lock()
server.lcond.Wait()
server.lcond.L.Lock()
}()
// Start the network server
2018-11-06 01:24:45 +03:00
return server.netServe()
}
func (server *Server) isProtected() bool {
if core.ProtectedMode == "no" {
// --protected-mode no
return false
}
if server.host != "" && server.host != "127.0.0.1" &&
server.host != "::1" && server.host != "localhost" {
// -h address
return false
}
is := server.config.protectedMode() != "no" && server.config.requirePass() == ""
return is
}
2018-11-06 01:24:45 +03:00
func (server *Server) netServe() error {
ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", server.host, server.port))
if err != nil {
return err
}
defer ln.Close()
log.Infof("Ready to accept connections at %s", ln.Addr())
var clientID int64
for {
conn, err := ln.Accept()
if err != nil {
return err
}
go func(conn net.Conn) {
// open connection
// create the client
client := new(Client)
client.id = int(atomic.AddInt64(&clientID, 1))
client.opened = time.Now()
client.remoteAddr = conn.RemoteAddr().String()
// add client to server map
server.connsmu.Lock()
server.conns[client.id] = client
server.connsmu.Unlock()
server.statsTotalConns.add(1)
// set the client keep-alive, if needed
if server.config.keepAlive() > 0 {
if conn, ok := conn.(*net.TCPConn); ok {
conn.SetKeepAlive(true)
conn.SetKeepAlivePeriod(
time.Duration(server.config.keepAlive()) * time.Second,
)
}
}
log.Debugf("Opened connection: %s", client.remoteAddr)
defer func() {
// close connection
// delete from server map
server.connsmu.Lock()
delete(server.conns, client.id)
server.connsmu.Unlock()
log.Debugf("Closed connection: %s", client.remoteAddr)
conn.Close()
}()
var lastConnType Type
var lastOutputType Type
2018-11-06 01:24:45 +03:00
// check if the connection is protected
if !strings.HasPrefix(client.remoteAddr, "127.0.0.1:") &&
!strings.HasPrefix(client.remoteAddr, "[::1]:") {
if server.isProtected() {
// This is a protected server. Only loopback is allowed.
conn.Write(deniedMessage)
return // close connection
}
}
packet := make([]byte, 0xFFFF)
for {
var close bool
n, err := conn.Read(packet)
if err != nil {
return
}
in := packet[:n]
// read the payload packet from the client input stream.
packet := client.in.Begin(in)
// load the pipeline reader
pr := &client.pr
rdbuf := bytes.NewBuffer(packet)
pr.rd = rdbuf
pr.wr = client
msgs, err := pr.ReadMessages()
for _, msg := range msgs {
// Just closing connection if we have deprecated HTTP or WS connection,
// And --http-transport = false
if !server.http && (msg.ConnType == WebSocket ||
msg.ConnType == HTTP) {
close = true // close connection
break
}
if msg != nil && msg.Command() != "" {
if client.outputType != Null {
msg.OutputType = client.outputType
}
if msg.Command() == "quit" {
if msg.OutputType == RESP {
io.WriteString(client, "+OK\r\n")
}
close = true // close connection
break
}
// increment last used
client.mu.Lock()
client.last = time.Now()
client.mu.Unlock()
// update total command count
server.statsTotalCommands.add(1)
// handle the command
err := server.handleInputCommand(client, msg)
if err != nil {
if err.Error() == goingLive {
client.goLiveErr = err
client.goLiveMsg = msg
// detach
var rwc io.ReadWriteCloser = conn
client.conn = rwc
if len(client.out) > 0 {
client.conn.Write(client.out)
client.out = nil
}
2019-04-26 21:50:49 +03:00
client.in = InputStream{}
2018-11-06 01:24:45 +03:00
client.pr.rd = rwc
client.pr.wr = rwc
log.Debugf("Detached connection: %s", client.remoteAddr)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
err := server.goLive(
client.goLiveErr,
&liveConn{conn.RemoteAddr(), rwc},
&client.pr,
client.goLiveMsg,
client.goLiveMsg.ConnType == WebSocket,
)
if err != nil {
log.Error(err)
}
}()
wg.Wait()
return // close connection
}
log.Error(err)
return // close connection, NOW
}
client.outputType = msg.OutputType
} else {
client.Write([]byte("HTTP/1.1 500 Bad Request\r\nConnection: close\r\n\r\n"))
break
}
if msg.ConnType == HTTP || msg.ConnType == WebSocket {
close = true // close connection
break
}
lastOutputType = msg.OutputType
lastConnType = msg.ConnType
2018-11-06 01:24:45 +03:00
}
packet = packet[len(packet)-rdbuf.Len():]
client.in.End(packet)
// write to client
if len(client.out) > 0 {
2018-11-11 02:21:07 +03:00
if atomic.LoadInt32(&server.aofdirty) != 0 {
func() {
// prewrite
server.mu.Lock()
defer server.mu.Unlock()
2019-03-10 20:48:14 +03:00
server.flushAOF(false)
2018-11-11 02:21:07 +03:00
}()
atomic.StoreInt32(&server.aofdirty, 0)
}
2018-11-06 01:24:45 +03:00
conn.Write(client.out)
2018-11-10 23:30:56 +03:00
client.out = nil
2018-11-06 01:24:45 +03:00
}
if close {
break
}
if err != nil {
log.Error(err)
if lastConnType == RESP {
var value resp.Value
switch lastOutputType {
case JSON:
value = resp.StringValue(`{"ok":false,"err":` +
jsonString(err.Error()) + "}")
case RESP:
value = resp.ErrorValue(err)
}
bytes, _ := value.MarshalRESP()
conn.Write(bytes)
}
break // close connection
}
2018-11-06 01:24:45 +03:00
}
}(conn)
}
}
type liveConn struct {
remoteAddr net.Addr
rwc io.ReadWriteCloser
}
func (conn *liveConn) Close() error {
return conn.rwc.Close()
}
func (conn *liveConn) LocalAddr() net.Addr {
panic("not supported")
}
func (conn *liveConn) RemoteAddr() net.Addr {
return conn.remoteAddr
}
func (conn *liveConn) Read(b []byte) (n int, err error) {
return conn.rwc.Read(b)
}
func (conn *liveConn) Write(b []byte) (n int, err error) {
return conn.rwc.Write(b)
}
func (conn *liveConn) SetDeadline(deadline time.Time) error {
panic("not supported")
}
func (conn *liveConn) SetReadDeadline(deadline time.Time) error {
panic("not supported")
}
func (conn *liveConn) SetWriteDeadline(deadline time.Time) error {
panic("not supported")
}
func (server *Server) watchAutoGC() {
t := time.NewTicker(time.Second)
defer t.Stop()
s := time.Now()
for range t.C {
2018-11-23 12:14:26 +03:00
if server.stopServer.on() {
return
}
autoGC := server.config.autoGC()
if autoGC == 0 {
continue
}
if time.Since(s) < time.Second*time.Duration(autoGC) {
continue
}
var mem1, mem2 runtime.MemStats
runtime.ReadMemStats(&mem1)
log.Debugf("autogc(before): "+
"alloc: %v, heap_alloc: %v, heap_released: %v",
mem1.Alloc, mem1.HeapAlloc, mem1.HeapReleased)
runtime.GC()
debug.FreeOSMemory()
runtime.ReadMemStats(&mem2)
log.Debugf("autogc(after): "+
"alloc: %v, heap_alloc: %v, heap_released: %v",
mem2.Alloc, mem2.HeapAlloc, mem2.HeapReleased)
s = time.Now()
}
}
func (server *Server) watchOutOfMemory() {
t := time.NewTicker(time.Second * 2)
defer t.Stop()
var mem runtime.MemStats
for range t.C {
func() {
2018-11-23 12:14:26 +03:00
if server.stopServer.on() {
return
}
oom := server.outOfMemory.on()
if server.config.maxMemory() == 0 {
if oom {
server.outOfMemory.set(false)
}
return
}
if oom {
runtime.GC()
}
runtime.ReadMemStats(&mem)
server.outOfMemory.set(int(mem.HeapAlloc) > server.config.maxMemory())
}()
}
}
func (server *Server) watchLuaStatePool() {
t := time.NewTicker(time.Second * 10)
defer t.Stop()
for range t.C {
func() {
server.luapool.Prune()
}()
}
}
2019-03-10 20:48:14 +03:00
// backgroundSyncAOF ensures that the aof buffer is does not grow too big.
func (server *Server) backgroundSyncAOF() {
t := time.NewTicker(time.Second)
defer t.Stop()
for range t.C {
if server.stopServer.on() {
return
}
func() {
server.mu.Lock()
defer server.mu.Unlock()
2019-09-04 03:01:26 +03:00
server.flushAOF(true)
2019-03-10 20:48:14 +03:00
}()
}
}
2021-01-26 00:34:40 +03:00
// collectionKeyContainer is a wrapper object around a collection that includes
// the collection and the key. It's needed for support with the btree package,
// which requires a comparator less function.
type collectionKeyContainer struct {
key string
col *collection.Collection
}
func byCollectionKey(a, b interface{}) bool {
return a.(*collectionKeyContainer).key < b.(*collectionKeyContainer).key
}
func (server *Server) setCol(key string, col *collection.Collection) {
2021-01-26 00:34:40 +03:00
server.cols.Set(&collectionKeyContainer{key, col})
}
func (server *Server) getCol(key string) *collection.Collection {
2021-01-26 00:34:40 +03:00
if v := server.cols.Get(&collectionKeyContainer{key: key}); v != nil {
return v.(*collectionKeyContainer).col
}
return nil
}
func (server *Server) scanGreaterOrEqual(
key string, iterator func(key string, col *collection.Collection) bool,
) {
2021-01-26 00:34:40 +03:00
server.cols.Ascend(&collectionKeyContainer{key: key},
func(v interface{}) bool {
vcol := v.(*collectionKeyContainer)
return iterator(vcol.key, vcol.col)
},
)
}
func (server *Server) deleteCol(key string) *collection.Collection {
2021-01-26 00:34:40 +03:00
if v := server.cols.Delete(&collectionKeyContainer{key: key}); v != nil {
return v.(*collectionKeyContainer).col
}
return nil
}
func isReservedFieldName(field string) bool {
switch field {
case "z", "lat", "lon":
return true
}
return false
}
func rewriteTimeoutMsg(msg *Message) (err error) {
vs := msg.Args[1:]
var valStr string
var ok bool
if vs, valStr, ok = tokenval(vs); !ok || valStr == "" || len(vs) == 0 {
err = errInvalidNumberOfArguments
return
}
timeoutSec, _err := strconv.ParseFloat(valStr, 64)
if _err != nil || timeoutSec < 0 {
err = errInvalidArgument(valStr)
return
}
msg.Args = vs[:]
msg._command = ""
msg.Deadline = deadline.New(
time.Now().Add(time.Duration(timeoutSec * float64(time.Second))))
return
}
func (server *Server) handleInputCommand(client *Client, msg *Message) error {
start := time.Now()
serializeOutput := func(res resp.Value) (string, error) {
var resStr string
var err error
switch msg.OutputType {
case JSON:
resStr = res.String()
case RESP:
var resBytes []byte
resBytes, err = res.MarshalRESP()
resStr = string(resBytes)
2016-03-28 18:57:41 +03:00
}
return resStr, err
2016-03-05 02:08:16 +03:00
}
writeOutput := func(res string) error {
switch msg.ConnType {
default:
err := fmt.Errorf("unsupported conn type: %v", msg.ConnType)
log.Error(err)
return err
case WebSocket:
return WriteWebSocketMessage(client, []byte(res))
case HTTP:
2020-11-07 15:22:11 +03:00
status := "200 OK"
if (server.http500Errors || msg._command == "healthz") &&
!gjson.Get(res, "ok").Bool() {
2020-11-07 15:22:11 +03:00
status = "500 Internal Server Error"
}
_, err := fmt.Fprintf(client, "HTTP/1.1 %s\r\n"+
"Connection: close\r\n"+
"Content-Length: %d\r\n"+
"Content-Type: application/json; charset=utf-8\r\n"+
2020-11-07 15:22:11 +03:00
"\r\n", status, len(res)+2)
if err != nil {
return err
}
_, err = io.WriteString(client, res)
if err != nil {
return err
}
_, err = io.WriteString(client, "\r\n")
return err
case RESP:
var err error
if msg.OutputType == JSON {
_, err = fmt.Fprintf(client, "$%d\r\n%s\r\n", len(res), res)
} else {
_, err = io.WriteString(client, res)
}
return err
case Native:
_, err := fmt.Fprintf(client, "$%d %s\r\n", len(res), res)
return err
}
}
2018-11-10 23:30:56 +03:00
2021-05-04 04:44:54 +03:00
cmd := msg.Command()
defer func() {
2021-07-10 13:59:27 +03:00
took := time.Since(start).Seconds()
2021-05-04 04:44:54 +03:00
cmdDurations.With(prometheus.Labels{"cmd": cmd}).Observe(took)
}()
// Ping. Just send back the response. No need to put through the pipeline.
2021-05-04 04:44:54 +03:00
if cmd == "ping" || cmd == "echo" {
switch msg.OutputType {
case JSON:
if len(msg.Args) > 1 {
2021-05-04 04:44:54 +03:00
return writeOutput(`{"ok":true,"` + cmd + `":` + jsonString(msg.Args[1]) + `,"elapsed":"` + time.Since(start).String() + `"}`)
}
2021-05-04 04:44:54 +03:00
return writeOutput(`{"ok":true,"` + cmd + `":"pong","elapsed":"` + time.Since(start).String() + `"}`)
case RESP:
if len(msg.Args) > 1 {
data := redcon.AppendBulkString(nil, msg.Args[1])
return writeOutput(string(data))
}
return writeOutput("+PONG\r\n")
}
2020-08-12 22:38:35 +03:00
server.sendMonitor(nil, msg, client, false)
return nil
}
2018-11-10 23:30:56 +03:00
writeErr := func(errMsg string) error {
switch msg.OutputType {
case JSON:
return writeOutput(`{"ok":false,"err":` + jsonString(errMsg) + `,"elapsed":"` + time.Since(start).String() + "\"}")
case RESP:
if errMsg == errInvalidNumberOfArguments.Error() {
2021-05-04 04:44:54 +03:00
return writeOutput("-ERR wrong number of arguments for '" + cmd + "' command\r\n")
}
v, _ := resp.ErrorValue(errors.New("ERR " + errMsg)).MarshalRESP()
return writeOutput(string(v))
}
return nil
}
2021-05-04 04:44:54 +03:00
if cmd == "timeout" {
if err := rewriteTimeoutMsg(msg); err != nil {
return writeErr(err.Error())
}
}
var write bool
2021-05-04 04:44:54 +03:00
if (!client.authd || cmd == "auth") && cmd != "output" {
if server.config.requirePass() != "" {
password := ""
// This better be an AUTH command or the Message should contain an Auth
2021-05-04 04:44:54 +03:00
if cmd != "auth" && msg.Auth == "" {
// Just shut down the pipeline now. The less the client connection knows the better.
return writeErr("authentication required")
}
if msg.Auth != "" {
password = msg.Auth
} else {
if len(msg.Args) > 1 {
password = msg.Args[1]
}
}
if server.config.requirePass() != strings.TrimSpace(password) {
return writeErr("invalid password")
}
client.authd = true
if msg.ConnType != HTTP {
resStr, _ := serializeOutput(OKMessage(msg, start))
return writeOutput(resStr)
}
} else if msg.Command() == "auth" {
return writeErr("invalid password")
}
}
2018-11-10 23:30:56 +03:00
// choose the locking strategy
switch msg.Command() {
default:
server.mu.RLock()
defer server.mu.RUnlock()
case "set", "del", "drop", "fset", "flushdb",
"setchan", "pdelchan", "delchan",
"sethook", "pdelhook", "delhook",
2018-12-28 04:15:53 +03:00
"expire", "persist", "jset", "pdel", "rename", "renamenx":
// write operations
write = true
server.mu.Lock()
defer server.mu.Unlock()
if server.config.followHost() != "" {
return writeErr("not the leader")
}
if server.config.readOnly() {
return writeErr("read only")
}
case "eval", "evalsha":
// write operations (potentially) but no AOF for the script command itself
server.mu.Lock()
defer server.mu.Unlock()
if server.config.followHost() != "" {
return writeErr("not the leader")
}
if server.config.readOnly() {
return writeErr("read only")
}
case "get", "keys", "scan", "nearby", "within", "intersects", "hooks",
"chans", "search", "ttl", "bounds", "server", "info", "type", "jget",
"evalro", "evalrosha", "healthz":
// read operations
2018-11-10 23:30:56 +03:00
server.mu.RLock()
defer server.mu.RUnlock()
if server.config.followHost() != "" && !server.fcuponce {
return writeErr("catching up to leader")
}
case "follow", "slaveof", "replconf", "readonly", "config":
// system operations
// does not write to aof, but requires a write lock.
server.mu.Lock()
defer server.mu.Unlock()
case "output":
// this is local connection operation. Locks not needed.
case "echo":
case "massinsert":
// dev operation
case "sleep":
// dev operation
server.mu.RLock()
defer server.mu.RUnlock()
case "shutdown":
// dev operation
server.mu.Lock()
defer server.mu.Unlock()
case "aofshrink":
server.mu.RLock()
defer server.mu.RUnlock()
case "client":
server.mu.Lock()
defer server.mu.Unlock()
case "evalna", "evalnasha":
// No locking for scripts, otherwise writes cannot happen within scripts
case "subscribe", "psubscribe", "publish":
// No locking for pubsub
2021-05-04 04:44:54 +03:00
case "monitor":
2020-08-12 22:38:35 +03:00
// No locking for monitor
}
2019-04-24 15:09:41 +03:00
res, d, err := func() (res resp.Value, d commandDetails, err error) {
if msg.Deadline != nil {
if write {
res = NOMessage
2019-04-26 21:50:49 +03:00
err = errTimeoutOnCmd(msg.Command())
return
2019-04-24 23:20:57 +03:00
}
2019-04-24 15:09:41 +03:00
defer func() {
if msg.Deadline.Hit() {
v := recover()
if v != nil {
if s, ok := v.(string); !ok || s != "deadline" {
panic(v)
}
}
res = NOMessage
err = writeErr("timeout")
}
}()
}
return server.command(msg, client)
}()
if res.Type() == resp.Error {
return writeErr(res.String())
}
if err != nil {
if err.Error() == goingLive {
return err
}
return writeErr(err.Error())
}
if write {
if err := server.writeAOF(msg.Args, &d); err != nil {
if _, ok := err.(errAOFHook); ok {
return writeErr(err.Error())
}
log.Fatal(err)
return err
}
}
if !isRespValueEmptyString(res) {
var resStr string
resStr, err := serializeOutput(res)
if err != nil {
return err
}
if err := writeOutput(resStr); err != nil {
return err
}
}
return nil
}
func isRespValueEmptyString(val resp.Value) bool {
return !val.IsNull() && (val.Type() == resp.SimpleString || val.Type() == resp.BulkString) && len(val.Bytes()) == 0
}
func randomKey(n int) string {
b := make([]byte, n)
nn, err := rand.Read(b)
if err != nil {
panic(err)
}
if nn != n {
panic("random failed")
}
return fmt.Sprintf("%x", b)
}
func (server *Server) reset() {
server.aofsz = 0
2021-07-31 17:42:58 +03:00
server.cols = btree.NewNonConcurrent(byCollectionKey)
}
func (server *Server) command(msg *Message, client *Client) (
2018-11-24 01:53:33 +03:00
res resp.Value, d commandDetails, err error,
) {
switch msg.Command() {
default:
err = fmt.Errorf("unknown command '%s'", msg.Args[0])
case "set":
res, d, err = server.cmdSet(msg)
case "fset":
res, d, err = server.cmdFset(msg)
case "del":
res, d, err = server.cmdDel(msg)
case "pdel":
res, d, err = server.cmdPdel(msg)
case "drop":
res, d, err = server.cmdDrop(msg)
case "flushdb":
res, d, err = server.cmdFlushDB(msg)
2018-12-28 04:15:53 +03:00
case "rename":
res, d, err = server.cmdRename(msg, false)
case "renamenx":
res, d, err = server.cmdRename(msg, true)
case "sethook":
res, d, err = server.cmdSetHook(msg, false)
case "delhook":
res, d, err = server.cmdDelHook(msg, false)
case "pdelhook":
res, d, err = server.cmdPDelHook(msg, false)
case "hooks":
res, err = server.cmdHooks(msg, false)
case "setchan":
res, d, err = server.cmdSetHook(msg, true)
case "delchan":
res, d, err = server.cmdDelHook(msg, true)
case "pdelchan":
res, d, err = server.cmdPDelHook(msg, true)
case "chans":
res, err = server.cmdHooks(msg, true)
case "expire":
res, d, err = server.cmdExpire(msg)
case "persist":
res, d, err = server.cmdPersist(msg)
case "ttl":
res, err = server.cmdTTL(msg)
case "shutdown":
if !core.DevMode {
err = fmt.Errorf("unknown command '%s'", msg.Args[0])
return
}
log.Fatal("shutdown requested by developer")
case "massinsert":
if !core.DevMode {
err = fmt.Errorf("unknown command '%s'", msg.Args[0])
return
}
res, err = server.cmdMassInsert(msg)
case "sleep":
if !core.DevMode {
err = fmt.Errorf("unknown command '%s'", msg.Args[0])
return
}
res, err = server.cmdSleep(msg)
case "follow", "slaveof":
res, err = server.cmdFollow(msg)
case "replconf":
res, err = server.cmdReplConf(msg, client)
case "readonly":
res, err = server.cmdReadOnly(msg)
case "stats":
res, err = server.cmdStats(msg)
case "server":
res, err = server.cmdServer(msg)
case "healthz":
res, err = server.cmdHealthz(msg)
case "info":
res, err = server.cmdInfo(msg)
case "scan":
res, err = server.cmdScan(msg)
case "nearby":
res, err = server.cmdNearby(msg)
case "within":
res, err = server.cmdWithin(msg)
case "intersects":
res, err = server.cmdIntersects(msg)
case "search":
res, err = server.cmdSearch(msg)
case "bounds":
res, err = server.cmdBounds(msg)
case "get":
res, err = server.cmdGet(msg)
case "jget":
res, err = server.cmdJget(msg)
case "jset":
res, d, err = server.cmdJset(msg)
case "jdel":
res, d, err = server.cmdJdel(msg)
case "type":
res, err = server.cmdType(msg)
case "keys":
res, err = server.cmdKeys(msg)
case "output":
res, err = server.cmdOutput(msg)
case "aof":
res, err = server.cmdAOF(msg)
case "aofmd5":
res, err = server.cmdAOFMD5(msg)
case "gc":
runtime.GC()
debug.FreeOSMemory()
res = OKMessage(msg, time.Now())
case "aofshrink":
go server.aofshrink()
res = OKMessage(msg, time.Now())
case "config get":
res, err = server.cmdConfigGet(msg)
case "config set":
res, err = server.cmdConfigSet(msg)
case "config rewrite":
res, err = server.cmdConfigRewrite(msg)
case "config", "script":
// These get rewritten into "config foo" and "script bar"
err = fmt.Errorf("unknown command '%s'", msg.Args[0])
if len(msg.Args) > 1 {
2018-11-20 21:25:48 +03:00
msg.Args[1] = msg.Args[0] + " " + msg.Args[1]
msg.Args = msg.Args[1:]
2018-11-20 21:25:48 +03:00
msg._command = ""
return server.command(msg, client)
}
case "client":
res, err = server.cmdClient(msg, client)
case "eval", "evalro", "evalna":
res, err = server.cmdEvalUnified(false, msg)
case "evalsha", "evalrosha", "evalnasha":
res, err = server.cmdEvalUnified(true, msg)
case "script load":
res, err = server.cmdScriptLoad(msg)
case "script exists":
res, err = server.cmdScriptExists(msg)
case "script flush":
res, err = server.cmdScriptFlush(msg)
case "subscribe":
res, err = server.cmdSubscribe(msg)
case "psubscribe":
res, err = server.cmdPsubscribe(msg)
case "publish":
res, err = server.cmdPublish(msg)
2019-02-09 00:57:29 +03:00
case "test":
res, err = server.cmdTest(msg)
2020-08-12 22:38:35 +03:00
case "monitor":
res, err = server.cmdMonitor(msg)
}
2020-08-12 22:38:35 +03:00
server.sendMonitor(err, msg, client, false)
return
}
// This phrase is copied nearly verbatim from Redis.
var deniedMessage = []byte(strings.Replace(strings.TrimSpace(`
-DENIED Tile38 is running in protected mode because protected mode is enabled,
no bind address was specified, no authentication password is requested to
clients. In this mode connections are only accepted from the loopback
interface. If you want to connect from external computers to Tile38 you may
adopt one of the following solutions: 1) Just disable protected mode sending
the command 'CONFIG SET protected-mode no' from the loopback interface by
connecting to Tile38 from the same host the server is running, however MAKE
SURE Tile38 is not publicly accessible from internet if you do so. Use CONFIG
REWRITE to make this change permanent. 2) Alternatively you can just disable
the protected mode by editing the Tile38 configuration file, and setting the
protected mode option to 'no', and then restarting the server. 3) If you
started the server manually just for testing, restart it with the
'--protected-mode no' option. 4) Setup a bind address or an authentication
password. NOTE: You only need to do one of the above things in order for the
server to start accepting connections from the outside.
`), "\n", " ", -1) + "\r\n")
2016-04-03 05:16:36 +03:00
// WriteWebSocketMessage write a websocket message to an io.Writer.
2016-03-29 15:53:53 +03:00
func WriteWebSocketMessage(w io.Writer, data []byte) error {
var msg []byte
buf := make([]byte, 10+len(data))
buf[0] = 129 // FIN + TEXT
if len(data) <= 125 {
buf[1] = byte(len(data))
copy(buf[2:], data)
msg = buf[:2+len(data)]
} else if len(data) <= 0xFFFF {
buf[1] = 126
binary.BigEndian.PutUint16(buf[2:], uint16(len(data)))
copy(buf[4:], data)
msg = buf[:4+len(data)]
} else {
buf[1] = 127
binary.BigEndian.PutUint64(buf[2:], uint64(len(data)))
copy(buf[10:], data)
msg = buf[:10+len(data)]
}
_, err := w.Write(msg)
return err
}
2016-04-03 05:16:36 +03:00
// OKMessage returns a default OK message in JSON or RESP.
2017-10-05 18:20:40 +03:00
func OKMessage(msg *Message, start time.Time) resp.Value {
2016-03-29 15:53:53 +03:00
switch msg.OutputType {
case JSON:
return resp.StringValue(`{"ok":true,"elapsed":"` + time.Since(start).String() + "\"}")
2016-03-29 15:53:53 +03:00
case RESP:
2017-10-05 18:20:40 +03:00
return resp.SimpleStringValue("OK")
2016-03-29 15:53:53 +03:00
}
2017-10-05 18:20:40 +03:00
return resp.SimpleStringValue("")
2016-03-29 15:53:53 +03:00
}
2017-10-05 18:20:40 +03:00
// NOMessage is no message
2017-10-05 18:20:40 +03:00
var NOMessage = resp.SimpleStringValue("")
var errInvalidHTTP = errors.New("invalid HTTP request")
// Type is resp type
type Type byte
// Protocol Types
const (
Null Type = iota
RESP
Telnet
Native
HTTP
WebSocket
JSON
)
// Message is a resp message
type Message struct {
2018-11-10 23:30:56 +03:00
_command string
Args []string
ConnType Type
OutputType Type
Auth string
2019-04-24 15:09:41 +03:00
Deadline *deadline.Deadline
}
// Command returns the first argument as a lowercase string
func (msg *Message) Command() string {
2018-11-10 23:30:56 +03:00
if msg._command == "" {
msg._command = strings.ToLower(msg.Args[0])
}
return msg._command
}
// PipelineReader ...
type PipelineReader struct {
rd io.Reader
wr io.Writer
packet [0xFFFF]byte
buf []byte
}
const kindHTTP redcon.Kind = 9999
// NewPipelineReader ...
func NewPipelineReader(rd io.ReadWriter) *PipelineReader {
return &PipelineReader{rd: rd, wr: rd}
}
func readcrlfline(packet []byte) (line string, leftover []byte, ok bool) {
for i := 1; i < len(packet); i++ {
if packet[i] == '\n' && packet[i-1] == '\r' {
return string(packet[:i-1]), packet[i+1:], true
}
}
return "", packet, false
}
func readNextHTTPCommand(packet []byte, argsIn [][]byte, msg *Message, wr io.Writer) (
complete bool, args [][]byte, kind redcon.Kind, leftover []byte, err error,
) {
args = argsIn[:0]
msg.ConnType = HTTP
msg.OutputType = JSON
opacket := packet
ready, err := func() (bool, error) {
var line string
var ok bool
// read header
var headers []string
for {
line, packet, ok = readcrlfline(packet)
if !ok {
return false, nil
}
if line == "" {
break
}
headers = append(headers, line)
}
parts := strings.Split(headers[0], " ")
if len(parts) != 3 {
return false, errInvalidHTTP
}
method := parts[0]
path := parts[1]
if len(path) == 0 || path[0] != '/' {
return false, errInvalidHTTP
}
path, err = url.QueryUnescape(path[1:])
if err != nil {
return false, errInvalidHTTP
}
if method != "GET" && method != "POST" {
return false, errInvalidHTTP
}
contentLength := 0
websocket := false
websocketVersion := 0
websocketKey := ""
for _, header := range headers[1:] {
if header[0] == 'a' || header[0] == 'A' {
if strings.HasPrefix(strings.ToLower(header), "authorization:") {
msg.Auth = strings.TrimSpace(header[len("authorization:"):])
}
} else if header[0] == 'u' || header[0] == 'U' {
if strings.HasPrefix(strings.ToLower(header), "upgrade:") && strings.ToLower(strings.TrimSpace(header[len("upgrade:"):])) == "websocket" {
websocket = true
}
} else if header[0] == 's' || header[0] == 'S' {
if strings.HasPrefix(strings.ToLower(header), "sec-websocket-version:") {
var n uint64
n, err = strconv.ParseUint(strings.TrimSpace(header[len("sec-websocket-version:"):]), 10, 64)
if err != nil {
return false, err
}
websocketVersion = int(n)
} else if strings.HasPrefix(strings.ToLower(header), "sec-websocket-key:") {
websocketKey = strings.TrimSpace(header[len("sec-websocket-key:"):])
}
} else if header[0] == 'c' || header[0] == 'C' {
if strings.HasPrefix(strings.ToLower(header), "content-length:") {
var n uint64
n, err = strconv.ParseUint(strings.TrimSpace(header[len("content-length:"):]), 10, 64)
if err != nil {
return false, err
}
contentLength = int(n)
}
}
}
if websocket && websocketVersion >= 13 && websocketKey != "" {
msg.ConnType = WebSocket
if wr == nil {
return false, errors.New("connection is nil")
}
sum := sha1.Sum([]byte(websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
accept := base64.StdEncoding.EncodeToString(sum[:])
wshead := "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + accept + "\r\n\r\n"
if _, err = wr.Write([]byte(wshead)); err != nil {
return false, err
}
} else if contentLength > 0 {
msg.ConnType = HTTP
if len(packet) < contentLength {
return false, nil
}
path += string(packet[:contentLength])
packet = packet[contentLength:]
}
if path == "" {
return true, nil
}
nmsg, err := readNativeMessageLine([]byte(path))
if err != nil {
return false, err
}
msg.OutputType = JSON
msg.Args = nmsg.Args
return true, nil
}()
if err != nil || !ready {
return false, args[:0], kindHTTP, opacket, err
}
return true, args[:0], kindHTTP, packet, nil
}
func readNextCommand(packet []byte, argsIn [][]byte, msg *Message, wr io.Writer) (
complete bool, args [][]byte, kind redcon.Kind, leftover []byte, err error,
) {
if packet[0] == 'G' || packet[0] == 'P' {
// could be an HTTP request
var line []byte
for i := 1; i < len(packet); i++ {
if packet[i] == '\n' {
if packet[i-1] == '\r' {
line = packet[:i+1]
break
}
}
}
if len(line) == 0 {
return false, argsIn[:0], redcon.Redis, packet, nil
}
if len(line) > 11 && string(line[len(line)-11:len(line)-5]) == " HTTP/" {
return readNextHTTPCommand(packet, argsIn, msg, wr)
}
}
return redcon.ReadNextCommand(packet, args)
}
// ReadMessages ...
func (rd *PipelineReader) ReadMessages() ([]*Message, error) {
var msgs []*Message
moreData:
n, err := rd.rd.Read(rd.packet[:])
if err != nil {
return nil, err
}
if n == 0 {
// need more data
goto moreData
}
data := rd.packet[:n]
if len(rd.buf) > 0 {
data = append(rd.buf, data...)
}
for len(data) > 0 {
msg := &Message{}
complete, args, kind, leftover, err2 :=
readNextCommand(data, nil, msg, rd.wr)
if err2 != nil {
err = err2
break
}
if !complete {
break
}
if kind == kindHTTP {
if len(msg.Args) == 0 {
return nil, errInvalidHTTP
}
msgs = append(msgs, msg)
} else if len(args) > 0 {
for i := 0; i < len(args); i++ {
msg.Args = append(msg.Args, string(args[i]))
}
switch kind {
case redcon.Redis:
msg.ConnType = RESP
msg.OutputType = RESP
case redcon.Tile38:
msg.ConnType = Native
msg.OutputType = JSON
case redcon.Telnet:
msg.ConnType = RESP
msg.OutputType = RESP
}
msgs = append(msgs, msg)
}
data = leftover
}
if len(data) > 0 {
rd.buf = append(rd.buf[:0], data...)
} else if len(rd.buf) > 0 {
rd.buf = rd.buf[:0]
}
return msgs, err
}
func readNativeMessageLine(line []byte) (*Message, error) {
var args []string
reading:
for len(line) != 0 {
if line[0] == '{' {
// The native protocol cannot understand json boundaries so it assumes that
// a json element must be at the end of the line.
args = append(args, string(line))
break
}
if line[0] == '"' && line[len(line)-1] == '"' {
if len(args) > 0 &&
strings.ToLower(args[0]) == "set" &&
strings.ToLower(args[len(args)-1]) == "string" {
// Setting a string value that is contained inside double quotes.
// This is only because of the boundary issues of the native protocol.
args = append(args, string(line[1:len(line)-1]))
break
}
}
i := 0
for ; i < len(line); i++ {
if line[i] == ' ' {
arg := string(line[:i])
if arg != "" {
args = append(args, arg)
}
line = line[i+1:]
continue reading
}
}
args = append(args, string(line))
break
}
return &Message{Args: args, ConnType: Native, OutputType: JSON}, nil
}
2019-04-26 21:50:49 +03:00
// InputStream is a helper type for managing input streams from inside
// the Data event.
type InputStream struct{ b []byte }
// Begin accepts a new packet and returns a working sequence of
// unprocessed bytes.
func (is *InputStream) Begin(packet []byte) (data []byte) {
data = packet
if len(is.b) > 0 {
is.b = append(is.b, data...)
data = is.b
}
return data
}
// End shifts the stream to match the unprocessed data.
func (is *InputStream) End(data []byte) {
if len(data) > 0 {
if len(data) != len(is.b) {
is.b = append(is.b[:0], data...)
}
} else if len(is.b) > 0 {
is.b = is.b[:0]
}
}
// clientErrorf is the same as fmt.Errorf, but is intented for errors that are
// sent back to the client. This allows for the Go static checker to ignore
// throwing warning for certain error strings.
// https://staticcheck.io/docs/checks#ST1005
func clientErrorf(format string, args ...interface{}) error {
return fmt.Errorf(format, args...)
}