mirror of https://github.com/tidwall/tile38.git
847 lines
20 KiB
Go
847 lines
20 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha1"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/tidwall/geojson/geo"
|
|
"github.com/tidwall/resp"
|
|
"github.com/tidwall/tile38/internal/log"
|
|
lua "github.com/yuin/gopher-lua"
|
|
luajson "layeh.com/gopher-json"
|
|
)
|
|
|
|
const (
|
|
iniLuaPoolSize = 5
|
|
maxLuaPoolSize = 1000
|
|
)
|
|
|
|
var errShaNotFound = errors.New("sha not found")
|
|
var errCmdNotSupported = errors.New("command not supported in scripts")
|
|
var errNotLeader = errors.New("not the leader")
|
|
var errReadOnly = errors.New("read only")
|
|
var errCatchingUp = errors.New("catching up to leader")
|
|
var errNoLuasAvailable = errors.New("no interpreters available")
|
|
var errTimeout = errors.New("timeout")
|
|
|
|
// Go-routine-safe pool of read-to-go lua states
|
|
type lStatePool struct {
|
|
m sync.Mutex
|
|
s *Server
|
|
saved []*lua.LState
|
|
total int
|
|
}
|
|
|
|
// newPool returns a new pool of lua states
|
|
func (s *Server) newPool() *lStatePool {
|
|
pl := &lStatePool{
|
|
saved: make([]*lua.LState, iniLuaPoolSize),
|
|
s: s,
|
|
}
|
|
// Fill the pool with some ready handlers
|
|
for i := 0; i < iniLuaPoolSize; i++ {
|
|
pl.saved[i] = pl.New()
|
|
pl.total++
|
|
}
|
|
return pl
|
|
}
|
|
|
|
func (pl *lStatePool) Get() (*lua.LState, error) {
|
|
pl.m.Lock()
|
|
defer pl.m.Unlock()
|
|
n := len(pl.saved)
|
|
if n == 0 {
|
|
if pl.total >= maxLuaPoolSize {
|
|
return nil, errNoLuasAvailable
|
|
}
|
|
pl.total++
|
|
return pl.New(), nil
|
|
}
|
|
x := pl.saved[n-1]
|
|
pl.saved = pl.saved[0 : n-1]
|
|
return x, nil
|
|
}
|
|
|
|
// Prune removes some of the idle lua states from the pool
|
|
func (pl *lStatePool) Prune() {
|
|
pl.m.Lock()
|
|
n := len(pl.saved)
|
|
if n > iniLuaPoolSize {
|
|
// drop half of the idle states that is above the minimum
|
|
dropNum := (n - iniLuaPoolSize) / 2
|
|
if dropNum < 1 {
|
|
dropNum = 1
|
|
}
|
|
newSaved := make([]*lua.LState, n-dropNum)
|
|
copy(newSaved, pl.saved[dropNum:])
|
|
pl.saved = newSaved
|
|
pl.total -= dropNum
|
|
}
|
|
pl.m.Unlock()
|
|
}
|
|
|
|
func (pl *lStatePool) New() *lua.LState {
|
|
L := lua.NewState()
|
|
|
|
getArgs := func(ls *lua.LState) (evalCmd string, args []string) {
|
|
evalCmd = ls.GetGlobal("EVAL_CMD").String()
|
|
|
|
// Trying to work with unknown number of args.
|
|
// When we see empty arg we call it enough.
|
|
for i := 1; ; i++ {
|
|
if arg := ls.ToString(i); arg == "" {
|
|
break
|
|
} else {
|
|
args = append(args, arg)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
call := func(ls *lua.LState) int {
|
|
evalCmd, args := getArgs(ls)
|
|
var numRet int
|
|
if res, err := pl.s.luaTile38Call(evalCmd, args[0], args[1:]...); err != nil {
|
|
ls.RaiseError("ERR %s", err.Error())
|
|
numRet = 0
|
|
} else {
|
|
ls.Push(ConvertToLua(ls, res))
|
|
numRet = 1
|
|
}
|
|
return numRet
|
|
}
|
|
pcall := func(ls *lua.LState) int {
|
|
evalCmd, args := getArgs(ls)
|
|
if res, err := pl.s.luaTile38Call(evalCmd, args[0], args[1:]...); err != nil {
|
|
ls.Push(ConvertToLua(ls, resp.ErrorValue(err)))
|
|
} else {
|
|
ls.Push(ConvertToLua(ls, res))
|
|
}
|
|
return 1
|
|
|
|
}
|
|
errorReply := func(ls *lua.LState) int {
|
|
tbl := L.CreateTable(0, 1)
|
|
tbl.RawSetString("err", lua.LString(ls.ToString(1)))
|
|
ls.Push(tbl)
|
|
return 1
|
|
}
|
|
statusReply := func(ls *lua.LState) int {
|
|
tbl := L.CreateTable(0, 1)
|
|
tbl.RawSetString("ok", lua.LString(ls.ToString(1)))
|
|
ls.Push(tbl)
|
|
return 1
|
|
}
|
|
sha1hex := func(ls *lua.LState) int {
|
|
shaSum := Sha1Sum(ls.ToString(1))
|
|
ls.Push(lua.LString(shaSum))
|
|
return 1
|
|
}
|
|
distanceTo := func(ls *lua.LState) int {
|
|
dt := geo.DistanceTo(
|
|
float64(ls.ToNumber(1)),
|
|
float64(ls.ToNumber(2)),
|
|
float64(ls.ToNumber(3)),
|
|
float64(ls.ToNumber(4)))
|
|
ls.Push(lua.LNumber(dt))
|
|
return 1
|
|
}
|
|
var exports = map[string]lua.LGFunction{
|
|
"call": call,
|
|
"pcall": pcall,
|
|
"error_reply": errorReply,
|
|
"status_reply": statusReply,
|
|
"sha1hex": sha1hex,
|
|
"distance_to": distanceTo,
|
|
}
|
|
L.SetGlobal("tile38", L.SetFuncs(L.NewTable(), exports))
|
|
|
|
// Load json
|
|
L.SetGlobal("json", L.Get(luajson.Loader(L)))
|
|
|
|
// Prohibit creating new globals in this state
|
|
lockNewGlobals := func(ls *lua.LState) int {
|
|
ls.RaiseError("attempt to create global variable '%s'", ls.ToString(2))
|
|
return 0
|
|
}
|
|
mt := L.CreateTable(0, 1)
|
|
mt.RawSetString("__newindex", L.NewFunction(lockNewGlobals))
|
|
L.SetMetatable(L.Get(lua.GlobalsIndex), mt)
|
|
|
|
return L
|
|
}
|
|
|
|
func (pl *lStatePool) Put(L *lua.LState) {
|
|
pl.m.Lock()
|
|
pl.saved = append(pl.saved, L)
|
|
pl.m.Unlock()
|
|
}
|
|
|
|
func (pl *lStatePool) Shutdown() {
|
|
pl.m.Lock()
|
|
for _, L := range pl.saved {
|
|
L.Close()
|
|
}
|
|
pl.m.Unlock()
|
|
}
|
|
|
|
// Go-routine-safe map of compiled scripts
|
|
type lScriptMap struct {
|
|
m sync.Mutex
|
|
scripts map[string]*lua.FunctionProto
|
|
}
|
|
|
|
func (sm *lScriptMap) Get(key string) (script *lua.FunctionProto, ok bool) {
|
|
sm.m.Lock()
|
|
script, ok = sm.scripts[key]
|
|
sm.m.Unlock()
|
|
return
|
|
}
|
|
|
|
func (sm *lScriptMap) Put(key string, script *lua.FunctionProto) {
|
|
sm.m.Lock()
|
|
sm.scripts[key] = script
|
|
sm.m.Unlock()
|
|
}
|
|
|
|
func (sm *lScriptMap) Flush() {
|
|
sm.m.Lock()
|
|
sm.scripts = make(map[string]*lua.FunctionProto)
|
|
sm.m.Unlock()
|
|
}
|
|
|
|
// NewScriptMap returns a new map with lua scripts
|
|
func (s *Server) newScriptMap() *lScriptMap {
|
|
return &lScriptMap{
|
|
scripts: make(map[string]*lua.FunctionProto),
|
|
}
|
|
}
|
|
|
|
// ConvertToLua converts RESP value to lua LValue
|
|
func ConvertToLua(L *lua.LState, val resp.Value) lua.LValue {
|
|
if val.IsNull() {
|
|
return lua.LFalse
|
|
}
|
|
switch val.Type() {
|
|
case resp.Integer:
|
|
return lua.LNumber(val.Integer())
|
|
case resp.BulkString:
|
|
return lua.LString(val.String())
|
|
case resp.Error:
|
|
tbl := L.CreateTable(0, 1)
|
|
tbl.RawSetString("err", lua.LString(val.String()))
|
|
return tbl
|
|
case resp.SimpleString:
|
|
tbl := L.CreateTable(0, 1)
|
|
tbl.RawSetString("ok", lua.LString(val.String()))
|
|
return tbl
|
|
case resp.Array:
|
|
tbl := L.CreateTable(len(val.Array()), 0)
|
|
for _, item := range val.Array() {
|
|
tbl.Append(ConvertToLua(L, item))
|
|
}
|
|
return tbl
|
|
}
|
|
return lua.LString("ERR: unknown RESP type: " + val.Type().String())
|
|
}
|
|
|
|
// ConvertToRESP convert lua LValue to RESP value
|
|
func ConvertToRESP(val lua.LValue) resp.Value {
|
|
switch val.Type() {
|
|
case lua.LTNil:
|
|
return resp.NullValue()
|
|
case lua.LTBool:
|
|
if val == lua.LTrue {
|
|
return resp.IntegerValue(1)
|
|
}
|
|
return resp.NullValue()
|
|
case lua.LTNumber:
|
|
float := float64(val.(lua.LNumber))
|
|
if math.IsNaN(float) || math.IsInf(float, 0) {
|
|
return resp.FloatValue(float)
|
|
}
|
|
return resp.IntegerValue(int(math.Floor(float)))
|
|
case lua.LTString:
|
|
return resp.StringValue(val.String())
|
|
case lua.LTTable:
|
|
var values []resp.Value
|
|
var specialValues []resp.Value
|
|
var cb func(lk lua.LValue, lv lua.LValue)
|
|
tbl := val.(*lua.LTable)
|
|
|
|
if tbl.Len() != 0 { // list
|
|
cb = func(lk lua.LValue, lv lua.LValue) {
|
|
values = append(values, ConvertToRESP(lv))
|
|
}
|
|
} else { // map
|
|
cb = func(lk lua.LValue, lv lua.LValue) {
|
|
if lk.Type() == lua.LTString {
|
|
lks := lk.String()
|
|
switch lks {
|
|
case "ok":
|
|
specialValues = append(specialValues, resp.SimpleStringValue(lv.String()))
|
|
case "err":
|
|
specialValues = append(specialValues, resp.ErrorValue(errors.New(lv.String())))
|
|
}
|
|
}
|
|
values = append(values, resp.ArrayValue(
|
|
[]resp.Value{ConvertToRESP(lk), ConvertToRESP(lv)}))
|
|
}
|
|
}
|
|
tbl.ForEach(cb)
|
|
if len(values) == 1 && len(specialValues) == 1 {
|
|
return specialValues[0]
|
|
}
|
|
return resp.ArrayValue(values)
|
|
}
|
|
return resp.ErrorValue(errors.New("Unsupported lua type: " + val.Type().String()))
|
|
}
|
|
|
|
// ConvertToJSON converts lua LValue to JSON string
|
|
func ConvertToJSON(val lua.LValue) string {
|
|
switch val.Type() {
|
|
case lua.LTNil:
|
|
return "null"
|
|
case lua.LTBool:
|
|
if val == lua.LTrue {
|
|
return "true"
|
|
}
|
|
return "false"
|
|
case lua.LTNumber:
|
|
return val.String()
|
|
case lua.LTString:
|
|
if b, err := json.Marshal(val.String()); err != nil {
|
|
panic(err)
|
|
} else {
|
|
return string(b)
|
|
}
|
|
case lua.LTTable:
|
|
var values []string
|
|
var cb func(lk lua.LValue, lv lua.LValue)
|
|
var start, end string
|
|
|
|
tbl := val.(*lua.LTable)
|
|
if tbl.Len() != 0 { // list
|
|
start = `[`
|
|
end = `]`
|
|
cb = func(lk lua.LValue, lv lua.LValue) {
|
|
values = append(values, ConvertToJSON(lv))
|
|
}
|
|
} else { // map
|
|
start = `{`
|
|
end = `}`
|
|
cb = func(lk lua.LValue, lv lua.LValue) {
|
|
values = append(
|
|
values, ConvertToJSON(lk)+`:`+ConvertToJSON(lv))
|
|
}
|
|
}
|
|
tbl.ForEach(cb)
|
|
return start + strings.Join(values, `,`) + end
|
|
}
|
|
return "Unsupported lua type: " + val.Type().String()
|
|
}
|
|
|
|
func luaSetRawGlobals(ls *lua.LState, tbl map[string]lua.LValue) {
|
|
gt := ls.Get(lua.GlobalsIndex).(*lua.LTable)
|
|
for key, val := range tbl {
|
|
gt.RawSetString(key, val)
|
|
}
|
|
}
|
|
|
|
// Sha1Sum returns a string with hex representation of sha1 sum of a given string
|
|
func Sha1Sum(s string) string {
|
|
h := sha1.New()
|
|
h.Write([]byte(s))
|
|
return hex.EncodeToString(h.Sum(nil))
|
|
}
|
|
|
|
// Replace newlines with literal \n since RESP errors cannot have newlines
|
|
func makeSafeErr(err error) error {
|
|
return errors.New(strings.Replace(err.Error(), "\n", `\n`, -1))
|
|
}
|
|
|
|
// Run eval/evalro/evalna command or it's -sha variant
|
|
func (s *Server) cmdEvalUnified(scriptIsSha bool, msg *Message) (res resp.Value, err error) {
|
|
start := time.Now()
|
|
vs := msg.Args[1:]
|
|
|
|
var ok bool
|
|
var script, numkeysStr, key, arg string
|
|
if vs, script, ok = tokenval(vs); !ok || script == "" {
|
|
return NOMessage, errInvalidNumberOfArguments
|
|
}
|
|
|
|
if vs, numkeysStr, ok = tokenval(vs); !ok || numkeysStr == "" {
|
|
return NOMessage, errInvalidNumberOfArguments
|
|
}
|
|
|
|
var i, numkeys uint64
|
|
if numkeys, err = strconv.ParseUint(numkeysStr, 10, 64); err != nil {
|
|
err = errInvalidArgument(numkeysStr)
|
|
return
|
|
}
|
|
|
|
luaState, err := s.luapool.Get()
|
|
if err != nil {
|
|
return
|
|
}
|
|
luaDeadline := lua.LNil
|
|
if msg.Deadline != nil {
|
|
dlTime := msg.Deadline.GetDeadlineTime()
|
|
ctx, cancel := context.WithDeadline(context.Background(), dlTime)
|
|
defer cancel()
|
|
luaState.SetContext(ctx)
|
|
defer luaState.RemoveContext()
|
|
luaDeadline = lua.LNumber(float64(dlTime.UnixNano()) / 1e9)
|
|
}
|
|
defer s.luapool.Put(luaState)
|
|
|
|
keysTbl := luaState.CreateTable(int(numkeys), 0)
|
|
for i = 0; i < numkeys; i++ {
|
|
if vs, key, ok = tokenval(vs); !ok || key == "" {
|
|
err = errInvalidNumberOfArguments
|
|
return
|
|
}
|
|
keysTbl.Append(lua.LString(key))
|
|
}
|
|
|
|
argsTbl := luaState.CreateTable(len(vs), 0)
|
|
for len(vs) > 0 {
|
|
if vs, arg, ok = tokenval(vs); !ok || arg == "" {
|
|
err = errInvalidNumberOfArguments
|
|
return
|
|
}
|
|
argsTbl.Append(lua.LString(arg))
|
|
}
|
|
|
|
var shaSum string
|
|
if scriptIsSha {
|
|
shaSum = script
|
|
} else {
|
|
shaSum = Sha1Sum(script)
|
|
}
|
|
|
|
luaSetRawGlobals(
|
|
luaState, map[string]lua.LValue{
|
|
"KEYS": keysTbl,
|
|
"ARGV": argsTbl,
|
|
"DEADLINE": luaDeadline,
|
|
"EVAL_CMD": lua.LString(msg.Command()),
|
|
})
|
|
|
|
compiled, ok := s.luascripts.Get(shaSum)
|
|
var fn *lua.LFunction
|
|
if ok {
|
|
fn = &lua.LFunction{
|
|
IsG: false,
|
|
Env: luaState.Env,
|
|
|
|
Proto: compiled,
|
|
GFunction: nil,
|
|
Upvalues: make([]*lua.Upvalue, 0),
|
|
}
|
|
} else if scriptIsSha {
|
|
err = errShaNotFound
|
|
return
|
|
} else {
|
|
fn, err = luaState.Load(strings.NewReader(script), "f_"+shaSum)
|
|
if err != nil {
|
|
return NOMessage, makeSafeErr(err)
|
|
}
|
|
s.luascripts.Put(shaSum, fn.Proto)
|
|
}
|
|
luaState.Push(fn)
|
|
defer luaSetRawGlobals(
|
|
luaState, map[string]lua.LValue{
|
|
"KEYS": lua.LNil,
|
|
"ARGV": lua.LNil,
|
|
"DEADLINE": lua.LNil,
|
|
"EVAL_CMD": lua.LNil,
|
|
})
|
|
if err := luaState.PCall(0, 1, nil); err != nil {
|
|
if strings.Contains(err.Error(), "context deadline exceeded") {
|
|
msg.Deadline.Check()
|
|
}
|
|
log.Debugf("%v", err.Error())
|
|
return NOMessage, makeSafeErr(err)
|
|
}
|
|
ret := luaState.Get(-1) // returned value
|
|
luaState.Pop(1)
|
|
|
|
switch msg.OutputType {
|
|
case JSON:
|
|
var buf bytes.Buffer
|
|
buf.WriteString(`{"ok":true`)
|
|
buf.WriteString(`,"result":` + ConvertToJSON(ret))
|
|
buf.WriteString(`,"elapsed":"` + time.Since(start).String() + "\"}")
|
|
return resp.StringValue(buf.String()), nil
|
|
case RESP:
|
|
return ConvertToRESP(ret), nil
|
|
}
|
|
return NOMessage, nil
|
|
}
|
|
|
|
func (s *Server) cmdScriptLoad(msg *Message) (resp.Value, error) {
|
|
start := time.Now()
|
|
vs := msg.Args[1:]
|
|
|
|
var ok bool
|
|
var script string
|
|
if _, script, ok = tokenval(vs); !ok || script == "" {
|
|
return NOMessage, errInvalidNumberOfArguments
|
|
}
|
|
|
|
shaSum := Sha1Sum(script)
|
|
|
|
luaState, err := s.luapool.Get()
|
|
if err != nil {
|
|
return NOMessage, err
|
|
}
|
|
defer s.luapool.Put(luaState)
|
|
|
|
fn, err := luaState.Load(strings.NewReader(script), "f_"+shaSum)
|
|
if err != nil {
|
|
return NOMessage, makeSafeErr(err)
|
|
}
|
|
s.luascripts.Put(shaSum, fn.Proto)
|
|
|
|
switch msg.OutputType {
|
|
case JSON:
|
|
var buf bytes.Buffer
|
|
buf.WriteString(`{"ok":true`)
|
|
buf.WriteString(`,"result":"` + shaSum + `"`)
|
|
buf.WriteString(`,"elapsed":"` + time.Since(start).String() + "\"}")
|
|
return resp.StringValue(buf.String()), nil
|
|
case RESP:
|
|
return resp.StringValue(shaSum), nil
|
|
}
|
|
return NOMessage, nil
|
|
}
|
|
|
|
func (s *Server) cmdScriptExists(msg *Message) (resp.Value, error) {
|
|
start := time.Now()
|
|
vs := msg.Args[1:]
|
|
|
|
var ok bool
|
|
var shaSum string
|
|
var results []int
|
|
var ires int
|
|
for len(vs) > 0 {
|
|
if vs, shaSum, ok = tokenval(vs); !ok || shaSum == "" {
|
|
return NOMessage, errInvalidNumberOfArguments
|
|
}
|
|
_, ok = s.luascripts.Get(shaSum)
|
|
if ok {
|
|
ires = 1
|
|
} else {
|
|
ires = 0
|
|
}
|
|
results = append(results, ires)
|
|
}
|
|
|
|
switch msg.OutputType {
|
|
case JSON:
|
|
var buf bytes.Buffer
|
|
buf.WriteString(`{"ok":true`)
|
|
var resArray []string
|
|
for _, ires := range results {
|
|
resArray = append(resArray, fmt.Sprintf("%d", ires))
|
|
}
|
|
buf.WriteString(`,"result":[` + strings.Join(resArray, ",") + `]`)
|
|
buf.WriteString(`,"elapsed":"` + time.Since(start).String() + "\"}")
|
|
return resp.StringValue(buf.String()), nil
|
|
case RESP:
|
|
var resArray []resp.Value
|
|
for _, ires := range results {
|
|
resArray = append(resArray, resp.IntegerValue(ires))
|
|
}
|
|
return resp.ArrayValue(resArray), nil
|
|
}
|
|
return resp.SimpleStringValue(""), nil
|
|
}
|
|
|
|
func (s *Server) cmdScriptFlush(msg *Message) (resp.Value, error) {
|
|
start := time.Now()
|
|
s.luascripts.Flush()
|
|
|
|
switch msg.OutputType {
|
|
case JSON:
|
|
var buf bytes.Buffer
|
|
buf.WriteString(`{"ok":true`)
|
|
buf.WriteString(`,"elapsed":"` + time.Since(start).String() + "\"}")
|
|
return resp.StringValue(buf.String()), nil
|
|
case RESP:
|
|
return resp.StringValue("OK"), nil
|
|
}
|
|
return resp.SimpleStringValue(""), nil
|
|
}
|
|
|
|
func (s *Server) commandInScript(msg *Message) (
|
|
res resp.Value, d commandDetails, err error,
|
|
) {
|
|
switch msg.Command() {
|
|
default:
|
|
err = fmt.Errorf("unknown command '%s'", msg.Args[0])
|
|
case "set":
|
|
res, d, err = s.cmdSET(msg)
|
|
case "fset":
|
|
res, d, err = s.cmdFSET(msg)
|
|
case "del":
|
|
res, d, err = s.cmdDel(msg)
|
|
case "pdel":
|
|
res, d, err = s.cmdPdel(msg)
|
|
case "drop":
|
|
res, d, err = s.cmdDrop(msg)
|
|
case "expire":
|
|
res, d, err = s.cmdEXPIRE(msg)
|
|
case "rename":
|
|
res, d, err = s.cmdRename(msg)
|
|
case "renamenx":
|
|
res, d, err = s.cmdRename(msg)
|
|
case "persist":
|
|
res, d, err = s.cmdPERSIST(msg)
|
|
case "ttl":
|
|
res, err = s.cmdTTL(msg)
|
|
case "stats":
|
|
res, err = s.cmdStats(msg)
|
|
case "scan":
|
|
res, err = s.cmdScan(msg)
|
|
case "nearby":
|
|
res, err = s.cmdNearby(msg)
|
|
case "within":
|
|
res, err = s.cmdWITHIN(msg)
|
|
case "intersects":
|
|
res, err = s.cmdINTERSECTS(msg)
|
|
case "search":
|
|
res, err = s.cmdSearch(msg)
|
|
case "bounds":
|
|
res, err = s.cmdBounds(msg)
|
|
case "get":
|
|
res, err = s.cmdGet(msg)
|
|
case "jget":
|
|
res, err = s.cmdJget(msg)
|
|
case "jset":
|
|
res, d, err = s.cmdJset(msg)
|
|
case "jdel":
|
|
res, d, err = s.cmdJdel(msg)
|
|
case "type":
|
|
res, err = s.cmdType(msg)
|
|
case "keys":
|
|
res, err = s.cmdKeys(msg)
|
|
case "test":
|
|
res, err = s.cmdTest(msg)
|
|
case "server":
|
|
res, err = s.cmdServer(msg)
|
|
}
|
|
s.sendMonitor(err, msg, nil, true)
|
|
return
|
|
}
|
|
|
|
func (s *Server) luaTile38Call(evalcmd string, cmd string, args ...string) (resp.Value, error) {
|
|
msg := &Message{}
|
|
msg.OutputType = RESP
|
|
msg.Args = append([]string{cmd}, args...)
|
|
|
|
if msg.Command() == "timeout" {
|
|
if err := rewriteTimeoutMsg(msg); err != nil {
|
|
return resp.NullValue(), err
|
|
}
|
|
}
|
|
|
|
switch msg.Command() {
|
|
case "ping", "echo", "auth", "massinsert", "shutdown", "gc",
|
|
"sethook", "pdelhook", "delhook",
|
|
"follow", "readonly", "config", "output", "client",
|
|
"aofshrink",
|
|
"script load", "script exists", "script flush",
|
|
"eval", "evalsha", "evalro", "evalrosha", "evalna", "evalnasha":
|
|
return resp.NullValue(), errCmdNotSupported
|
|
}
|
|
|
|
switch evalcmd {
|
|
case "eval", "evalsha":
|
|
return s.luaTile38AtomicRW(msg)
|
|
case "evalro", "evalrosha":
|
|
return s.luaTile38AtomicRO(msg)
|
|
case "evalna", "evalnasha":
|
|
return s.luaTile38NonAtomic(msg)
|
|
}
|
|
|
|
return resp.NullValue(), errCmdNotSupported
|
|
}
|
|
|
|
// The eval command has already got the lock. No locking on the call from within the script.
|
|
func (s *Server) luaTile38AtomicRW(msg *Message) (resp.Value, error) {
|
|
var write bool
|
|
|
|
switch msg.Command() {
|
|
default:
|
|
return resp.NullValue(), errCmdNotSupported
|
|
case "set", "del", "drop", "fset", "flushdb", "expire", "persist", "jset", "pdel",
|
|
"rename", "renamenx":
|
|
// write operations
|
|
write = true
|
|
if s.config.followHost() != "" {
|
|
return resp.NullValue(), errNotLeader
|
|
}
|
|
if s.config.readOnly() {
|
|
return resp.NullValue(), errReadOnly
|
|
}
|
|
case "get", "keys", "scan", "nearby", "within", "intersects", "hooks", "search",
|
|
"ttl", "bounds", "server", "info", "type", "jget", "test":
|
|
// read operations
|
|
if s.config.followHost() != "" && !s.fcuponce {
|
|
return resp.NullValue(), errCatchingUp
|
|
}
|
|
}
|
|
|
|
res, d, err := func() (res resp.Value, d commandDetails, err error) {
|
|
if msg.Deadline != nil {
|
|
if write {
|
|
res = NOMessage
|
|
err = errTimeoutOnCmd(msg.Command())
|
|
return
|
|
}
|
|
defer func() {
|
|
if msg.Deadline.Hit() {
|
|
v := recover()
|
|
if v != nil {
|
|
if s, ok := v.(string); !ok || s != "deadline" {
|
|
panic(v)
|
|
}
|
|
}
|
|
res = NOMessage
|
|
err = errTimeout
|
|
}
|
|
}()
|
|
}
|
|
return s.commandInScript(msg)
|
|
}()
|
|
if err != nil {
|
|
return resp.NullValue(), err
|
|
}
|
|
|
|
if write {
|
|
if err := s.writeAOF(msg.Args, &d); err != nil {
|
|
return resp.NullValue(), err
|
|
}
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func (s *Server) luaTile38AtomicRO(msg *Message) (resp.Value, error) {
|
|
switch msg.Command() {
|
|
default:
|
|
return resp.NullValue(), errCmdNotSupported
|
|
|
|
case "set", "del", "drop", "fset", "flushdb", "expire", "persist", "jset", "pdel",
|
|
"rename", "renamenx":
|
|
// write operations
|
|
return resp.NullValue(), errReadOnly
|
|
|
|
case "get", "keys", "scan", "nearby", "within", "intersects", "hooks", "search",
|
|
"ttl", "bounds", "server", "info", "type", "jget", "test":
|
|
// read operations
|
|
if s.config.followHost() != "" && !s.fcuponce {
|
|
return resp.NullValue(), errCatchingUp
|
|
}
|
|
}
|
|
|
|
res, _, err := func() (res resp.Value, d commandDetails, err error) {
|
|
if msg.Deadline != nil {
|
|
defer func() {
|
|
if msg.Deadline.Hit() {
|
|
v := recover()
|
|
if v != nil {
|
|
if s, ok := v.(string); !ok || s != "deadline" {
|
|
panic(v)
|
|
}
|
|
}
|
|
res = NOMessage
|
|
err = errTimeout
|
|
}
|
|
}()
|
|
}
|
|
return s.commandInScript(msg)
|
|
}()
|
|
if err != nil {
|
|
return resp.NullValue(), err
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func (s *Server) luaTile38NonAtomic(msg *Message) (resp.Value, error) {
|
|
var write bool
|
|
|
|
// choose the locking strategy
|
|
switch msg.Command() {
|
|
default:
|
|
return resp.NullValue(), errCmdNotSupported
|
|
case "set", "del", "drop", "fset", "flushdb", "expire", "persist", "jset", "pdel",
|
|
"rename", "renamenx":
|
|
// write operations
|
|
write = true
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.config.followHost() != "" {
|
|
return resp.NullValue(), errNotLeader
|
|
}
|
|
if s.config.readOnly() {
|
|
return resp.NullValue(), errReadOnly
|
|
}
|
|
case "get", "keys", "scan", "nearby", "within", "intersects", "hooks", "search",
|
|
"ttl", "bounds", "server", "info", "type", "jget", "test":
|
|
// read operations
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
if s.config.followHost() != "" && !s.fcuponce {
|
|
return resp.NullValue(), errCatchingUp
|
|
}
|
|
}
|
|
|
|
res, d, err := func() (res resp.Value, d commandDetails, err error) {
|
|
if msg.Deadline != nil {
|
|
if write {
|
|
res = NOMessage
|
|
err = errTimeoutOnCmd(msg.Command())
|
|
return
|
|
}
|
|
defer func() {
|
|
if msg.Deadline.Hit() {
|
|
v := recover()
|
|
if v != nil {
|
|
if s, ok := v.(string); !ok || s != "deadline" {
|
|
panic(v)
|
|
}
|
|
}
|
|
res = NOMessage
|
|
err = errTimeout
|
|
}
|
|
}()
|
|
}
|
|
return s.commandInScript(msg)
|
|
}()
|
|
if err != nil {
|
|
return resp.NullValue(), err
|
|
}
|
|
|
|
if write {
|
|
if err := s.writeAOF(msg.Args, &d); err != nil {
|
|
return resp.NullValue(), err
|
|
}
|
|
}
|
|
|
|
return res, nil
|
|
}
|