mirror of https://github.com/tidwall/tile38.git
Prohibit creation of new globals in scripts (#227)
This commit is contained in:
parent
5753f3dc43
commit
d0a510d9ff
|
@ -13,8 +13,8 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/tile38/controller/server"
|
||||
"github.com/tidwall/resp"
|
||||
"github.com/tidwall/tile38/controller/server"
|
||||
"github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
|
@ -78,7 +78,7 @@ func (pl *lStatePool) Prune() {
|
|||
if dropNum < 1 {
|
||||
dropNum = 1
|
||||
}
|
||||
newSaved := make([]*lua.LState, n - dropNum)
|
||||
newSaved := make([]*lua.LState, n-dropNum)
|
||||
copy(newSaved, pl.saved[dropNum:])
|
||||
pl.saved = newSaved
|
||||
}
|
||||
|
@ -149,6 +149,16 @@ func (pl *lStatePool) New() *lua.LState {
|
|||
"sha1hex": sha1hex,
|
||||
}
|
||||
L.SetGlobal("tile38", L.SetFuncs(L.NewTable(), exports))
|
||||
|
||||
// Prohibit creating new globals in this state
|
||||
lockNewGlobals := func(ls *lua.LState) int {
|
||||
ls.RaiseError("attempt to create global variable '%s'", ls.ToString(2))
|
||||
return 0
|
||||
}
|
||||
mt := L.CreateTable(0, 1)
|
||||
mt.RawSetString("__newindex", L.NewFunction(lockNewGlobals))
|
||||
L.SetMetatable(L.Get(lua.GlobalsIndex), mt)
|
||||
|
||||
return L
|
||||
}
|
||||
|
||||
|
@ -322,10 +332,11 @@ func ConvertToJSON(val lua.LValue) string {
|
|||
return "Unsupported lua type: " + val.Type().String()
|
||||
}
|
||||
|
||||
func luaStateCleanup(ls *lua.LState) {
|
||||
ls.SetGlobal("KEYS", lua.LNil)
|
||||
ls.SetGlobal("ARGV", lua.LNil)
|
||||
ls.SetGlobal("EVAL_CMD", lua.LNil)
|
||||
func luaSetRawGlobals(ls *lua.LState, tbl map[string]lua.LValue) {
|
||||
gt := ls.Get(lua.GlobalsIndex).(*lua.LTable)
|
||||
for key, val := range tbl {
|
||||
gt.RawSetString(key, val)
|
||||
}
|
||||
}
|
||||
|
||||
// Sha1Sum returns a string with hex representation of sha1 sum of a given string
|
||||
|
@ -392,9 +403,12 @@ func (c *Controller) cmdEvalUnified(scriptIsSha bool, msg *server.Message) (res
|
|||
shaSum = Sha1Sum(script)
|
||||
}
|
||||
|
||||
luaState.SetGlobal("KEYS", keysTbl)
|
||||
luaState.SetGlobal("ARGV", argsTbl)
|
||||
luaState.SetGlobal("EVAL_CMD", lua.LString(msg.Command))
|
||||
luaSetRawGlobals(
|
||||
luaState, map[string]lua.LValue{
|
||||
"KEYS": keysTbl,
|
||||
"ARGV": argsTbl,
|
||||
"EVAL_CMD": lua.LString(msg.Command),
|
||||
})
|
||||
|
||||
compiled, ok := c.luascripts.Get(shaSum)
|
||||
var fn *lua.LFunction
|
||||
|
@ -418,7 +432,12 @@ func (c *Controller) cmdEvalUnified(scriptIsSha bool, msg *server.Message) (res
|
|||
c.luascripts.Put(shaSum, fn.Proto)
|
||||
}
|
||||
luaState.Push(fn)
|
||||
defer luaStateCleanup(luaState)
|
||||
defer luaSetRawGlobals(
|
||||
luaState, map[string]lua.LValue{
|
||||
"KEYS": lua.LNil,
|
||||
"ARGV": lua.LNil,
|
||||
"EVAL_CMD": lua.LNil,
|
||||
})
|
||||
|
||||
if err := luaState.PCall(0, 1, nil); err != nil {
|
||||
return server.NOMessage, makeSafeErr(err)
|
||||
|
|
Loading…
Reference in New Issue