Prohibit creation of new globals in scripts (#227)

This commit is contained in:
Alex Roitman 2017-10-06 07:32:04 -07:00 committed by Josh Baker
parent 5753f3dc43
commit d0a510d9ff
1 changed files with 29 additions and 10 deletions

View File

@ -13,8 +13,8 @@ import (
"sync"
"time"
"github.com/tidwall/tile38/controller/server"
"github.com/tidwall/resp"
"github.com/tidwall/tile38/controller/server"
"github.com/yuin/gopher-lua"
)
@ -149,6 +149,16 @@ func (pl *lStatePool) New() *lua.LState {
"sha1hex": sha1hex,
}
L.SetGlobal("tile38", L.SetFuncs(L.NewTable(), exports))
// Prohibit creating new globals in this state
lockNewGlobals := func(ls *lua.LState) int {
ls.RaiseError("attempt to create global variable '%s'", ls.ToString(2))
return 0
}
mt := L.CreateTable(0, 1)
mt.RawSetString("__newindex", L.NewFunction(lockNewGlobals))
L.SetMetatable(L.Get(lua.GlobalsIndex), mt)
return L
}
@ -322,10 +332,11 @@ func ConvertToJSON(val lua.LValue) string {
return "Unsupported lua type: " + val.Type().String()
}
func luaStateCleanup(ls *lua.LState) {
ls.SetGlobal("KEYS", lua.LNil)
ls.SetGlobal("ARGV", lua.LNil)
ls.SetGlobal("EVAL_CMD", lua.LNil)
func luaSetRawGlobals(ls *lua.LState, tbl map[string]lua.LValue) {
gt := ls.Get(lua.GlobalsIndex).(*lua.LTable)
for key, val := range tbl {
gt.RawSetString(key, val)
}
}
// Sha1Sum returns a string with hex representation of sha1 sum of a given string
@ -392,9 +403,12 @@ func (c *Controller) cmdEvalUnified(scriptIsSha bool, msg *server.Message) (res
shaSum = Sha1Sum(script)
}
luaState.SetGlobal("KEYS", keysTbl)
luaState.SetGlobal("ARGV", argsTbl)
luaState.SetGlobal("EVAL_CMD", lua.LString(msg.Command))
luaSetRawGlobals(
luaState, map[string]lua.LValue{
"KEYS": keysTbl,
"ARGV": argsTbl,
"EVAL_CMD": lua.LString(msg.Command),
})
compiled, ok := c.luascripts.Get(shaSum)
var fn *lua.LFunction
@ -418,7 +432,12 @@ func (c *Controller) cmdEvalUnified(scriptIsSha bool, msg *server.Message) (res
c.luascripts.Put(shaSum, fn.Proto)
}
luaState.Push(fn)
defer luaStateCleanup(luaState)
defer luaSetRawGlobals(
luaState, map[string]lua.LValue{
"KEYS": lua.LNil,
"ARGV": lua.LNil,
"EVAL_CMD": lua.LNil,
})
if err := luaState.PCall(0, 1, nil); err != nil {
return server.NOMessage, makeSafeErr(err)