tile38/internal/server/scripts.go

847 lines
20 KiB
Go

package server
import (
"bytes"
"context"
"crypto/sha1"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"math"
"strconv"
"strings"
"sync"
"time"
"github.com/tidwall/geojson/geo"
"github.com/tidwall/resp"
"github.com/tidwall/tile38/internal/log"
lua "github.com/yuin/gopher-lua"
luajson "layeh.com/gopher-json"
)
const (
iniLuaPoolSize = 5
maxLuaPoolSize = 1000
)
var errShaNotFound = errors.New("sha not found")
var errCmdNotSupported = errors.New("command not supported in scripts")
var errNotLeader = errors.New("not the leader")
var errReadOnly = errors.New("read only")
var errCatchingUp = errors.New("catching up to leader")
var errNoLuasAvailable = errors.New("no interpreters available")
var errTimeout = errors.New("timeout")
// Go-routine-safe pool of read-to-go lua states
type lStatePool struct {
m sync.Mutex
s *Server
saved []*lua.LState
total int
}
// newPool returns a new pool of lua states
func (s *Server) newPool() *lStatePool {
pl := &lStatePool{
saved: make([]*lua.LState, iniLuaPoolSize),
s: s,
}
// Fill the pool with some ready handlers
for i := 0; i < iniLuaPoolSize; i++ {
pl.saved[i] = pl.New()
pl.total++
}
return pl
}
func (pl *lStatePool) Get() (*lua.LState, error) {
pl.m.Lock()
defer pl.m.Unlock()
n := len(pl.saved)
if n == 0 {
if pl.total >= maxLuaPoolSize {
return nil, errNoLuasAvailable
}
pl.total++
return pl.New(), nil
}
x := pl.saved[n-1]
pl.saved = pl.saved[0 : n-1]
return x, nil
}
// Prune removes some of the idle lua states from the pool
func (pl *lStatePool) Prune() {
pl.m.Lock()
n := len(pl.saved)
if n > iniLuaPoolSize {
// drop half of the idle states that is above the minimum
dropNum := (n - iniLuaPoolSize) / 2
if dropNum < 1 {
dropNum = 1
}
newSaved := make([]*lua.LState, n-dropNum)
copy(newSaved, pl.saved[dropNum:])
pl.saved = newSaved
pl.total -= dropNum
}
pl.m.Unlock()
}
func (pl *lStatePool) New() *lua.LState {
L := lua.NewState()
getArgs := func(ls *lua.LState) (evalCmd string, args []string) {
evalCmd = ls.GetGlobal("EVAL_CMD").String()
// Trying to work with unknown number of args.
// When we see empty arg we call it enough.
for i := 1; ; i++ {
if arg := ls.ToString(i); arg == "" {
break
} else {
args = append(args, arg)
}
}
return
}
call := func(ls *lua.LState) int {
evalCmd, args := getArgs(ls)
var numRet int
if res, err := pl.s.luaTile38Call(evalCmd, args[0], args[1:]...); err != nil {
ls.RaiseError("ERR %s", err.Error())
numRet = 0
} else {
ls.Push(ConvertToLua(ls, res))
numRet = 1
}
return numRet
}
pcall := func(ls *lua.LState) int {
evalCmd, args := getArgs(ls)
if res, err := pl.s.luaTile38Call(evalCmd, args[0], args[1:]...); err != nil {
ls.Push(ConvertToLua(ls, resp.ErrorValue(err)))
} else {
ls.Push(ConvertToLua(ls, res))
}
return 1
}
errorReply := func(ls *lua.LState) int {
tbl := L.CreateTable(0, 1)
tbl.RawSetString("err", lua.LString(ls.ToString(1)))
ls.Push(tbl)
return 1
}
statusReply := func(ls *lua.LState) int {
tbl := L.CreateTable(0, 1)
tbl.RawSetString("ok", lua.LString(ls.ToString(1)))
ls.Push(tbl)
return 1
}
sha1hex := func(ls *lua.LState) int {
shaSum := Sha1Sum(ls.ToString(1))
ls.Push(lua.LString(shaSum))
return 1
}
distanceTo := func(ls *lua.LState) int {
dt := geo.DistanceTo(
float64(ls.ToNumber(1)),
float64(ls.ToNumber(2)),
float64(ls.ToNumber(3)),
float64(ls.ToNumber(4)))
ls.Push(lua.LNumber(dt))
return 1
}
var exports = map[string]lua.LGFunction{
"call": call,
"pcall": pcall,
"error_reply": errorReply,
"status_reply": statusReply,
"sha1hex": sha1hex,
"distance_to": distanceTo,
}
L.SetGlobal("tile38", L.SetFuncs(L.NewTable(), exports))
// Load json
L.SetGlobal("json", L.Get(luajson.Loader(L)))
// Prohibit creating new globals in this state
lockNewGlobals := func(ls *lua.LState) int {
ls.RaiseError("attempt to create global variable '%s'", ls.ToString(2))
return 0
}
mt := L.CreateTable(0, 1)
mt.RawSetString("__newindex", L.NewFunction(lockNewGlobals))
L.SetMetatable(L.Get(lua.GlobalsIndex), mt)
return L
}
func (pl *lStatePool) Put(L *lua.LState) {
pl.m.Lock()
pl.saved = append(pl.saved, L)
pl.m.Unlock()
}
func (pl *lStatePool) Shutdown() {
pl.m.Lock()
for _, L := range pl.saved {
L.Close()
}
pl.m.Unlock()
}
// Go-routine-safe map of compiled scripts
type lScriptMap struct {
m sync.Mutex
scripts map[string]*lua.FunctionProto
}
func (sm *lScriptMap) Get(key string) (script *lua.FunctionProto, ok bool) {
sm.m.Lock()
script, ok = sm.scripts[key]
sm.m.Unlock()
return
}
func (sm *lScriptMap) Put(key string, script *lua.FunctionProto) {
sm.m.Lock()
sm.scripts[key] = script
sm.m.Unlock()
}
func (sm *lScriptMap) Flush() {
sm.m.Lock()
sm.scripts = make(map[string]*lua.FunctionProto)
sm.m.Unlock()
}
// NewScriptMap returns a new map with lua scripts
func (s *Server) newScriptMap() *lScriptMap {
return &lScriptMap{
scripts: make(map[string]*lua.FunctionProto),
}
}
// ConvertToLua converts RESP value to lua LValue
func ConvertToLua(L *lua.LState, val resp.Value) lua.LValue {
if val.IsNull() {
return lua.LFalse
}
switch val.Type() {
case resp.Integer:
return lua.LNumber(val.Integer())
case resp.BulkString:
return lua.LString(val.String())
case resp.Error:
tbl := L.CreateTable(0, 1)
tbl.RawSetString("err", lua.LString(val.String()))
return tbl
case resp.SimpleString:
tbl := L.CreateTable(0, 1)
tbl.RawSetString("ok", lua.LString(val.String()))
return tbl
case resp.Array:
tbl := L.CreateTable(len(val.Array()), 0)
for _, item := range val.Array() {
tbl.Append(ConvertToLua(L, item))
}
return tbl
}
return lua.LString("ERR: unknown RESP type: " + val.Type().String())
}
// ConvertToRESP convert lua LValue to RESP value
func ConvertToRESP(val lua.LValue) resp.Value {
switch val.Type() {
case lua.LTNil:
return resp.NullValue()
case lua.LTBool:
if val == lua.LTrue {
return resp.IntegerValue(1)
}
return resp.NullValue()
case lua.LTNumber:
float := float64(val.(lua.LNumber))
if math.IsNaN(float) || math.IsInf(float, 0) {
return resp.FloatValue(float)
}
return resp.IntegerValue(int(math.Floor(float)))
case lua.LTString:
return resp.StringValue(val.String())
case lua.LTTable:
var values []resp.Value
var specialValues []resp.Value
var cb func(lk lua.LValue, lv lua.LValue)
tbl := val.(*lua.LTable)
if tbl.Len() != 0 { // list
cb = func(lk lua.LValue, lv lua.LValue) {
values = append(values, ConvertToRESP(lv))
}
} else { // map
cb = func(lk lua.LValue, lv lua.LValue) {
if lk.Type() == lua.LTString {
lks := lk.String()
switch lks {
case "ok":
specialValues = append(specialValues, resp.SimpleStringValue(lv.String()))
case "err":
specialValues = append(specialValues, resp.ErrorValue(errors.New(lv.String())))
}
}
values = append(values, resp.ArrayValue(
[]resp.Value{ConvertToRESP(lk), ConvertToRESP(lv)}))
}
}
tbl.ForEach(cb)
if len(values) == 1 && len(specialValues) == 1 {
return specialValues[0]
}
return resp.ArrayValue(values)
}
return resp.ErrorValue(errors.New("Unsupported lua type: " + val.Type().String()))
}
// ConvertToJSON converts lua LValue to JSON string
func ConvertToJSON(val lua.LValue) string {
switch val.Type() {
case lua.LTNil:
return "null"
case lua.LTBool:
if val == lua.LTrue {
return "true"
}
return "false"
case lua.LTNumber:
return val.String()
case lua.LTString:
if b, err := json.Marshal(val.String()); err != nil {
panic(err)
} else {
return string(b)
}
case lua.LTTable:
var values []string
var cb func(lk lua.LValue, lv lua.LValue)
var start, end string
tbl := val.(*lua.LTable)
if tbl.Len() != 0 { // list
start = `[`
end = `]`
cb = func(lk lua.LValue, lv lua.LValue) {
values = append(values, ConvertToJSON(lv))
}
} else { // map
start = `{`
end = `}`
cb = func(lk lua.LValue, lv lua.LValue) {
values = append(
values, ConvertToJSON(lk)+`:`+ConvertToJSON(lv))
}
}
tbl.ForEach(cb)
return start + strings.Join(values, `,`) + end
}
return "Unsupported lua type: " + val.Type().String()
}
func luaSetRawGlobals(ls *lua.LState, tbl map[string]lua.LValue) {
gt := ls.Get(lua.GlobalsIndex).(*lua.LTable)
for key, val := range tbl {
gt.RawSetString(key, val)
}
}
// Sha1Sum returns a string with hex representation of sha1 sum of a given string
func Sha1Sum(s string) string {
h := sha1.New()
h.Write([]byte(s))
return hex.EncodeToString(h.Sum(nil))
}
// Replace newlines with literal \n since RESP errors cannot have newlines
func makeSafeErr(err error) error {
return errors.New(strings.Replace(err.Error(), "\n", `\n`, -1))
}
// Run eval/evalro/evalna command or it's -sha variant
func (s *Server) cmdEvalUnified(scriptIsSha bool, msg *Message) (res resp.Value, err error) {
start := time.Now()
vs := msg.Args[1:]
var ok bool
var script, numkeysStr, key, arg string
if vs, script, ok = tokenval(vs); !ok || script == "" {
return NOMessage, errInvalidNumberOfArguments
}
if vs, numkeysStr, ok = tokenval(vs); !ok || numkeysStr == "" {
return NOMessage, errInvalidNumberOfArguments
}
var i, numkeys uint64
if numkeys, err = strconv.ParseUint(numkeysStr, 10, 64); err != nil {
err = errInvalidArgument(numkeysStr)
return
}
luaState, err := s.luapool.Get()
if err != nil {
return
}
luaDeadline := lua.LNil
if msg.Deadline != nil {
dlTime := msg.Deadline.GetDeadlineTime()
ctx, cancel := context.WithDeadline(context.Background(), dlTime)
defer cancel()
luaState.SetContext(ctx)
defer luaState.RemoveContext()
luaDeadline = lua.LNumber(float64(dlTime.UnixNano()) / 1e9)
}
defer s.luapool.Put(luaState)
keysTbl := luaState.CreateTable(int(numkeys), 0)
for i = 0; i < numkeys; i++ {
if vs, key, ok = tokenval(vs); !ok || key == "" {
err = errInvalidNumberOfArguments
return
}
keysTbl.Append(lua.LString(key))
}
argsTbl := luaState.CreateTable(len(vs), 0)
for len(vs) > 0 {
if vs, arg, ok = tokenval(vs); !ok || arg == "" {
err = errInvalidNumberOfArguments
return
}
argsTbl.Append(lua.LString(arg))
}
var shaSum string
if scriptIsSha {
shaSum = script
} else {
shaSum = Sha1Sum(script)
}
luaSetRawGlobals(
luaState, map[string]lua.LValue{
"KEYS": keysTbl,
"ARGV": argsTbl,
"DEADLINE": luaDeadline,
"EVAL_CMD": lua.LString(msg.Command()),
})
compiled, ok := s.luascripts.Get(shaSum)
var fn *lua.LFunction
if ok {
fn = &lua.LFunction{
IsG: false,
Env: luaState.Env,
Proto: compiled,
GFunction: nil,
Upvalues: make([]*lua.Upvalue, 0),
}
} else if scriptIsSha {
err = errShaNotFound
return
} else {
fn, err = luaState.Load(strings.NewReader(script), "f_"+shaSum)
if err != nil {
return NOMessage, makeSafeErr(err)
}
s.luascripts.Put(shaSum, fn.Proto)
}
luaState.Push(fn)
defer luaSetRawGlobals(
luaState, map[string]lua.LValue{
"KEYS": lua.LNil,
"ARGV": lua.LNil,
"DEADLINE": lua.LNil,
"EVAL_CMD": lua.LNil,
})
if err := luaState.PCall(0, 1, nil); err != nil {
if strings.Contains(err.Error(), "context deadline exceeded") {
msg.Deadline.Check()
}
log.Debugf("%v", err.Error())
return NOMessage, makeSafeErr(err)
}
ret := luaState.Get(-1) // returned value
luaState.Pop(1)
switch msg.OutputType {
case JSON:
var buf bytes.Buffer
buf.WriteString(`{"ok":true`)
buf.WriteString(`,"result":` + ConvertToJSON(ret))
buf.WriteString(`,"elapsed":"` + time.Since(start).String() + "\"}")
return resp.StringValue(buf.String()), nil
case RESP:
return ConvertToRESP(ret), nil
}
return NOMessage, nil
}
func (s *Server) cmdScriptLoad(msg *Message) (resp.Value, error) {
start := time.Now()
vs := msg.Args[1:]
var ok bool
var script string
if _, script, ok = tokenval(vs); !ok || script == "" {
return NOMessage, errInvalidNumberOfArguments
}
shaSum := Sha1Sum(script)
luaState, err := s.luapool.Get()
if err != nil {
return NOMessage, err
}
defer s.luapool.Put(luaState)
fn, err := luaState.Load(strings.NewReader(script), "f_"+shaSum)
if err != nil {
return NOMessage, makeSafeErr(err)
}
s.luascripts.Put(shaSum, fn.Proto)
switch msg.OutputType {
case JSON:
var buf bytes.Buffer
buf.WriteString(`{"ok":true`)
buf.WriteString(`,"result":"` + shaSum + `"`)
buf.WriteString(`,"elapsed":"` + time.Since(start).String() + "\"}")
return resp.StringValue(buf.String()), nil
case RESP:
return resp.StringValue(shaSum), nil
}
return NOMessage, nil
}
func (s *Server) cmdScriptExists(msg *Message) (resp.Value, error) {
start := time.Now()
vs := msg.Args[1:]
var ok bool
var shaSum string
var results []int
var ires int
for len(vs) > 0 {
if vs, shaSum, ok = tokenval(vs); !ok || shaSum == "" {
return NOMessage, errInvalidNumberOfArguments
}
_, ok = s.luascripts.Get(shaSum)
if ok {
ires = 1
} else {
ires = 0
}
results = append(results, ires)
}
switch msg.OutputType {
case JSON:
var buf bytes.Buffer
buf.WriteString(`{"ok":true`)
var resArray []string
for _, ires := range results {
resArray = append(resArray, fmt.Sprintf("%d", ires))
}
buf.WriteString(`,"result":[` + strings.Join(resArray, ",") + `]`)
buf.WriteString(`,"elapsed":"` + time.Since(start).String() + "\"}")
return resp.StringValue(buf.String()), nil
case RESP:
var resArray []resp.Value
for _, ires := range results {
resArray = append(resArray, resp.IntegerValue(ires))
}
return resp.ArrayValue(resArray), nil
}
return resp.SimpleStringValue(""), nil
}
func (s *Server) cmdScriptFlush(msg *Message) (resp.Value, error) {
start := time.Now()
s.luascripts.Flush()
switch msg.OutputType {
case JSON:
var buf bytes.Buffer
buf.WriteString(`{"ok":true`)
buf.WriteString(`,"elapsed":"` + time.Since(start).String() + "\"}")
return resp.StringValue(buf.String()), nil
case RESP:
return resp.StringValue("OK"), nil
}
return resp.SimpleStringValue(""), nil
}
func (s *Server) commandInScript(msg *Message) (
res resp.Value, d commandDetails, err error,
) {
switch msg.Command() {
default:
err = fmt.Errorf("unknown command '%s'", msg.Args[0])
case "set":
res, d, err = s.cmdSet(msg)
case "fset":
res, d, err = s.cmdFset(msg)
case "del":
res, d, err = s.cmdDel(msg)
case "pdel":
res, d, err = s.cmdPdel(msg)
case "drop":
res, d, err = s.cmdDrop(msg)
case "expire":
res, d, err = s.cmdExpire(msg)
case "rename":
res, d, err = s.cmdRename(msg)
case "renamenx":
res, d, err = s.cmdRename(msg)
case "persist":
res, d, err = s.cmdPersist(msg)
case "ttl":
res, err = s.cmdTTL(msg)
case "stats":
res, err = s.cmdStats(msg)
case "scan":
res, err = s.cmdScan(msg)
case "nearby":
res, err = s.cmdNearby(msg)
case "within":
res, err = s.cmdWithin(msg)
case "intersects":
res, err = s.cmdIntersects(msg)
case "search":
res, err = s.cmdSearch(msg)
case "bounds":
res, err = s.cmdBounds(msg)
case "get":
res, err = s.cmdGet(msg)
case "jget":
res, err = s.cmdJget(msg)
case "jset":
res, d, err = s.cmdJset(msg)
case "jdel":
res, d, err = s.cmdJdel(msg)
case "type":
res, err = s.cmdType(msg)
case "keys":
res, err = s.cmdKeys(msg)
case "test":
res, err = s.cmdTest(msg)
case "server":
res, err = s.cmdServer(msg)
}
s.sendMonitor(err, msg, nil, true)
return
}
func (s *Server) luaTile38Call(evalcmd string, cmd string, args ...string) (resp.Value, error) {
msg := &Message{}
msg.OutputType = RESP
msg.Args = append([]string{cmd}, args...)
if msg.Command() == "timeout" {
if err := rewriteTimeoutMsg(msg); err != nil {
return resp.NullValue(), err
}
}
switch msg.Command() {
case "ping", "echo", "auth", "massinsert", "shutdown", "gc",
"sethook", "pdelhook", "delhook",
"follow", "readonly", "config", "output", "client",
"aofshrink",
"script load", "script exists", "script flush",
"eval", "evalsha", "evalro", "evalrosha", "evalna", "evalnasha":
return resp.NullValue(), errCmdNotSupported
}
switch evalcmd {
case "eval", "evalsha":
return s.luaTile38AtomicRW(msg)
case "evalro", "evalrosha":
return s.luaTile38AtomicRO(msg)
case "evalna", "evalnasha":
return s.luaTile38NonAtomic(msg)
}
return resp.NullValue(), errCmdNotSupported
}
// The eval command has already got the lock. No locking on the call from within the script.
func (s *Server) luaTile38AtomicRW(msg *Message) (resp.Value, error) {
var write bool
switch msg.Command() {
default:
return resp.NullValue(), errCmdNotSupported
case "set", "del", "drop", "fset", "flushdb", "expire", "persist", "jset", "pdel",
"rename", "renamenx":
// write operations
write = true
if s.config.followHost() != "" {
return resp.NullValue(), errNotLeader
}
if s.config.readOnly() {
return resp.NullValue(), errReadOnly
}
case "get", "keys", "scan", "nearby", "within", "intersects", "hooks", "search",
"ttl", "bounds", "server", "info", "type", "jget", "test":
// read operations
if s.config.followHost() != "" && !s.fcuponce {
return resp.NullValue(), errCatchingUp
}
}
res, d, err := func() (res resp.Value, d commandDetails, err error) {
if msg.Deadline != nil {
if write {
res = NOMessage
err = errTimeoutOnCmd(msg.Command())
return
}
defer func() {
if msg.Deadline.Hit() {
v := recover()
if v != nil {
if s, ok := v.(string); !ok || s != "deadline" {
panic(v)
}
}
res = NOMessage
err = errTimeout
}
}()
}
return s.commandInScript(msg)
}()
if err != nil {
return resp.NullValue(), err
}
if write {
if err := s.writeAOF(msg.Args, &d); err != nil {
return resp.NullValue(), err
}
}
return res, nil
}
func (s *Server) luaTile38AtomicRO(msg *Message) (resp.Value, error) {
switch msg.Command() {
default:
return resp.NullValue(), errCmdNotSupported
case "set", "del", "drop", "fset", "flushdb", "expire", "persist", "jset", "pdel",
"rename", "renamenx":
// write operations
return resp.NullValue(), errReadOnly
case "get", "keys", "scan", "nearby", "within", "intersects", "hooks", "search",
"ttl", "bounds", "server", "info", "type", "jget", "test":
// read operations
if s.config.followHost() != "" && !s.fcuponce {
return resp.NullValue(), errCatchingUp
}
}
res, _, err := func() (res resp.Value, d commandDetails, err error) {
if msg.Deadline != nil {
defer func() {
if msg.Deadline.Hit() {
v := recover()
if v != nil {
if s, ok := v.(string); !ok || s != "deadline" {
panic(v)
}
}
res = NOMessage
err = errTimeout
}
}()
}
return s.commandInScript(msg)
}()
if err != nil {
return resp.NullValue(), err
}
return res, nil
}
func (s *Server) luaTile38NonAtomic(msg *Message) (resp.Value, error) {
var write bool
// choose the locking strategy
switch msg.Command() {
default:
return resp.NullValue(), errCmdNotSupported
case "set", "del", "drop", "fset", "flushdb", "expire", "persist", "jset", "pdel",
"rename", "renamenx":
// write operations
write = true
s.mu.Lock()
defer s.mu.Unlock()
if s.config.followHost() != "" {
return resp.NullValue(), errNotLeader
}
if s.config.readOnly() {
return resp.NullValue(), errReadOnly
}
case "get", "keys", "scan", "nearby", "within", "intersects", "hooks", "search",
"ttl", "bounds", "server", "info", "type", "jget", "test":
// read operations
s.mu.RLock()
defer s.mu.RUnlock()
if s.config.followHost() != "" && !s.fcuponce {
return resp.NullValue(), errCatchingUp
}
}
res, d, err := func() (res resp.Value, d commandDetails, err error) {
if msg.Deadline != nil {
if write {
res = NOMessage
err = errTimeoutOnCmd(msg.Command())
return
}
defer func() {
if msg.Deadline.Hit() {
v := recover()
if v != nil {
if s, ok := v.(string); !ok || s != "deadline" {
panic(v)
}
}
res = NOMessage
err = errTimeout
}
}()
}
return s.commandInScript(msg)
}()
if err != nil {
return resp.NullValue(), err
}
if write {
if err := s.writeAOF(msg.Args, &d); err != nil {
return resp.NullValue(), err
}
}
return res, nil
}