mirror of https://github.com/tidwall/tile38.git
1588 lines
41 KiB
Go
1588 lines
41 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"crypto/sha1"
|
|
"encoding/base64"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/url"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"runtime"
|
|
"runtime/debug"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/tidwall/boxtree/d2"
|
|
"github.com/tidwall/buntdb"
|
|
"github.com/tidwall/evio"
|
|
"github.com/tidwall/geojson"
|
|
"github.com/tidwall/geojson/geometry"
|
|
"github.com/tidwall/redcon"
|
|
"github.com/tidwall/resp"
|
|
"github.com/tidwall/tile38/core"
|
|
"github.com/tidwall/tile38/internal/collection"
|
|
"github.com/tidwall/tile38/internal/endpoint"
|
|
"github.com/tidwall/tile38/internal/expire"
|
|
"github.com/tidwall/tile38/internal/log"
|
|
"github.com/tidwall/tinybtree"
|
|
)
|
|
|
|
var errOOM = errors.New("OOM command not allowed when used memory > 'maxmemory'")
|
|
|
|
const (
|
|
goingLive = "going live"
|
|
hookLogPrefix = "hook:log:"
|
|
)
|
|
|
|
// commandDetails is detailed information about a mutable command. It's used
|
|
// for geofence formulas.
|
|
type commandDetails struct {
|
|
command string // client command, like "SET" or "DEL"
|
|
key, id string // collection key and object id of object
|
|
newKey string // new key, for RENAME command
|
|
obj geojson.Object // new object
|
|
fmap map[string]int // map of field names to value indexes
|
|
fields *collection.Fields // array of field values
|
|
oldObj geojson.Object // previous object, if any
|
|
oldFields *collection.Fields // previous object field values
|
|
updated bool // object was updated
|
|
timestamp time.Time // timestamp when the update occured
|
|
parent bool // when true, only children are forwarded
|
|
pattern string // PDEL key pattern
|
|
children []*commandDetails // for multi actions such as "PDEL"
|
|
}
|
|
|
|
// Server is a tile38 controller
|
|
type Server struct {
|
|
// static values
|
|
host string
|
|
port int
|
|
http bool
|
|
dir string
|
|
started time.Time
|
|
config *Config
|
|
epc *endpoint.Manager
|
|
|
|
// env opts
|
|
geomParseOpts geojson.ParseOptions
|
|
|
|
// atomics
|
|
followc aint // counter increases when follow property changes
|
|
statsTotalConns aint // counter for total connections
|
|
statsTotalCommands aint // counter for total commands
|
|
statsExpired aint // item expiration counter
|
|
lastShrinkDuration aint
|
|
stopServer abool
|
|
outOfMemory abool
|
|
|
|
connsmu sync.RWMutex
|
|
conns map[int]*Client
|
|
|
|
exlistmu sync.RWMutex
|
|
exlist []exitem
|
|
|
|
mu sync.RWMutex
|
|
aof *os.File // active aof file
|
|
aofdirty int32 // mark the aofbuf as having data
|
|
aofbuf []byte // prewrite buffer
|
|
aofsz int // active size of the aof file
|
|
qdb *buntdb.DB // hook queue log
|
|
qidx uint64 // hook queue log last idx
|
|
cols tinybtree.BTree // data collections
|
|
expires map[string]map[string]time.Time // synced with cols
|
|
|
|
follows map[*bytes.Buffer]bool
|
|
fcond *sync.Cond
|
|
lstack []*commandDetails
|
|
lives map[*liveBuffer]bool
|
|
lcond *sync.Cond
|
|
fcup bool // follow caught up
|
|
fcuponce bool // follow caught up once
|
|
shrinking bool // aof shrinking flag
|
|
shrinklog [][]string // aof shrinking log
|
|
hooks map[string]*Hook // hook name
|
|
hookTree d2.BoxTree // hook spatial tree containing all
|
|
hooksOut map[string]*Hook // hooks with "outside" detection
|
|
aofconnM map[net.Conn]bool
|
|
luascripts *lScriptMap
|
|
luapool *lStatePool
|
|
|
|
pubsub *pubsub
|
|
hookex expire.List
|
|
}
|
|
|
|
// Serve starts a new tile38 server
|
|
func Serve(host string, port int, dir string, http bool) error {
|
|
if core.AppendFileName == "" {
|
|
core.AppendFileName = path.Join(dir, "appendonly.aof")
|
|
}
|
|
if core.QueueFileName == "" {
|
|
core.QueueFileName = path.Join(dir, "queue.db")
|
|
}
|
|
log.Infof("Server started, Tile38 version %s, git %s", core.Version, core.GitSHA)
|
|
|
|
// Initialize the server
|
|
server := &Server{
|
|
host: host,
|
|
port: port,
|
|
dir: dir,
|
|
follows: make(map[*bytes.Buffer]bool),
|
|
fcond: sync.NewCond(&sync.Mutex{}),
|
|
lives: make(map[*liveBuffer]bool),
|
|
lcond: sync.NewCond(&sync.Mutex{}),
|
|
hooks: make(map[string]*Hook),
|
|
hooksOut: make(map[string]*Hook),
|
|
aofconnM: make(map[net.Conn]bool),
|
|
expires: make(map[string]map[string]time.Time),
|
|
started: time.Now(),
|
|
conns: make(map[int]*Client),
|
|
http: http,
|
|
pubsub: newPubsub(),
|
|
}
|
|
|
|
server.hookex.Expired = func(item expire.Item) {
|
|
switch v := item.(type) {
|
|
case *Hook:
|
|
server.possiblyExpireHook(v.Name)
|
|
}
|
|
}
|
|
server.epc = endpoint.NewManager(server)
|
|
server.luascripts = server.newScriptMap()
|
|
server.luapool = server.newPool()
|
|
defer server.luapool.Shutdown()
|
|
|
|
if err := os.MkdirAll(dir, 0700); err != nil {
|
|
return err
|
|
}
|
|
var err error
|
|
server.config, err = loadConfig(filepath.Join(dir, "config"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
server.geomParseOpts = *geojson.DefaultParseOptions
|
|
server.geomParseOpts.AllowSimplePoints = true
|
|
|
|
// Allow for geometry indexing options through environment variables:
|
|
// T38IDXGEOMKIND -- None, RTree, QuadTree
|
|
// T38IDXGEOM -- Min number of points in a geometry for indexing.
|
|
// T38IDXMULTI -- Min number of object in a Multi/Collection for indexing.
|
|
n, err := strconv.ParseUint(os.Getenv("T38IDXGEOM"), 10, 32)
|
|
if err == nil {
|
|
server.geomParseOpts.IndexGeometry = int(n)
|
|
}
|
|
n, err = strconv.ParseUint(os.Getenv("T38IDXMULTI"), 10, 32)
|
|
if err == nil {
|
|
server.geomParseOpts.IndexChildren = int(n)
|
|
}
|
|
requireValid := os.Getenv("REQUIREVALID")
|
|
if requireValid != "" {
|
|
server.geomParseOpts.RequireValid = true
|
|
}
|
|
indexKind := os.Getenv("T38IDXGEOMKIND")
|
|
switch indexKind {
|
|
default:
|
|
log.Errorf("Unknown index kind: %s", indexKind)
|
|
case "":
|
|
case "None":
|
|
server.geomParseOpts.IndexGeometryKind = geometry.None
|
|
case "RTree":
|
|
server.geomParseOpts.IndexGeometryKind = geometry.RTree
|
|
case "QuadTree":
|
|
server.geomParseOpts.IndexGeometryKind = geometry.QuadTree
|
|
}
|
|
if server.geomParseOpts.IndexGeometryKind == geometry.None {
|
|
log.Debugf("Geom indexing: %s",
|
|
server.geomParseOpts.IndexGeometryKind,
|
|
)
|
|
} else {
|
|
log.Debugf("Geom indexing: %s (%d points)",
|
|
server.geomParseOpts.IndexGeometryKind,
|
|
server.geomParseOpts.IndexGeometry,
|
|
)
|
|
}
|
|
log.Debugf("Multi indexing: RTree (%d points)", server.geomParseOpts.IndexChildren)
|
|
|
|
// Load the queue before the aof
|
|
qdb, err := buntdb.Open(core.QueueFileName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var qidx uint64
|
|
if err := qdb.View(func(tx *buntdb.Tx) error {
|
|
val, err := tx.Get("hook:idx")
|
|
if err != nil {
|
|
if err == buntdb.ErrNotFound {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
qidx = stringToUint64(val)
|
|
return nil
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
err = qdb.CreateIndex("hooks", hookLogPrefix+"*", buntdb.IndexJSONCaseSensitive("hook"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
server.qdb = qdb
|
|
server.qidx = qidx
|
|
if err := server.migrateAOF(); err != nil {
|
|
return err
|
|
}
|
|
if core.AppendOnly == true {
|
|
f, err := os.OpenFile(core.AppendFileName, os.O_CREATE|os.O_RDWR, 0600)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
server.aof = f
|
|
if err := server.loadAOF(); err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
server.flushAOF(false)
|
|
server.aof.Sync()
|
|
}()
|
|
}
|
|
server.fillExpiresList()
|
|
|
|
// Start background routines
|
|
if server.config.followHost() != "" {
|
|
go server.follow(server.config.followHost(), server.config.followPort(),
|
|
server.followc.get())
|
|
}
|
|
go server.processLives()
|
|
go server.watchOutOfMemory()
|
|
go server.watchLuaStatePool()
|
|
go server.watchAutoGC()
|
|
go server.backgroundExpiring()
|
|
go server.backgroundSyncAOF()
|
|
defer func() {
|
|
// Stop background routines
|
|
server.followc.add(1) // this will force any follow communication to die
|
|
server.stopServer.set(true)
|
|
|
|
// notify the live geofence connections that we are stopping.
|
|
server.lcond.L.Lock()
|
|
server.lcond.Wait()
|
|
server.lcond.L.Lock()
|
|
}()
|
|
|
|
// Start the network server
|
|
if core.Evio {
|
|
return server.evioServe()
|
|
}
|
|
return server.netServe()
|
|
}
|
|
|
|
func (server *Server) isProtected() bool {
|
|
if core.ProtectedMode == "no" {
|
|
// --protected-mode no
|
|
return false
|
|
}
|
|
if server.host != "" && server.host != "127.0.0.1" &&
|
|
server.host != "::1" && server.host != "localhost" {
|
|
// -h address
|
|
return false
|
|
}
|
|
is := server.config.protectedMode() != "no" && server.config.requirePass() == ""
|
|
return is
|
|
}
|
|
|
|
func (server *Server) evioServe() error {
|
|
var events evio.Events
|
|
if core.NumThreads == 0 {
|
|
events.NumLoops = -1
|
|
} else {
|
|
events.NumLoops = core.NumThreads
|
|
}
|
|
events.LoadBalance = evio.LeastConnections
|
|
events.Serving = func(eserver evio.Server) (action evio.Action) {
|
|
if eserver.NumLoops == 1 {
|
|
log.Infof("Running single-threaded")
|
|
} else {
|
|
log.Infof("Running on %d threads", eserver.NumLoops)
|
|
}
|
|
for _, addr := range eserver.Addrs {
|
|
log.Infof("Ready to accept connections at %s",
|
|
addr)
|
|
}
|
|
return
|
|
}
|
|
var clientID int64
|
|
events.Opened = func(econn evio.Conn) (out []byte, opts evio.Options, action evio.Action) {
|
|
// create the client
|
|
client := new(Client)
|
|
client.id = int(atomic.AddInt64(&clientID, 1))
|
|
client.opened = time.Now()
|
|
client.remoteAddr = econn.RemoteAddr().String()
|
|
|
|
// keep track of the client
|
|
econn.SetContext(client)
|
|
|
|
// add client to server map
|
|
server.connsmu.Lock()
|
|
server.conns[client.id] = client
|
|
server.connsmu.Unlock()
|
|
server.statsTotalConns.add(1)
|
|
|
|
// set the client keep-alive, if needed
|
|
if server.config.keepAlive() > 0 {
|
|
opts.TCPKeepAlive = time.Duration(server.config.keepAlive()) * time.Second
|
|
}
|
|
log.Debugf("Opened connection: %s", client.remoteAddr)
|
|
|
|
// check if the connection is protected
|
|
if !strings.HasPrefix(client.remoteAddr, "127.0.0.1:") &&
|
|
!strings.HasPrefix(client.remoteAddr, "[::1]:") {
|
|
if server.isProtected() {
|
|
// This is a protected server. Only loopback is allowed.
|
|
out = append(out, deniedMessage...)
|
|
action = evio.Close
|
|
return
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
events.Closed = func(econn evio.Conn, err error) (action evio.Action) {
|
|
// load the client
|
|
client := econn.Context().(*Client)
|
|
|
|
// delete from server map
|
|
server.connsmu.Lock()
|
|
delete(server.conns, client.id)
|
|
server.connsmu.Unlock()
|
|
|
|
log.Debugf("Closed connection: %s", client.remoteAddr)
|
|
return
|
|
}
|
|
|
|
events.Data = func(econn evio.Conn, in []byte) (out []byte, action evio.Action) {
|
|
// load the client
|
|
client := econn.Context().(*Client)
|
|
|
|
// read the payload packet from the client input stream.
|
|
packet := client.in.Begin(in)
|
|
|
|
// load the pipeline reader
|
|
pr := &client.pr
|
|
rdbuf := bytes.NewBuffer(packet)
|
|
pr.rd = rdbuf
|
|
pr.wr = client
|
|
|
|
msgs, err := pr.ReadMessages()
|
|
if err != nil {
|
|
log.Error(err)
|
|
action = evio.Close
|
|
return
|
|
}
|
|
for _, msg := range msgs {
|
|
// Just closing connection if we have deprecated HTTP or WS connection,
|
|
// And --http-transport = false
|
|
if !server.http && (msg.ConnType == WebSocket ||
|
|
msg.ConnType == HTTP) {
|
|
action = evio.Close
|
|
break
|
|
}
|
|
if msg != nil && msg.Command() != "" {
|
|
if client.outputType != Null {
|
|
msg.OutputType = client.outputType
|
|
}
|
|
if msg.Command() == "quit" {
|
|
if msg.OutputType == RESP {
|
|
io.WriteString(client, "+OK\r\n")
|
|
}
|
|
action = evio.Close
|
|
break
|
|
}
|
|
|
|
// increment last used
|
|
client.mu.Lock()
|
|
client.last = time.Now()
|
|
client.mu.Unlock()
|
|
|
|
// update total command count
|
|
server.statsTotalCommands.add(1)
|
|
|
|
// handle the command
|
|
err := server.handleInputCommand(client, msg)
|
|
if err != nil {
|
|
if err.Error() == goingLive {
|
|
client.goLiveErr = err
|
|
client.goLiveMsg = msg
|
|
action = evio.Detach
|
|
return
|
|
}
|
|
log.Error(err)
|
|
action = evio.Close
|
|
return
|
|
}
|
|
|
|
client.outputType = msg.OutputType
|
|
} else {
|
|
client.Write([]byte("HTTP/1.1 500 Bad Request\r\nConnection: close\r\n\r\n"))
|
|
action = evio.Close
|
|
break
|
|
}
|
|
if msg.ConnType == HTTP || msg.ConnType == WebSocket {
|
|
action = evio.Close
|
|
break
|
|
}
|
|
}
|
|
|
|
packet = packet[len(packet)-rdbuf.Len():]
|
|
client.in.End(packet)
|
|
|
|
out = client.out
|
|
client.out = nil
|
|
return
|
|
}
|
|
|
|
events.Detached = func(econn evio.Conn, rwc io.ReadWriteCloser) (action evio.Action) {
|
|
client := econn.Context().(*Client)
|
|
client.conn = rwc
|
|
if len(client.out) > 0 {
|
|
rwc.Write(client.out)
|
|
client.out = nil
|
|
}
|
|
client.in = evio.InputStream{}
|
|
client.pr.rd = rwc
|
|
client.pr.wr = rwc
|
|
|
|
log.Debugf("Detached connection: %s", client.remoteAddr)
|
|
go func() {
|
|
defer func() {
|
|
// close connection
|
|
rwc.Close()
|
|
server.connsmu.Lock()
|
|
delete(server.conns, client.id)
|
|
server.connsmu.Unlock()
|
|
log.Debugf("Closed connection: %s", client.remoteAddr)
|
|
}()
|
|
err := server.goLive(
|
|
client.goLiveErr,
|
|
&liveConn{econn.RemoteAddr(), rwc},
|
|
&client.pr,
|
|
client.goLiveMsg,
|
|
client.goLiveMsg.ConnType == WebSocket,
|
|
)
|
|
if err != nil {
|
|
log.Error(err)
|
|
}
|
|
}()
|
|
return
|
|
}
|
|
|
|
events.PreWrite = func() {
|
|
if atomic.LoadInt32(&server.aofdirty) != 0 {
|
|
server.mu.Lock()
|
|
defer server.mu.Unlock()
|
|
server.flushAOF(false)
|
|
atomic.StoreInt32(&server.aofdirty, 1)
|
|
}
|
|
}
|
|
|
|
return evio.Serve(events, fmt.Sprintf("%s:%d", server.host, server.port))
|
|
}
|
|
|
|
func (server *Server) netServe() error {
|
|
ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", server.host, server.port))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer ln.Close()
|
|
log.Infof("Ready to accept connections at %s", ln.Addr())
|
|
var clientID int64
|
|
for {
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
go func(conn net.Conn) {
|
|
// open connection
|
|
// create the client
|
|
client := new(Client)
|
|
client.id = int(atomic.AddInt64(&clientID, 1))
|
|
client.opened = time.Now()
|
|
client.remoteAddr = conn.RemoteAddr().String()
|
|
|
|
// add client to server map
|
|
server.connsmu.Lock()
|
|
server.conns[client.id] = client
|
|
server.connsmu.Unlock()
|
|
server.statsTotalConns.add(1)
|
|
|
|
// set the client keep-alive, if needed
|
|
if server.config.keepAlive() > 0 {
|
|
if conn, ok := conn.(*net.TCPConn); ok {
|
|
conn.SetKeepAlive(true)
|
|
conn.SetKeepAlivePeriod(
|
|
time.Duration(server.config.keepAlive()) * time.Second,
|
|
)
|
|
}
|
|
}
|
|
log.Debugf("Opened connection: %s", client.remoteAddr)
|
|
|
|
defer func() {
|
|
// close connection
|
|
// delete from server map
|
|
server.connsmu.Lock()
|
|
delete(server.conns, client.id)
|
|
server.connsmu.Unlock()
|
|
log.Debugf("Closed connection: %s", client.remoteAddr)
|
|
conn.Close()
|
|
}()
|
|
|
|
// check if the connection is protected
|
|
if !strings.HasPrefix(client.remoteAddr, "127.0.0.1:") &&
|
|
!strings.HasPrefix(client.remoteAddr, "[::1]:") {
|
|
if server.isProtected() {
|
|
// This is a protected server. Only loopback is allowed.
|
|
conn.Write(deniedMessage)
|
|
return // close connection
|
|
}
|
|
}
|
|
packet := make([]byte, 0xFFFF)
|
|
for {
|
|
var close bool
|
|
n, err := conn.Read(packet)
|
|
if err != nil {
|
|
return
|
|
}
|
|
in := packet[:n]
|
|
|
|
// read the payload packet from the client input stream.
|
|
packet := client.in.Begin(in)
|
|
|
|
// load the pipeline reader
|
|
pr := &client.pr
|
|
rdbuf := bytes.NewBuffer(packet)
|
|
pr.rd = rdbuf
|
|
pr.wr = client
|
|
|
|
msgs, err := pr.ReadMessages()
|
|
if err != nil {
|
|
log.Error(err)
|
|
return // close connection
|
|
}
|
|
for _, msg := range msgs {
|
|
// Just closing connection if we have deprecated HTTP or WS connection,
|
|
// And --http-transport = false
|
|
if !server.http && (msg.ConnType == WebSocket ||
|
|
msg.ConnType == HTTP) {
|
|
close = true // close connection
|
|
break
|
|
}
|
|
if msg != nil && msg.Command() != "" {
|
|
if client.outputType != Null {
|
|
msg.OutputType = client.outputType
|
|
}
|
|
if msg.Command() == "quit" {
|
|
if msg.OutputType == RESP {
|
|
io.WriteString(client, "+OK\r\n")
|
|
}
|
|
close = true // close connection
|
|
break
|
|
}
|
|
|
|
// increment last used
|
|
client.mu.Lock()
|
|
client.last = time.Now()
|
|
client.mu.Unlock()
|
|
|
|
// update total command count
|
|
server.statsTotalCommands.add(1)
|
|
|
|
// handle the command
|
|
err := server.handleInputCommand(client, msg)
|
|
if err != nil {
|
|
if err.Error() == goingLive {
|
|
client.goLiveErr = err
|
|
client.goLiveMsg = msg
|
|
// detach
|
|
var rwc io.ReadWriteCloser = conn
|
|
client.conn = rwc
|
|
if len(client.out) > 0 {
|
|
client.conn.Write(client.out)
|
|
client.out = nil
|
|
}
|
|
client.in = evio.InputStream{}
|
|
client.pr.rd = rwc
|
|
client.pr.wr = rwc
|
|
log.Debugf("Detached connection: %s", client.remoteAddr)
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
err := server.goLive(
|
|
client.goLiveErr,
|
|
&liveConn{conn.RemoteAddr(), rwc},
|
|
&client.pr,
|
|
client.goLiveMsg,
|
|
client.goLiveMsg.ConnType == WebSocket,
|
|
)
|
|
if err != nil {
|
|
log.Error(err)
|
|
}
|
|
}()
|
|
wg.Wait()
|
|
return // close connection
|
|
}
|
|
log.Error(err)
|
|
return // close connection, NOW
|
|
}
|
|
|
|
client.outputType = msg.OutputType
|
|
} else {
|
|
client.Write([]byte("HTTP/1.1 500 Bad Request\r\nConnection: close\r\n\r\n"))
|
|
break
|
|
}
|
|
if msg.ConnType == HTTP || msg.ConnType == WebSocket {
|
|
close = true // close connection
|
|
break
|
|
}
|
|
}
|
|
|
|
packet = packet[len(packet)-rdbuf.Len():]
|
|
client.in.End(packet)
|
|
|
|
// write to client
|
|
if len(client.out) > 0 {
|
|
if atomic.LoadInt32(&server.aofdirty) != 0 {
|
|
func() {
|
|
// prewrite
|
|
server.mu.Lock()
|
|
defer server.mu.Unlock()
|
|
server.flushAOF(false)
|
|
}()
|
|
atomic.StoreInt32(&server.aofdirty, 0)
|
|
}
|
|
conn.Write(client.out)
|
|
client.out = nil
|
|
|
|
}
|
|
if close {
|
|
break
|
|
}
|
|
}
|
|
}(conn)
|
|
}
|
|
}
|
|
|
|
type liveConn struct {
|
|
remoteAddr net.Addr
|
|
rwc io.ReadWriteCloser
|
|
}
|
|
|
|
func (conn *liveConn) Close() error {
|
|
return conn.rwc.Close()
|
|
}
|
|
|
|
func (conn *liveConn) LocalAddr() net.Addr {
|
|
panic("not supported")
|
|
}
|
|
|
|
func (conn *liveConn) RemoteAddr() net.Addr {
|
|
return conn.remoteAddr
|
|
}
|
|
func (conn *liveConn) Read(b []byte) (n int, err error) {
|
|
return conn.rwc.Read(b)
|
|
}
|
|
|
|
func (conn *liveConn) Write(b []byte) (n int, err error) {
|
|
return conn.rwc.Write(b)
|
|
}
|
|
|
|
func (conn *liveConn) SetDeadline(deadline time.Time) error {
|
|
panic("not supported")
|
|
}
|
|
|
|
func (conn *liveConn) SetReadDeadline(deadline time.Time) error {
|
|
panic("not supported")
|
|
}
|
|
|
|
func (conn *liveConn) SetWriteDeadline(deadline time.Time) error {
|
|
panic("not supported")
|
|
}
|
|
|
|
func (server *Server) watchAutoGC() {
|
|
t := time.NewTicker(time.Second)
|
|
defer t.Stop()
|
|
s := time.Now()
|
|
for range t.C {
|
|
if server.stopServer.on() {
|
|
return
|
|
}
|
|
autoGC := server.config.autoGC()
|
|
if autoGC == 0 {
|
|
continue
|
|
}
|
|
if time.Now().Sub(s) < time.Second*time.Duration(autoGC) {
|
|
continue
|
|
}
|
|
var mem1, mem2 runtime.MemStats
|
|
runtime.ReadMemStats(&mem1)
|
|
log.Debugf("autogc(before): "+
|
|
"alloc: %v, heap_alloc: %v, heap_released: %v",
|
|
mem1.Alloc, mem1.HeapAlloc, mem1.HeapReleased)
|
|
|
|
runtime.GC()
|
|
debug.FreeOSMemory()
|
|
runtime.ReadMemStats(&mem2)
|
|
log.Debugf("autogc(after): "+
|
|
"alloc: %v, heap_alloc: %v, heap_released: %v",
|
|
mem2.Alloc, mem2.HeapAlloc, mem2.HeapReleased)
|
|
s = time.Now()
|
|
}
|
|
}
|
|
|
|
func (server *Server) watchOutOfMemory() {
|
|
t := time.NewTicker(time.Second * 2)
|
|
defer t.Stop()
|
|
var mem runtime.MemStats
|
|
for range t.C {
|
|
if server.stopServer.on() {
|
|
return
|
|
}
|
|
oom := server.outOfMemory.on()
|
|
if server.config.maxMemory() == 0 {
|
|
if oom {
|
|
server.outOfMemory.set(false)
|
|
}
|
|
return
|
|
}
|
|
if oom {
|
|
runtime.GC()
|
|
}
|
|
runtime.ReadMemStats(&mem)
|
|
server.outOfMemory.set(int(mem.HeapAlloc) > server.config.maxMemory())
|
|
}
|
|
}
|
|
|
|
func (server *Server) watchLuaStatePool() {
|
|
t := time.NewTicker(time.Second * 10)
|
|
defer t.Stop()
|
|
for range t.C {
|
|
if server.stopServer.on() {
|
|
return
|
|
}
|
|
server.luapool.Prune()
|
|
}
|
|
}
|
|
|
|
func (server *Server) backgroundSyncAOF() {
|
|
t := time.NewTicker(time.Second)
|
|
defer t.Stop()
|
|
for range t.C {
|
|
if server.stopServer.on() {
|
|
return
|
|
}
|
|
server.mu.Lock()
|
|
if len(server.aofbuf) > 0 {
|
|
server.flushAOF(true)
|
|
}
|
|
server.aofbuf = nil
|
|
server.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
func (server *Server) setCol(key string, col *collection.Collection) {
|
|
server.cols.Set(key, col)
|
|
}
|
|
|
|
func (server *Server) getCol(key string) *collection.Collection {
|
|
if value, ok := server.cols.Get(key); ok {
|
|
return value.(*collection.Collection)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (server *Server) scanGreaterOrEqual(
|
|
key string, iterator func(key string, col *collection.Collection) bool,
|
|
) {
|
|
server.cols.Ascend(key, func(ikey string, ivalue interface{}) bool {
|
|
return iterator(ikey, ivalue.(*collection.Collection))
|
|
})
|
|
}
|
|
|
|
func (server *Server) deleteCol(key string) *collection.Collection {
|
|
if prev, ok := server.cols.Delete(key); ok {
|
|
return prev.(*collection.Collection)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func isReservedFieldName(field string) bool {
|
|
switch field {
|
|
case "z", "lat", "lon":
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (server *Server) handleInputCommand(client *Client, msg *Message) error {
|
|
start := time.Now()
|
|
serializeOutput := func(res resp.Value) (string, error) {
|
|
var resStr string
|
|
var err error
|
|
switch msg.OutputType {
|
|
case JSON:
|
|
resStr = res.String()
|
|
case RESP:
|
|
var resBytes []byte
|
|
resBytes, err = res.MarshalRESP()
|
|
resStr = string(resBytes)
|
|
}
|
|
return resStr, err
|
|
}
|
|
writeOutput := func(res string) error {
|
|
switch msg.ConnType {
|
|
default:
|
|
err := fmt.Errorf("unsupported conn type: %v", msg.ConnType)
|
|
log.Error(err)
|
|
return err
|
|
case WebSocket:
|
|
return WriteWebSocketMessage(client, []byte(res))
|
|
case HTTP:
|
|
_, err := fmt.Fprintf(client, "HTTP/1.1 200 OK\r\n"+
|
|
"Connection: close\r\n"+
|
|
"Content-Length: %d\r\n"+
|
|
"Content-Type: application/json; charset=utf-8\r\n"+
|
|
"\r\n", len(res)+2)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = io.WriteString(client, res)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = io.WriteString(client, "\r\n")
|
|
return err
|
|
case RESP:
|
|
var err error
|
|
if msg.OutputType == JSON {
|
|
_, err = fmt.Fprintf(client, "$%d\r\n%s\r\n", len(res), res)
|
|
} else {
|
|
_, err = io.WriteString(client, res)
|
|
}
|
|
return err
|
|
case Native:
|
|
_, err := fmt.Fprintf(client, "$%d %s\r\n", len(res), res)
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Ping. Just send back the response. No need to put through the pipeline.
|
|
if msg.Command() == "ping" || msg.Command() == "echo" {
|
|
switch msg.OutputType {
|
|
case JSON:
|
|
if len(msg.Args) > 1 {
|
|
return writeOutput(`{"ok":true,"` + msg.Command() + `":` + jsonString(msg.Args[1]) + `,"elapsed":"` + time.Now().Sub(start).String() + `"}`)
|
|
}
|
|
return writeOutput(`{"ok":true,"` + msg.Command() + `":"pong","elapsed":"` + time.Now().Sub(start).String() + `"}`)
|
|
case RESP:
|
|
if len(msg.Args) > 1 {
|
|
data := redcon.AppendBulkString(nil, msg.Args[1])
|
|
return writeOutput(string(data))
|
|
}
|
|
return writeOutput("+PONG\r\n")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
writeErr := func(errMsg string) error {
|
|
switch msg.OutputType {
|
|
case JSON:
|
|
return writeOutput(`{"ok":false,"err":` + jsonString(errMsg) + `,"elapsed":"` + time.Now().Sub(start).String() + "\"}")
|
|
case RESP:
|
|
if errMsg == errInvalidNumberOfArguments.Error() {
|
|
return writeOutput("-ERR wrong number of arguments for '" + msg.Command() + "' command\r\n")
|
|
}
|
|
v, _ := resp.ErrorValue(errors.New("ERR " + errMsg)).MarshalRESP()
|
|
return writeOutput(string(v))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
var write bool
|
|
|
|
if !client.authd || msg.Command() == "auth" {
|
|
if server.config.requirePass() != "" {
|
|
password := ""
|
|
// This better be an AUTH command or the Message should contain an Auth
|
|
if msg.Command() != "auth" && msg.Auth == "" {
|
|
// Just shut down the pipeline now. The less the client connection knows the better.
|
|
return writeErr("authentication required")
|
|
}
|
|
if msg.Auth != "" {
|
|
password = msg.Auth
|
|
} else {
|
|
if len(msg.Args) > 1 {
|
|
password = msg.Args[1]
|
|
}
|
|
}
|
|
if server.config.requirePass() != strings.TrimSpace(password) {
|
|
return writeErr("invalid password")
|
|
}
|
|
client.authd = true
|
|
if msg.ConnType != HTTP {
|
|
resStr, _ := serializeOutput(OKMessage(msg, start))
|
|
return writeOutput(resStr)
|
|
}
|
|
} else if msg.Command() == "auth" {
|
|
return writeErr("invalid password")
|
|
}
|
|
}
|
|
|
|
// choose the locking strategy
|
|
switch msg.Command() {
|
|
default:
|
|
server.mu.RLock()
|
|
defer server.mu.RUnlock()
|
|
case "set", "del", "drop", "fset", "flushdb",
|
|
"setchan", "pdelchan", "delchan",
|
|
"sethook", "pdelhook", "delhook",
|
|
"expire", "persist", "jset", "pdel", "rename", "renamenx":
|
|
// write operations
|
|
write = true
|
|
server.mu.Lock()
|
|
defer server.mu.Unlock()
|
|
if server.config.followHost() != "" {
|
|
return writeErr("not the leader")
|
|
}
|
|
if server.config.readOnly() {
|
|
return writeErr("read only")
|
|
}
|
|
case "eval", "evalsha":
|
|
// write operations (potentially) but no AOF for the script command itself
|
|
server.mu.Lock()
|
|
defer server.mu.Unlock()
|
|
if server.config.followHost() != "" {
|
|
return writeErr("not the leader")
|
|
}
|
|
if server.config.readOnly() {
|
|
return writeErr("read only")
|
|
}
|
|
case "get", "keys", "scan", "nearby", "within", "intersects", "hooks",
|
|
"chans", "search", "ttl", "bounds", "server", "info", "type", "jget",
|
|
"evalro", "evalrosha":
|
|
// read operations
|
|
|
|
server.mu.RLock()
|
|
defer server.mu.RUnlock()
|
|
if server.config.followHost() != "" && !server.fcuponce {
|
|
return writeErr("catching up to leader")
|
|
}
|
|
case "follow", "slaveof", "replconf", "readonly", "config":
|
|
// system operations
|
|
// does not write to aof, but requires a write lock.
|
|
server.mu.Lock()
|
|
defer server.mu.Unlock()
|
|
case "output":
|
|
// this is local connection operation. Locks not needed.
|
|
case "echo":
|
|
case "massinsert":
|
|
// dev operation
|
|
server.mu.Lock()
|
|
defer server.mu.Unlock()
|
|
case "sleep":
|
|
// dev operation
|
|
server.mu.RLock()
|
|
defer server.mu.RUnlock()
|
|
case "shutdown":
|
|
// dev operation
|
|
server.mu.Lock()
|
|
defer server.mu.Unlock()
|
|
case "aofshrink":
|
|
server.mu.RLock()
|
|
defer server.mu.RUnlock()
|
|
case "client":
|
|
server.mu.Lock()
|
|
defer server.mu.Unlock()
|
|
case "evalna", "evalnasha":
|
|
// No locking for scripts, otherwise writes cannot happen within scripts
|
|
case "subscribe", "psubscribe", "publish":
|
|
// No locking for pubsub
|
|
}
|
|
|
|
res, d, err := server.command(msg, client)
|
|
if res.Type() == resp.Error {
|
|
return writeErr(res.String())
|
|
}
|
|
if err != nil {
|
|
if err.Error() == goingLive {
|
|
return err
|
|
}
|
|
return writeErr(err.Error())
|
|
}
|
|
if write {
|
|
if err := server.writeAOF(msg.Args, &d); err != nil {
|
|
if _, ok := err.(errAOFHook); ok {
|
|
return writeErr(err.Error())
|
|
}
|
|
log.Fatal(err)
|
|
return err
|
|
}
|
|
}
|
|
if !isRespValueEmptyString(res) {
|
|
var resStr string
|
|
resStr, err := serializeOutput(res)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := writeOutput(resStr); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func isRespValueEmptyString(val resp.Value) bool {
|
|
return !val.IsNull() && (val.Type() == resp.SimpleString || val.Type() == resp.BulkString) && len(val.Bytes()) == 0
|
|
}
|
|
|
|
func randomKey(n int) string {
|
|
b := make([]byte, n)
|
|
nn, err := rand.Read(b)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
if nn != n {
|
|
panic("random failed")
|
|
}
|
|
return fmt.Sprintf("%x", b)
|
|
}
|
|
|
|
func (server *Server) reset() {
|
|
server.aofsz = 0
|
|
server.cols = tinybtree.BTree{}
|
|
server.exlistmu.Lock()
|
|
server.exlist = nil
|
|
server.exlistmu.Unlock()
|
|
server.expires = make(map[string]map[string]time.Time)
|
|
}
|
|
|
|
func (server *Server) command(msg *Message, client *Client) (
|
|
res resp.Value, d commandDetails, err error,
|
|
) {
|
|
switch msg.Command() {
|
|
default:
|
|
err = fmt.Errorf("unknown command '%s'", msg.Args[0])
|
|
case "set":
|
|
res, d, err = server.cmdSet(msg)
|
|
case "fset":
|
|
res, d, err = server.cmdFset(msg)
|
|
case "del":
|
|
res, d, err = server.cmdDel(msg)
|
|
case "pdel":
|
|
res, d, err = server.cmdPdel(msg)
|
|
case "drop":
|
|
res, d, err = server.cmdDrop(msg)
|
|
case "flushdb":
|
|
res, d, err = server.cmdFlushDB(msg)
|
|
case "rename":
|
|
res, d, err = server.cmdRename(msg, false)
|
|
case "renamenx":
|
|
res, d, err = server.cmdRename(msg, true)
|
|
case "sethook":
|
|
res, d, err = server.cmdSetHook(msg, false)
|
|
case "delhook":
|
|
res, d, err = server.cmdDelHook(msg, false)
|
|
case "pdelhook":
|
|
res, d, err = server.cmdPDelHook(msg, false)
|
|
case "hooks":
|
|
res, err = server.cmdHooks(msg, false)
|
|
case "setchan":
|
|
res, d, err = server.cmdSetHook(msg, true)
|
|
case "delchan":
|
|
res, d, err = server.cmdDelHook(msg, true)
|
|
case "pdelchan":
|
|
res, d, err = server.cmdPDelHook(msg, true)
|
|
case "chans":
|
|
res, err = server.cmdHooks(msg, true)
|
|
case "expire":
|
|
res, d, err = server.cmdExpire(msg)
|
|
case "persist":
|
|
res, d, err = server.cmdPersist(msg)
|
|
case "ttl":
|
|
res, err = server.cmdTTL(msg)
|
|
case "shutdown":
|
|
if !core.DevMode {
|
|
err = fmt.Errorf("unknown command '%s'", msg.Args[0])
|
|
return
|
|
}
|
|
log.Fatal("shutdown requested by developer")
|
|
case "massinsert":
|
|
if !core.DevMode {
|
|
err = fmt.Errorf("unknown command '%s'", msg.Args[0])
|
|
return
|
|
}
|
|
res, err = server.cmdMassInsert(msg)
|
|
case "sleep":
|
|
if !core.DevMode {
|
|
err = fmt.Errorf("unknown command '%s'", msg.Args[0])
|
|
return
|
|
}
|
|
res, err = server.cmdSleep(msg)
|
|
case "follow", "slaveof":
|
|
res, err = server.cmdFollow(msg)
|
|
case "replconf":
|
|
res, err = server.cmdReplConf(msg, client)
|
|
case "readonly":
|
|
res, err = server.cmdReadOnly(msg)
|
|
case "stats":
|
|
res, err = server.cmdStats(msg)
|
|
case "server":
|
|
res, err = server.cmdServer(msg)
|
|
case "info":
|
|
res, err = server.cmdInfo(msg)
|
|
case "scan":
|
|
res, err = server.cmdScan(msg)
|
|
case "nearby":
|
|
res, err = server.cmdNearby(msg)
|
|
case "within":
|
|
res, err = server.cmdWithin(msg)
|
|
case "intersects":
|
|
res, err = server.cmdIntersects(msg)
|
|
case "search":
|
|
res, err = server.cmdSearch(msg)
|
|
case "bounds":
|
|
res, err = server.cmdBounds(msg)
|
|
case "get":
|
|
res, err = server.cmdGet(msg)
|
|
case "jget":
|
|
res, err = server.cmdJget(msg)
|
|
case "jset":
|
|
res, d, err = server.cmdJset(msg)
|
|
case "jdel":
|
|
res, d, err = server.cmdJdel(msg)
|
|
case "type":
|
|
res, err = server.cmdType(msg)
|
|
case "keys":
|
|
res, err = server.cmdKeys(msg)
|
|
case "output":
|
|
res, err = server.cmdOutput(msg)
|
|
case "aof":
|
|
res, err = server.cmdAOF(msg)
|
|
case "aofmd5":
|
|
res, err = server.cmdAOFMD5(msg)
|
|
case "gc":
|
|
runtime.GC()
|
|
debug.FreeOSMemory()
|
|
res = OKMessage(msg, time.Now())
|
|
case "aofshrink":
|
|
go server.aofshrink()
|
|
res = OKMessage(msg, time.Now())
|
|
case "config get":
|
|
res, err = server.cmdConfigGet(msg)
|
|
case "config set":
|
|
res, err = server.cmdConfigSet(msg)
|
|
case "config rewrite":
|
|
res, err = server.cmdConfigRewrite(msg)
|
|
case "config", "script":
|
|
// These get rewritten into "config foo" and "script bar"
|
|
err = fmt.Errorf("unknown command '%s'", msg.Args[0])
|
|
if len(msg.Args) > 1 {
|
|
msg.Args[1] = msg.Args[0] + " " + msg.Args[1]
|
|
msg.Args = msg.Args[1:]
|
|
msg._command = ""
|
|
return server.command(msg, client)
|
|
}
|
|
case "client":
|
|
res, err = server.cmdClient(msg, client)
|
|
case "eval", "evalro", "evalna":
|
|
res, err = server.cmdEvalUnified(false, msg)
|
|
case "evalsha", "evalrosha", "evalnasha":
|
|
res, err = server.cmdEvalUnified(true, msg)
|
|
case "script load":
|
|
res, err = server.cmdScriptLoad(msg)
|
|
case "script exists":
|
|
res, err = server.cmdScriptExists(msg)
|
|
case "script flush":
|
|
res, err = server.cmdScriptFlush(msg)
|
|
case "subscribe":
|
|
res, err = server.cmdSubscribe(msg)
|
|
case "psubscribe":
|
|
res, err = server.cmdPsubscribe(msg)
|
|
case "publish":
|
|
res, err = server.cmdPublish(msg)
|
|
case "test":
|
|
res, err = server.cmdTest(msg)
|
|
}
|
|
return
|
|
}
|
|
|
|
// This phrase is copied nearly verbatim from Redis.
|
|
var deniedMessage = []byte(strings.Replace(strings.TrimSpace(`
|
|
-DENIED Tile38 is running in protected mode because protected mode is enabled,
|
|
no bind address was specified, no authentication password is requested to
|
|
clients. In this mode connections are only accepted from the loopback
|
|
interface. If you want to connect from external computers to Tile38 you may
|
|
adopt one of the following solutions: 1) Just disable protected mode sending
|
|
the command 'CONFIG SET protected-mode no' from the loopback interface by
|
|
connecting to Tile38 from the same host the server is running, however MAKE
|
|
SURE Tile38 is not publicly accessible from internet if you do so. Use CONFIG
|
|
REWRITE to make this change permanent. 2) Alternatively you can just disable
|
|
the protected mode by editing the Tile38 configuration file, and setting the
|
|
protected mode option to 'no', and then restarting the server. 3) If you
|
|
started the server manually just for testing, restart it with the
|
|
'--protected-mode no' option. 4) Setup a bind address or an authentication
|
|
password. NOTE: You only need to do one of the above things in order for the
|
|
server to start accepting connections from the outside.
|
|
`), "\n", " ", -1) + "\r\n")
|
|
|
|
// SetKeepAlive sets the connection keepalive
|
|
func setKeepAlive(conn net.Conn, period time.Duration) error {
|
|
if tcp, ok := conn.(*net.TCPConn); ok {
|
|
if err := tcp.SetKeepAlive(true); err != nil {
|
|
return err
|
|
}
|
|
return tcp.SetKeepAlivePeriod(period)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
var errCloseHTTP = errors.New("close http")
|
|
|
|
// WriteWebSocketMessage write a websocket message to an io.Writer.
|
|
func WriteWebSocketMessage(w io.Writer, data []byte) error {
|
|
var msg []byte
|
|
buf := make([]byte, 10+len(data))
|
|
buf[0] = 129 // FIN + TEXT
|
|
if len(data) <= 125 {
|
|
buf[1] = byte(len(data))
|
|
copy(buf[2:], data)
|
|
msg = buf[:2+len(data)]
|
|
} else if len(data) <= 0xFFFF {
|
|
buf[1] = 126
|
|
binary.BigEndian.PutUint16(buf[2:], uint16(len(data)))
|
|
copy(buf[4:], data)
|
|
msg = buf[:4+len(data)]
|
|
} else {
|
|
buf[1] = 127
|
|
binary.BigEndian.PutUint64(buf[2:], uint64(len(data)))
|
|
copy(buf[10:], data)
|
|
msg = buf[:10+len(data)]
|
|
}
|
|
_, err := w.Write(msg)
|
|
return err
|
|
}
|
|
|
|
// OKMessage returns a default OK message in JSON or RESP.
|
|
func OKMessage(msg *Message, start time.Time) resp.Value {
|
|
switch msg.OutputType {
|
|
case JSON:
|
|
return resp.StringValue(`{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}")
|
|
case RESP:
|
|
return resp.SimpleStringValue("OK")
|
|
}
|
|
return resp.SimpleStringValue("")
|
|
}
|
|
|
|
// NOMessage is no message
|
|
var NOMessage = resp.SimpleStringValue("")
|
|
|
|
var errInvalidHTTP = errors.New("invalid HTTP request")
|
|
|
|
// Type is resp type
|
|
type Type byte
|
|
|
|
// Protocol Types
|
|
const (
|
|
Null Type = iota
|
|
RESP
|
|
Telnet
|
|
Native
|
|
HTTP
|
|
WebSocket
|
|
JSON
|
|
)
|
|
|
|
// Message is a resp message
|
|
type Message struct {
|
|
_command string
|
|
Args []string
|
|
ConnType Type
|
|
OutputType Type
|
|
Auth string
|
|
}
|
|
|
|
// Command returns the first argument as a lowercase string
|
|
func (msg *Message) Command() string {
|
|
if msg._command == "" {
|
|
msg._command = strings.ToLower(msg.Args[0])
|
|
}
|
|
return msg._command
|
|
}
|
|
|
|
// PipelineReader ...
|
|
type PipelineReader struct {
|
|
rd io.Reader
|
|
wr io.Writer
|
|
packet [0xFFFF]byte
|
|
buf []byte
|
|
}
|
|
|
|
const kindHTTP redcon.Kind = 9999
|
|
|
|
// NewPipelineReader ...
|
|
func NewPipelineReader(rd io.ReadWriter) *PipelineReader {
|
|
return &PipelineReader{rd: rd, wr: rd}
|
|
}
|
|
|
|
func readcrlfline(packet []byte) (line string, leftover []byte, ok bool) {
|
|
for i := 1; i < len(packet); i++ {
|
|
if packet[i] == '\n' && packet[i-1] == '\r' {
|
|
return string(packet[:i-1]), packet[i+1:], true
|
|
}
|
|
}
|
|
return "", packet, false
|
|
}
|
|
|
|
func readNextHTTPCommand(packet []byte, argsIn [][]byte, msg *Message, wr io.Writer) (
|
|
complete bool, args [][]byte, kind redcon.Kind, leftover []byte, err error,
|
|
) {
|
|
args = argsIn[:0]
|
|
msg.ConnType = HTTP
|
|
msg.OutputType = JSON
|
|
opacket := packet
|
|
|
|
ready, err := func() (bool, error) {
|
|
var line string
|
|
var ok bool
|
|
|
|
// read header
|
|
var headers []string
|
|
for {
|
|
line, packet, ok = readcrlfline(packet)
|
|
if !ok {
|
|
return false, nil
|
|
}
|
|
if line == "" {
|
|
break
|
|
}
|
|
headers = append(headers, line)
|
|
}
|
|
parts := strings.Split(headers[0], " ")
|
|
if len(parts) != 3 {
|
|
return false, errInvalidHTTP
|
|
}
|
|
method := parts[0]
|
|
path := parts[1]
|
|
if len(path) == 0 || path[0] != '/' {
|
|
return false, errInvalidHTTP
|
|
}
|
|
path, err = url.QueryUnescape(path[1:])
|
|
if err != nil {
|
|
return false, errInvalidHTTP
|
|
}
|
|
if method != "GET" && method != "POST" {
|
|
return false, errInvalidHTTP
|
|
}
|
|
contentLength := 0
|
|
websocket := false
|
|
websocketVersion := 0
|
|
websocketKey := ""
|
|
for _, header := range headers[1:] {
|
|
if header[0] == 'a' || header[0] == 'A' {
|
|
if strings.HasPrefix(strings.ToLower(header), "authorization:") {
|
|
msg.Auth = strings.TrimSpace(header[len("authorization:"):])
|
|
}
|
|
} else if header[0] == 'u' || header[0] == 'U' {
|
|
if strings.HasPrefix(strings.ToLower(header), "upgrade:") && strings.ToLower(strings.TrimSpace(header[len("upgrade:"):])) == "websocket" {
|
|
websocket = true
|
|
}
|
|
} else if header[0] == 's' || header[0] == 'S' {
|
|
if strings.HasPrefix(strings.ToLower(header), "sec-websocket-version:") {
|
|
var n uint64
|
|
n, err = strconv.ParseUint(strings.TrimSpace(header[len("sec-websocket-version:"):]), 10, 64)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
websocketVersion = int(n)
|
|
} else if strings.HasPrefix(strings.ToLower(header), "sec-websocket-key:") {
|
|
websocketKey = strings.TrimSpace(header[len("sec-websocket-key:"):])
|
|
}
|
|
} else if header[0] == 'c' || header[0] == 'C' {
|
|
if strings.HasPrefix(strings.ToLower(header), "content-length:") {
|
|
var n uint64
|
|
n, err = strconv.ParseUint(strings.TrimSpace(header[len("content-length:"):]), 10, 64)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
contentLength = int(n)
|
|
}
|
|
}
|
|
}
|
|
if websocket && websocketVersion >= 13 && websocketKey != "" {
|
|
msg.ConnType = WebSocket
|
|
if wr == nil {
|
|
return false, errors.New("connection is nil")
|
|
}
|
|
sum := sha1.Sum([]byte(websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
|
|
accept := base64.StdEncoding.EncodeToString(sum[:])
|
|
wshead := "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + accept + "\r\n\r\n"
|
|
if _, err = wr.Write([]byte(wshead)); err != nil {
|
|
return false, err
|
|
}
|
|
} else if contentLength > 0 {
|
|
msg.ConnType = HTTP
|
|
if len(packet) < contentLength {
|
|
return false, nil
|
|
}
|
|
path += string(packet[:contentLength])
|
|
packet = packet[contentLength:]
|
|
}
|
|
if path == "" {
|
|
return true, nil
|
|
}
|
|
nmsg, err := readNativeMessageLine([]byte(path))
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
msg.OutputType = JSON
|
|
msg.Args = nmsg.Args
|
|
return true, nil
|
|
}()
|
|
if err != nil || !ready {
|
|
return false, args[:0], kindHTTP, opacket, err
|
|
}
|
|
return true, args[:0], kindHTTP, packet, nil
|
|
}
|
|
func readNextCommand(packet []byte, argsIn [][]byte, msg *Message, wr io.Writer) (
|
|
complete bool, args [][]byte, kind redcon.Kind, leftover []byte, err error,
|
|
) {
|
|
if packet[0] == 'G' || packet[0] == 'P' {
|
|
// could be an HTTP request
|
|
var line []byte
|
|
for i := 1; i < len(packet); i++ {
|
|
if packet[i] == '\n' {
|
|
if packet[i-1] == '\r' {
|
|
line = packet[:i+1]
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if len(line) == 0 {
|
|
return false, argsIn[:0], redcon.Redis, packet, nil
|
|
}
|
|
if len(line) > 11 && string(line[len(line)-11:len(line)-5]) == " HTTP/" {
|
|
return readNextHTTPCommand(packet, argsIn, msg, wr)
|
|
}
|
|
}
|
|
return redcon.ReadNextCommand(packet, args)
|
|
}
|
|
|
|
// ReadMessages ...
|
|
func (rd *PipelineReader) ReadMessages() ([]*Message, error) {
|
|
var msgs []*Message
|
|
moreData:
|
|
n, err := rd.rd.Read(rd.packet[:])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if n == 0 {
|
|
// need more data
|
|
goto moreData
|
|
}
|
|
data := rd.packet[:n]
|
|
if len(rd.buf) > 0 {
|
|
data = append(rd.buf, data...)
|
|
}
|
|
for len(data) > 0 {
|
|
msg := &Message{}
|
|
complete, args, kind, leftover, err := readNextCommand(data, nil, msg, rd.wr)
|
|
if err != nil {
|
|
break
|
|
}
|
|
if !complete {
|
|
break
|
|
}
|
|
if kind == kindHTTP {
|
|
if len(msg.Args) == 0 {
|
|
return nil, errInvalidHTTP
|
|
}
|
|
msgs = append(msgs, msg)
|
|
} else if len(args) > 0 {
|
|
for i := 0; i < len(args); i++ {
|
|
msg.Args = append(msg.Args, string(args[i]))
|
|
}
|
|
switch kind {
|
|
case redcon.Redis:
|
|
msg.ConnType = RESP
|
|
msg.OutputType = RESP
|
|
case redcon.Tile38:
|
|
msg.ConnType = Native
|
|
msg.OutputType = JSON
|
|
case redcon.Telnet:
|
|
msg.ConnType = RESP
|
|
msg.OutputType = RESP
|
|
}
|
|
msgs = append(msgs, msg)
|
|
}
|
|
data = leftover
|
|
}
|
|
if len(data) > 0 {
|
|
rd.buf = append(rd.buf[:0], data...)
|
|
} else if len(rd.buf) > 0 {
|
|
rd.buf = rd.buf[:0]
|
|
}
|
|
if err != nil && len(msgs) == 0 {
|
|
return nil, err
|
|
}
|
|
return msgs, nil
|
|
}
|
|
|
|
func readNativeMessageLine(line []byte) (*Message, error) {
|
|
var args []string
|
|
reading:
|
|
for len(line) != 0 {
|
|
if line[0] == '{' {
|
|
// The native protocol cannot understand json boundaries so it assumes that
|
|
// a json element must be at the end of the line.
|
|
args = append(args, string(line))
|
|
break
|
|
}
|
|
if line[0] == '"' && line[len(line)-1] == '"' {
|
|
if len(args) > 0 &&
|
|
strings.ToLower(args[0]) == "set" &&
|
|
strings.ToLower(args[len(args)-1]) == "string" {
|
|
// Setting a string value that is contained inside double quotes.
|
|
// This is only because of the boundary issues of the native protocol.
|
|
args = append(args, string(line[1:len(line)-1]))
|
|
break
|
|
}
|
|
}
|
|
i := 0
|
|
for ; i < len(line); i++ {
|
|
if line[i] == ' ' {
|
|
arg := string(line[:i])
|
|
if arg != "" {
|
|
args = append(args, arg)
|
|
}
|
|
line = line[i+1:]
|
|
continue reading
|
|
}
|
|
}
|
|
args = append(args, string(line))
|
|
break
|
|
}
|
|
return &Message{Args: args, ConnType: Native, OutputType: JSON}, nil
|
|
}
|