mirror of https://github.com/tidwall/tile38.git
284 lines
7.1 KiB
Go
284 lines
7.1 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/sha1"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"runtime"
|
|
"runtime/debug"
|
|
|
|
"github.com/tidwall/resp"
|
|
"github.com/yuin/gopher-lua"
|
|
"strings"
|
|
)
|
|
|
|
var errCmdNotSupported = errors.New("command not supported in scripts")
|
|
|
|
func Sha1Sum(s string) string {
|
|
h := sha1.New()
|
|
h.Write([]byte(s))
|
|
return hex.EncodeToString(h.Sum(nil))
|
|
}
|
|
|
|
// Convert lua LValue to RESP value
|
|
func ConvertToResp(val lua.LValue) resp.Value {
|
|
switch val.Type() {
|
|
case lua.LTNil:
|
|
return resp.NullValue()
|
|
case lua.LTBool:
|
|
if val == lua.LTrue {
|
|
return resp.IntegerValue(1)
|
|
} else {
|
|
return resp.NullValue()
|
|
}
|
|
case lua.LTNumber:
|
|
if float := float64(val.(lua.LNumber)); math.IsNaN(float) || math.IsInf(float, 0) {
|
|
return resp.FloatValue(float)
|
|
} else {
|
|
return resp.IntegerValue(int(math.Floor(float)))
|
|
}
|
|
case lua.LTString:
|
|
return resp.StringValue(val.String())
|
|
case lua.LTTable:
|
|
var values []resp.Value
|
|
var specialValues []resp.Value
|
|
var cb func(lk lua.LValue, lv lua.LValue)
|
|
tbl := val.(*lua.LTable)
|
|
|
|
if tbl.Len() != 0 { // list
|
|
cb = func(lk lua.LValue, lv lua.LValue) {
|
|
values = append(values, ConvertToResp(lv))
|
|
}
|
|
} else { // map
|
|
cb = func(lk lua.LValue, lv lua.LValue) {
|
|
if lk.Type() == lua.LTString {
|
|
lks := lk.String()
|
|
switch lks {
|
|
case "ok":
|
|
specialValues = append(specialValues, resp.SimpleStringValue(lv.String()))
|
|
case "err":
|
|
specialValues = append(specialValues, resp.ErrorValue(errors.New(lv.String())))
|
|
}
|
|
}
|
|
values = append(values, resp.ArrayValue(
|
|
[]resp.Value{ConvertToResp(lk), ConvertToResp(lv)}))
|
|
}
|
|
}
|
|
tbl.ForEach(cb)
|
|
if len(values) == 1 && len(specialValues) == 1 {
|
|
return specialValues[0]
|
|
}
|
|
return resp.ArrayValue(values)
|
|
}
|
|
return resp.ErrorValue(errors.New("Unsupported lua type: " + val.Type().String()))
|
|
}
|
|
|
|
// Convert RESP value to lua LValue
|
|
func ConvertToLua(L *lua.LState, val resp.Value) lua.LValue {
|
|
if val.IsNull() {
|
|
return lua.LFalse
|
|
}
|
|
switch val.Type() {
|
|
case resp.Integer:
|
|
return lua.LNumber(val.Integer())
|
|
case resp.BulkString:
|
|
return lua.LString(val.String())
|
|
case resp.Error:
|
|
tbl := L.CreateTable(0, 1)
|
|
tbl.RawSetString("err", lua.LString(val.String()))
|
|
return tbl
|
|
case resp.SimpleString:
|
|
tbl := L.CreateTable(0, 1)
|
|
tbl.RawSetString("ok", lua.LString(val.String()))
|
|
return tbl
|
|
case resp.Array:
|
|
tbl := L.CreateTable(len(val.Array()), 0)
|
|
for _, item := range val.Array() {
|
|
tbl.Append(ConvertToLua(L, item))
|
|
}
|
|
return tbl
|
|
}
|
|
return lua.LString("ERR: unknown RESP type: " + val.Type().String())
|
|
}
|
|
|
|
func luaTile38Call(evalcmd string, cmd string, args ...string) (resp.Value, error) {
|
|
var values []resp.Value
|
|
values = append(values, resp.StringValue("RUNNING:"))
|
|
values = append(values, resp.StringValue(evalcmd))
|
|
values = append(values, resp.StringValue(cmd))
|
|
for _, arg := range args {
|
|
values = append(values, resp.StringValue(arg))
|
|
}
|
|
|
|
return resp.ArrayValue(values), nil
|
|
}
|
|
|
|
func NewLuaState() *lua.LState {
|
|
L := lua.NewState()
|
|
|
|
get_args := func(ls *lua.LState) (evalCmd string, args []string) {
|
|
evalCmd = ls.GetGlobal("EVAL_CMD").String()
|
|
//log.Debugf("EVAL_CMD %s\n", evalCmd)
|
|
|
|
// Trying to work with unknown number of args.
|
|
// When we see empty arg we call it enough.
|
|
for i := 1; ; i++ {
|
|
if arg := ls.ToString(i); arg == "" {
|
|
break
|
|
} else {
|
|
args = append(args, arg)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
call := func(ls *lua.LState) int {
|
|
evalCmd, args := get_args(ls)
|
|
if res, err := luaTile38Call(evalCmd, args[0], args[1:]...); err != nil {
|
|
//log.Debugf("RES type: %s value: %s ERR %s\n", res.Type(), res.String(), err);
|
|
ls.RaiseError("ERR %s", err.Error())
|
|
return 0
|
|
} else {
|
|
//log.Debugf("RES type: %s value: %s\n", res.Type(), res.String());
|
|
ls.Push(ConvertToLua(ls, res))
|
|
return 1
|
|
}
|
|
}
|
|
pcall := func(ls *lua.LState) int {
|
|
evalCmd, args := get_args(ls)
|
|
if res, err := luaTile38Call(evalCmd, args[0], args[1:]...); err != nil {
|
|
//log.Debugf("RES type: %s value: %s ERR %s\n", res.Type(), res.String(), err);
|
|
ls.Push(ConvertToLua(ls, resp.ErrorValue(err)))
|
|
return 1
|
|
} else {
|
|
//log.Debugf("RES type: %s value: %s\n", res.Type(), res.String());
|
|
ls.Push(ConvertToLua(ls, res))
|
|
return 1
|
|
}
|
|
}
|
|
error_reply := func(ls *lua.LState) int {
|
|
tbl := L.CreateTable(0, 1)
|
|
tbl.RawSetString("err", lua.LString(ls.ToString(1)))
|
|
ls.Push(tbl)
|
|
return 1
|
|
}
|
|
status_reply := func(ls *lua.LState) int {
|
|
tbl := L.CreateTable(0, 1)
|
|
tbl.RawSetString("ok", lua.LString(ls.ToString(1)))
|
|
ls.Push(tbl)
|
|
return 1
|
|
}
|
|
sha1hex := func(ls *lua.LState) int {
|
|
sha_sum := Sha1Sum(ls.ToString(1))
|
|
ls.Push(lua.LString(sha_sum))
|
|
return 1
|
|
}
|
|
var exports = map[string]lua.LGFunction{
|
|
"call": call,
|
|
"pcall": pcall,
|
|
"error_reply": error_reply,
|
|
"status_reply": status_reply,
|
|
"sha1hex": sha1hex,
|
|
}
|
|
L.SetGlobal("tile38", L.SetFuncs(L.NewTable(), exports))
|
|
return L
|
|
}
|
|
|
|
func makeSafeErr(err error) error {
|
|
return errors.New(strings.Replace(err.Error(), "\n", `\n`, -1))
|
|
}
|
|
|
|
func runLuaFunc(luaState *lua.LState, script string, name string) resp.Value {
|
|
luaState.SetGlobal("EVAL_CMD", lua.LString("FAKE_EVAL"))
|
|
fn, err := luaState.Load(strings.NewReader(script), name)
|
|
if err != nil {
|
|
return resp.ErrorValue(makeSafeErr(err))
|
|
}
|
|
luaState.Push(fn)
|
|
if err := luaState.PCall(0, 1, nil); err != nil {
|
|
return resp.ErrorValue(makeSafeErr(err))
|
|
}
|
|
ret := luaState.Get(-1) // returned value
|
|
luaState.Pop(1)
|
|
luaState.SetGlobal("EVAL_CMD", lua.LNil)
|
|
return ConvertToResp(ret)
|
|
}
|
|
|
|
func runMany(luaState *lua.LState, start int, num int) int {
|
|
fmt.Printf("\nRunning %d lua calls... ", num)
|
|
for i := 0; i < num; i++ {
|
|
script := fmt.Sprintf("return tile38.call('foo', 'bar', %d)", i)
|
|
name := fmt.Sprintf("f_%020d", i)
|
|
ret := runLuaFunc(luaState, script, name)
|
|
if ret.Type() == resp.Error {
|
|
panic(ret.String())
|
|
}
|
|
}
|
|
fmt.Printf("done.\n")
|
|
return start + num
|
|
}
|
|
|
|
func printMemStats() {
|
|
var mem runtime.MemStats
|
|
runtime.GC()
|
|
debug.FreeOSMemory()
|
|
runtime.GC()
|
|
debug.FreeOSMemory()
|
|
runtime.GC()
|
|
debug.FreeOSMemory()
|
|
runtime.GC()
|
|
debug.FreeOSMemory()
|
|
runtime.ReadMemStats(&mem)
|
|
fmt.Printf("MemStats: Alloc %d, HeapAlloc %d, HeapSys %d, GCSys %d, HeapObjects %d.\n",
|
|
mem.Alloc, mem.HeapAlloc, mem.HeapSys, mem.GCSys, mem.HeapObjects)
|
|
}
|
|
|
|
func testLua() {
|
|
var luaState *lua.LState
|
|
start := 12345
|
|
luaState = NewLuaState()
|
|
|
|
printMemStats()
|
|
|
|
fmt.Printf("\nRunning single call as a test\n")
|
|
ret := runLuaFunc(luaState, "return tile38.call('fake_cmd', 'a', 'b')", "test_call")
|
|
fmt.Printf("Result: %s\n", ret.String())
|
|
|
|
printMemStats()
|
|
|
|
start = runMany(luaState, start, 100)
|
|
printMemStats()
|
|
|
|
start = runMany(luaState, start, 100)
|
|
printMemStats()
|
|
|
|
start = runMany(luaState, start, 100)
|
|
printMemStats()
|
|
|
|
start = runMany(luaState, start, 100)
|
|
printMemStats()
|
|
|
|
start = runMany(luaState, start, 1000)
|
|
printMemStats()
|
|
|
|
start = runMany(luaState, start, 10000)
|
|
printMemStats()
|
|
|
|
start = runMany(luaState, start, 1000)
|
|
printMemStats()
|
|
|
|
start = runMany(luaState, start, 100)
|
|
printMemStats()
|
|
|
|
start = runMany(luaState, start, 1000)
|
|
printMemStats()
|
|
|
|
luaState.Close()
|
|
}
|
|
|
|
func main() {
|
|
fmt.Printf("Starting memtest.\n")
|
|
testLua()
|
|
}
|