mirror of https://github.com/tidwall/tile38.git
Prohibit creation of new globals in scripts (#227)
This commit is contained in:
parent
5753f3dc43
commit
d0a510d9ff
|
@ -13,8 +13,8 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tidwall/tile38/controller/server"
|
|
||||||
"github.com/tidwall/resp"
|
"github.com/tidwall/resp"
|
||||||
|
"github.com/tidwall/tile38/controller/server"
|
||||||
"github.com/yuin/gopher-lua"
|
"github.com/yuin/gopher-lua"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ func (pl *lStatePool) Prune() {
|
||||||
if dropNum < 1 {
|
if dropNum < 1 {
|
||||||
dropNum = 1
|
dropNum = 1
|
||||||
}
|
}
|
||||||
newSaved := make([]*lua.LState, n - dropNum)
|
newSaved := make([]*lua.LState, n-dropNum)
|
||||||
copy(newSaved, pl.saved[dropNum:])
|
copy(newSaved, pl.saved[dropNum:])
|
||||||
pl.saved = newSaved
|
pl.saved = newSaved
|
||||||
}
|
}
|
||||||
|
@ -149,6 +149,16 @@ func (pl *lStatePool) New() *lua.LState {
|
||||||
"sha1hex": sha1hex,
|
"sha1hex": sha1hex,
|
||||||
}
|
}
|
||||||
L.SetGlobal("tile38", L.SetFuncs(L.NewTable(), exports))
|
L.SetGlobal("tile38", L.SetFuncs(L.NewTable(), exports))
|
||||||
|
|
||||||
|
// Prohibit creating new globals in this state
|
||||||
|
lockNewGlobals := func(ls *lua.LState) int {
|
||||||
|
ls.RaiseError("attempt to create global variable '%s'", ls.ToString(2))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
mt := L.CreateTable(0, 1)
|
||||||
|
mt.RawSetString("__newindex", L.NewFunction(lockNewGlobals))
|
||||||
|
L.SetMetatable(L.Get(lua.GlobalsIndex), mt)
|
||||||
|
|
||||||
return L
|
return L
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -322,10 +332,11 @@ func ConvertToJSON(val lua.LValue) string {
|
||||||
return "Unsupported lua type: " + val.Type().String()
|
return "Unsupported lua type: " + val.Type().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func luaStateCleanup(ls *lua.LState) {
|
func luaSetRawGlobals(ls *lua.LState, tbl map[string]lua.LValue) {
|
||||||
ls.SetGlobal("KEYS", lua.LNil)
|
gt := ls.Get(lua.GlobalsIndex).(*lua.LTable)
|
||||||
ls.SetGlobal("ARGV", lua.LNil)
|
for key, val := range tbl {
|
||||||
ls.SetGlobal("EVAL_CMD", lua.LNil)
|
gt.RawSetString(key, val)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sha1Sum returns a string with hex representation of sha1 sum of a given string
|
// Sha1Sum returns a string with hex representation of sha1 sum of a given string
|
||||||
|
@ -392,9 +403,12 @@ func (c *Controller) cmdEvalUnified(scriptIsSha bool, msg *server.Message) (res
|
||||||
shaSum = Sha1Sum(script)
|
shaSum = Sha1Sum(script)
|
||||||
}
|
}
|
||||||
|
|
||||||
luaState.SetGlobal("KEYS", keysTbl)
|
luaSetRawGlobals(
|
||||||
luaState.SetGlobal("ARGV", argsTbl)
|
luaState, map[string]lua.LValue{
|
||||||
luaState.SetGlobal("EVAL_CMD", lua.LString(msg.Command))
|
"KEYS": keysTbl,
|
||||||
|
"ARGV": argsTbl,
|
||||||
|
"EVAL_CMD": lua.LString(msg.Command),
|
||||||
|
})
|
||||||
|
|
||||||
compiled, ok := c.luascripts.Get(shaSum)
|
compiled, ok := c.luascripts.Get(shaSum)
|
||||||
var fn *lua.LFunction
|
var fn *lua.LFunction
|
||||||
|
@ -418,7 +432,12 @@ func (c *Controller) cmdEvalUnified(scriptIsSha bool, msg *server.Message) (res
|
||||||
c.luascripts.Put(shaSum, fn.Proto)
|
c.luascripts.Put(shaSum, fn.Proto)
|
||||||
}
|
}
|
||||||
luaState.Push(fn)
|
luaState.Push(fn)
|
||||||
defer luaStateCleanup(luaState)
|
defer luaSetRawGlobals(
|
||||||
|
luaState, map[string]lua.LValue{
|
||||||
|
"KEYS": lua.LNil,
|
||||||
|
"ARGV": lua.LNil,
|
||||||
|
"EVAL_CMD": lua.LNil,
|
||||||
|
})
|
||||||
|
|
||||||
if err := luaState.PCall(0, 1, nil); err != nil {
|
if err := luaState.PCall(0, 1, nil); err != nil {
|
||||||
return server.NOMessage, makeSafeErr(err)
|
return server.NOMessage, makeSafeErr(err)
|
||||||
|
|
Loading…
Reference in New Issue