From ab8e1cc202ca675b48e2def9b50633078ce69ea2 Mon Sep 17 00:00:00 2001 From: program-- Date: Wed, 9 Nov 2022 13:33:37 -0800 Subject: [PATCH] fix: handle EVAL vulnerability; open subset of lua modules --- internal/server/scripts.go | 112 ++++++++++++++++++++++++++++++++++++- tests/scripts_test.go | 13 +++++ 2 files changed, 124 insertions(+), 1 deletion(-) diff --git a/internal/server/scripts.go b/internal/server/scripts.go index 3aae4114..c5bc904d 100644 --- a/internal/server/scripts.go +++ b/internal/server/scripts.go @@ -26,6 +26,9 @@ const ( maxLuaPoolSize = 1000 ) +// For Lua os.clock() impl +var startedAt time.Time + var errShaNotFound = errors.New("sha not found") var errCmdNotSupported = errors.New("command not supported in scripts") var errNotLeader = errors.New("not the leader") @@ -42,6 +45,10 @@ type lStatePool struct { total int } +func init() { + startedAt = time.Now() +} + // newPool returns a new pool of lua states func (s *Server) newPool() *lStatePool { pl := &lStatePool{ @@ -91,7 +98,31 @@ func (pl *lStatePool) Prune() { } func (pl *lStatePool) New() *lua.LState { - L := lua.NewState() + // Prevent opening all Lua modules + L := lua.NewState(lua.Options{SkipOpenLibs: true}) + + allowedModules := []struct { + moduleName string + moduleFn lua.LGFunction + }{ + {lua.LoadLibName, lua.OpenPackage}, + {lua.BaseLibName, openBaseSubset}, + {lua.TabLibName, lua.OpenTable}, + {lua.MathLibName, lua.OpenMath}, + {lua.StringLibName, lua.OpenString}, + {lua.OsLibName, openOsSubset}, // See below for impl, only opens clock/difftime + } + + // Open non-vulnerable modules (i.e. NOT io/os) + for _, pair := range allowedModules { + if err := L.CallByParam(lua.P{ + Fn: L.NewFunction(pair.moduleFn), + NRet: 0, + Protect: true, + }, lua.LString(pair.moduleName)); err != nil { + panic(err) + } + } getArgs := func(ls *lua.LState) (evalCmd string, args []string) { evalCmd = ls.GetGlobal("EVAL_CMD").String() @@ -844,3 +875,82 @@ func (s *Server) luaTile38NonAtomic(msg *Message) (resp.Value, error) { return res, nil } + +// Opens a subset of the Lua 5.1 base module (tonumber, tostring) +func openBaseSubset(L *lua.LState) int { + basefns := map[string]lua.LGFunction{ + "tonumber": baseToNumber, + "tostring": baseToString, + } + + global := L.Get(lua.GlobalsIndex).(*lua.LTable) + L.SetGlobal("_G", global) + L.SetGlobal("_VERSION", lua.LString(lua.LuaVersion)) + L.SetGlobal("_GOPHER_LUA_VERSION", lua.LString(lua.PackageName+" "+lua.PackageVersion)) + basemod := L.RegisterModule("_G", basefns) + L.Push(basemod) + return 1 + +} + +// Opens a subset of the Lua 5.1 os module (clock, difftime) +func openOsSubset(L *lua.LState) int { + osfns := map[string]lua.LGFunction{ + "clock": osClock, + "difftime": osDiffTime, + } + osmod := L.RegisterModule(lua.OsLibName, osfns) + L.Push(osmod) + return 1 +} + +// Lua tonumber() +func baseToNumber(L *lua.LState) int { + base := L.OptInt(2, 10) + noBase := L.Get(2) == lua.LNil + + switch lv := L.CheckAny(1).(type) { + case lua.LNumber: + L.Push(lv) + case lua.LString: + str := strings.Trim(string(lv), " \n\t") + if strings.Index(str, ".") > -1 { + if v, err := strconv.ParseFloat(str, lua.LNumberBit); err != nil { + L.Push(lua.LNil) + } else { + L.Push(lua.LNumber(v)) + } + } else { + if noBase && strings.HasPrefix(strings.ToLower(str), "0x") { + base, str = 16, str[2:] // Hex number + } + if v, err := strconv.ParseInt(str, base, lua.LNumberBit); err != nil { + L.Push(lua.LNil) + } else { + L.Push(lua.LNumber(v)) + } + } + default: + L.Push(lua.LNil) + } + return 1 +} + +// Lua tostring() +func baseToString(L *lua.LState) int { + v1 := L.CheckAny(1) + L.Push(L.ToStringMeta(v1)) + return 1 +} + +// Lua os.clock() +func osClock(L *lua.LState) int { + L.Push(lua.LNumber(float64(time.Now().Sub(startedAt)) / float64(time.Second))) + return 1 +} + +// Lua os.difftime() +func osDiffTime(L *lua.LState) int { + L.Push(lua.LNumber(L.CheckInt64(1) - L.CheckInt64(2))) + return 1 +} diff --git a/tests/scripts_test.go b/tests/scripts_test.go index 8dbefc7f..13d03979 100644 --- a/tests/scripts_test.go +++ b/tests/scripts_test.go @@ -10,6 +10,7 @@ func subTestScripts(g *testGroup) { g.regSubTest("ATOMIC", scripts_ATOMIC_test) g.regSubTest("READONLY", scripts_READONLY_test) g.regSubTest("NONATOMIC", scripts_NONATOMIC_test) + g.regSubTest("VULN", scripts_VULN_test) } func scripts_BASIC_test(mc *mockServer) error { @@ -61,3 +62,15 @@ func scripts_NONATOMIC_test(mc *mockServer) error { {"EVALNA", "return tile38.call('get', KEYS[1], ARGV[1], ARGV[2])", "1", "mykey", "myid1", "point"}, {"[33 -115]"}, }) } + +func scripts_VULN_test(mc *mockServer) error { + return mc.DoBatch([][]interface{}{ + {"EVAL", "return io", "0"}, {nil}, + {"EVAL", "return file", "0"}, {nil}, + {"EVAL", "return os.execute", "0"}, {nil}, + {"EVAL", "return os.getenv", "0"}, {nil}, + {"EVAL", "return os.clock", "0"}, {"ERR Unsupported lua type: function"}, + {"EVAL", "return loadfile", "0"}, {nil}, + {"EVAL", "return tonumber", "0"}, {"ERR Unsupported lua type: function"}, + }) +}