ledisdb/server/script.go

406 lines
6.8 KiB
Go

// +build lua
package server
import (
"encoding/hex"
"fmt"
"io"
"sync"
"github.com/siddontang/go/hack"
"github.com/siddontang/go/num"
lua "github.com/siddontang/golua"
"github.com/siddontang/ledisdb/ledis"
)
//ledis <-> lua type conversion, same as http://redis.io/commands/eval
type luaWriter struct {
l *lua.State
}
func (w *luaWriter) writeError(err error) {
panic(err)
}
func (w *luaWriter) writeStatus(status string) {
w.l.NewTable()
top := w.l.GetTop()
w.l.PushString("ok")
w.l.PushString(status)
w.l.SetTable(top)
}
func (w *luaWriter) writeInteger(n int64) {
w.l.PushInteger(n)
}
func (w *luaWriter) writeBulk(b []byte) {
if b == nil {
w.l.PushBoolean(false)
} else {
w.l.PushString(hack.String(b))
}
}
func (w *luaWriter) writeArray(lst []interface{}) {
if lst == nil {
w.l.PushBoolean(false)
return
}
w.l.CreateTable(len(lst), 0)
top := w.l.GetTop()
for i := range lst {
w.l.PushInteger(int64(i) + 1)
switch v := lst[i].(type) {
case []interface{}:
w.writeArray(v)
case [][]byte:
w.writeSliceArray(v)
case []byte:
w.writeBulk(v)
case nil:
w.writeBulk(nil)
case int64:
w.writeInteger(v)
default:
panic("invalid array type")
}
w.l.SetTable(top)
}
}
func (w *luaWriter) writeSliceArray(lst [][]byte) {
if lst == nil {
w.l.PushBoolean(false)
return
}
w.l.CreateTable(len(lst), 0)
for i, v := range lst {
if v == nil {
w.l.PushBoolean(false)
} else {
w.l.PushString(hack.String(v))
}
w.l.RawSeti(-2, i+1)
}
}
func (w *luaWriter) writeFVPairArray(lst []ledis.FVPair) {
if lst == nil {
w.l.PushBoolean(false)
return
}
w.l.CreateTable(len(lst)*2, 0)
for i, v := range lst {
w.l.PushString(hack.String(v.Field))
w.l.RawSeti(-2, 2*i+1)
w.l.PushString(hack.String(v.Value))
w.l.RawSeti(-2, 2*i+2)
}
}
func (w *luaWriter) writeScorePairArray(lst []ledis.ScorePair, withScores bool) {
if lst == nil {
w.l.PushBoolean(false)
return
}
if withScores {
w.l.CreateTable(len(lst)*2, 0)
for i, v := range lst {
w.l.PushString(hack.String(v.Member))
w.l.RawSeti(-2, 2*i+1)
w.l.PushString(hack.String(num.FormatInt64ToSlice(v.Score)))
w.l.RawSeti(-2, 2*i+2)
}
} else {
w.l.CreateTable(len(lst), 0)
for i, v := range lst {
w.l.PushString(hack.String(v.Member))
w.l.RawSeti(-2, i+1)
}
}
}
func (w *luaWriter) writeBulkFrom(n int64, rb io.Reader) {
w.writeError(fmt.Errorf("unsupport"))
}
func (w *luaWriter) flush() {
}
type script struct {
sync.Mutex
app *App
l *lua.State
c *client
chunks map[string]struct{}
}
func (app *App) openScript() {
s := new(script)
s.app = app
s.chunks = make(map[string]struct{})
app.script = s
l := lua.NewState()
l.OpenBase()
l.OpenLibs()
l.OpenMath()
l.OpenString()
l.OpenTable()
l.OpenPackage()
l.OpenCJson()
l.OpenCMsgpack()
l.OpenStruct()
l.Register("error", luaErrorHandler)
s.l = l
s.c = newClient(app)
s.c.db = nil
w := new(luaWriter)
w.l = l
s.c.resp = w
setGlobalDBScriptVar(l, "ledis")
setGlobalDBScriptVar(l, "redis")
setMapState(l, s)
}
func (app *App) closeScript() {
app.script.l.Close()
delMapState(app.script.l)
app.script = nil
}
var mapState = map[*lua.State]*script{}
var stateLock sync.Mutex
func setGlobalDBScriptVar(l *lua.State, name string) {
l.NewTable()
l.PushString("call")
l.PushGoFunction(luaCall)
l.SetTable(-3)
l.PushString("pcall")
l.PushGoFunction(luaPCall)
l.SetTable(-3)
l.PushString("sha1hex")
l.PushGoFunction(luaSha1Hex)
l.SetTable(-3)
l.PushString("error_reply")
l.PushGoFunction(luaErrorReply)
l.SetTable(-3)
l.PushString("status_reply")
l.PushGoFunction(luaStatusReply)
l.SetTable(-3)
l.SetGlobal(name)
}
func setMapState(l *lua.State, s *script) {
stateLock.Lock()
defer stateLock.Unlock()
mapState[l] = s
}
func getMapState(l *lua.State) *script {
stateLock.Lock()
defer stateLock.Unlock()
return mapState[l]
}
func delMapState(l *lua.State) {
stateLock.Lock()
defer stateLock.Unlock()
delete(mapState, l)
}
func luaErrorHandler(l *lua.State) int {
msg := l.ToString(1)
panic(fmt.Errorf(msg))
}
func luaCall(l *lua.State) int {
return luaCallGenericCommand(l)
}
func luaPCall(l *lua.State) (n int) {
defer func() {
if e := recover(); e != nil {
luaPushError(l, fmt.Sprintf("%v", e))
n = 1
}
return
}()
return luaCallGenericCommand(l)
}
func luaErrorReply(l *lua.State) int {
return luaReturnSingleFieldTable(l, "err")
}
func luaStatusReply(l *lua.State) int {
return luaReturnSingleFieldTable(l, "ok")
}
func luaReturnSingleFieldTable(l *lua.State, filed string) int {
if l.GetTop() != 1 || l.Type(-1) != lua.LUA_TSTRING {
luaPushError(l, "wrong number or type of arguments")
return 1
}
l.NewTable()
l.PushString(filed)
l.PushValue(-3)
l.SetTable(-3)
return 1
}
func luaSha1Hex(l *lua.State) int {
argc := l.GetTop()
if argc != 1 {
luaPushError(l, "wrong number of arguments")
return 1
}
s := l.ToString(1)
s = hex.EncodeToString(hack.Slice(s))
l.PushString(s)
return 1
}
func luaPushError(l *lua.State, msg string) {
l.NewTable()
l.PushString("err")
err := l.NewError(msg)
l.PushString(err.Error())
l.SetTable(-3)
}
func luaCallGenericCommand(l *lua.State) int {
s := getMapState(l)
if s == nil {
panic("Invalid lua call")
} else if s.c.db == nil {
panic("Invalid lua call, not prepared")
}
c := s.c
argc := l.GetTop()
if argc < 1 {
panic("Please specify at least one argument for ledis.call()")
}
c.cmd = l.ToString(1)
c.args = make([][]byte, argc-1)
for i := 2; i <= argc; i++ {
switch l.Type(i) {
case lua.LUA_TNUMBER:
c.args[i-2] = []byte(fmt.Sprintf("%.17g", l.ToNumber(i)))
case lua.LUA_TSTRING:
c.args[i-2] = []byte(l.ToString(i))
default:
panic("Lua ledis() command arguments must be strings or integers")
}
}
c.perform()
return 1
}
func luaSetGlobalArray(l *lua.State, name string, ay [][]byte) {
l.NewTable()
for i := 0; i < len(ay); i++ {
l.PushString(hack.String(ay[i]))
l.RawSeti(-2, i+1)
}
l.SetGlobal(name)
}
func luaReplyToLedisReply(l *lua.State) interface{} {
base := l.GetTop()
defer func() {
l.SetTop(base - 1)
}()
switch l.Type(-1) {
case lua.LUA_TSTRING:
return hack.Slice(l.ToString(-1))
case lua.LUA_TBOOLEAN:
if l.ToBoolean(-1) {
return int64(1)
} else {
return nil
}
case lua.LUA_TNUMBER:
return int64(l.ToInteger(-1))
case lua.LUA_TTABLE:
l.PushString("err")
l.GetTable(-2)
if l.Type(-1) == lua.LUA_TSTRING {
return fmt.Errorf("%s", l.ToString(-1))
}
l.Pop(1)
l.PushString("ok")
l.GetTable(-2)
if l.Type(-1) == lua.LUA_TSTRING {
return l.ToString(-1)
} else {
l.Pop(1)
ay := make([]interface{}, 0)
for i := 1; ; i++ {
l.PushInteger(int64(i))
l.GetTable(-2)
if l.Type(-1) == lua.LUA_TNIL {
l.Pop(1)
break
}
ay = append(ay, luaReplyToLedisReply(l))
}
return ay
}
default:
return nil
}
}