Prohibit creation of new globals in scripts (#227)

This commit is contained in:
Alex Roitman 2017-10-06 07:32:04 -07:00 committed by Josh Baker
parent 5753f3dc43
commit d0a510d9ff
1 changed files with 29 additions and 10 deletions

View File

@ -13,8 +13,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/tidwall/tile38/controller/server"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"github.com/tidwall/tile38/controller/server"
"github.com/yuin/gopher-lua" "github.com/yuin/gopher-lua"
) )
@ -78,7 +78,7 @@ func (pl *lStatePool) Prune() {
if dropNum < 1 { if dropNum < 1 {
dropNum = 1 dropNum = 1
} }
newSaved := make([]*lua.LState, n - dropNum) newSaved := make([]*lua.LState, n-dropNum)
copy(newSaved, pl.saved[dropNum:]) copy(newSaved, pl.saved[dropNum:])
pl.saved = newSaved pl.saved = newSaved
} }
@ -149,6 +149,16 @@ func (pl *lStatePool) New() *lua.LState {
"sha1hex": sha1hex, "sha1hex": sha1hex,
} }
L.SetGlobal("tile38", L.SetFuncs(L.NewTable(), exports)) L.SetGlobal("tile38", L.SetFuncs(L.NewTable(), exports))
// Prohibit creating new globals in this state
lockNewGlobals := func(ls *lua.LState) int {
ls.RaiseError("attempt to create global variable '%s'", ls.ToString(2))
return 0
}
mt := L.CreateTable(0, 1)
mt.RawSetString("__newindex", L.NewFunction(lockNewGlobals))
L.SetMetatable(L.Get(lua.GlobalsIndex), mt)
return L return L
} }
@ -322,10 +332,11 @@ func ConvertToJSON(val lua.LValue) string {
return "Unsupported lua type: " + val.Type().String() return "Unsupported lua type: " + val.Type().String()
} }
func luaStateCleanup(ls *lua.LState) { func luaSetRawGlobals(ls *lua.LState, tbl map[string]lua.LValue) {
ls.SetGlobal("KEYS", lua.LNil) gt := ls.Get(lua.GlobalsIndex).(*lua.LTable)
ls.SetGlobal("ARGV", lua.LNil) for key, val := range tbl {
ls.SetGlobal("EVAL_CMD", lua.LNil) gt.RawSetString(key, val)
}
} }
// Sha1Sum returns a string with hex representation of sha1 sum of a given string // Sha1Sum returns a string with hex representation of sha1 sum of a given string
@ -392,9 +403,12 @@ func (c *Controller) cmdEvalUnified(scriptIsSha bool, msg *server.Message) (res
shaSum = Sha1Sum(script) shaSum = Sha1Sum(script)
} }
luaState.SetGlobal("KEYS", keysTbl) luaSetRawGlobals(
luaState.SetGlobal("ARGV", argsTbl) luaState, map[string]lua.LValue{
luaState.SetGlobal("EVAL_CMD", lua.LString(msg.Command)) "KEYS": keysTbl,
"ARGV": argsTbl,
"EVAL_CMD": lua.LString(msg.Command),
})
compiled, ok := c.luascripts.Get(shaSum) compiled, ok := c.luascripts.Get(shaSum)
var fn *lua.LFunction var fn *lua.LFunction
@ -418,7 +432,12 @@ func (c *Controller) cmdEvalUnified(scriptIsSha bool, msg *server.Message) (res
c.luascripts.Put(shaSum, fn.Proto) c.luascripts.Put(shaSum, fn.Proto)
} }
luaState.Push(fn) luaState.Push(fn)
defer luaStateCleanup(luaState) defer luaSetRawGlobals(
luaState, map[string]lua.LValue{
"KEYS": lua.LNil,
"ARGV": lua.LNil,
"EVAL_CMD": lua.LNil,
})
if err := luaState.PCall(0, 1, nil); err != nil { if err := luaState.PCall(0, 1, nil); err != nil {
return server.NOMessage, makeSafeErr(err) return server.NOMessage, makeSafeErr(err)