ledisdb/server/cmd_script.go

214 lines
3.3 KiB
Go
Raw Normal View History

2014-09-01 19:26:35 +04:00
package server
import (
2014-09-02 18:04:18 +04:00
"crypto/sha1"
2014-09-02 13:55:12 +04:00
"encoding/hex"
2014-09-01 19:26:35 +04:00
"fmt"
2015-05-04 17:42:28 +03:00
2014-09-24 08:34:21 +04:00
"github.com/siddontang/go/hack"
2014-09-02 13:55:12 +04:00
"strconv"
"strings"
2015-05-04 17:42:28 +03:00
2017-04-15 16:51:03 +03:00
"github.com/yuin/gopher-lua"
2014-09-01 19:26:35 +04:00
)
2017-04-15 16:51:03 +03:00
func parseEvalArgs(l *lua.LState, c *client) error {
2014-09-02 13:55:12 +04:00
args := c.args
if len(args) < 2 {
return ErrCmdParams
}
2014-09-01 19:26:35 +04:00
2014-09-02 13:55:12 +04:00
args = args[1:]
2014-09-01 19:26:35 +04:00
2014-09-24 08:34:21 +04:00
n, err := strconv.Atoi(hack.String(args[0]))
2014-09-02 13:55:12 +04:00
if err != nil {
return err
}
2014-09-01 19:26:35 +04:00
2014-09-02 13:55:12 +04:00
if n > len(args)-1 {
return ErrCmdParams
}
luaSetGlobalArray(l, "KEYS", args[1:n+1])
luaSetGlobalArray(l, "ARGV", args[n+1:])
2014-09-01 19:26:35 +04:00
2014-09-02 13:55:12 +04:00
return nil
2014-09-01 19:26:35 +04:00
}
2017-04-15 16:51:03 +03:00
func evalGenericCommand(c *client, evalSha1 bool) (err error) {
s := c.app.script
2014-09-02 13:55:12 +04:00
luaClient := s.c
l := s.l
2014-09-01 19:26:35 +04:00
2014-09-02 13:55:12 +04:00
s.Lock()
2014-09-01 19:26:35 +04:00
2014-09-02 13:55:12 +04:00
defer func() {
luaClient.db = nil
// luaClient.script = nil
2014-09-02 13:55:12 +04:00
s.Unlock()
}()
luaClient.db = c.db
// luaClient.script = m
2014-09-02 13:55:12 +04:00
luaClient.remoteAddr = c.remoteAddr
if err := parseEvalArgs(l, c); err != nil {
return err
2014-09-01 19:26:35 +04:00
}
2014-09-02 18:04:18 +04:00
var key string
2014-09-02 13:55:12 +04:00
if !evalSha1 {
2014-09-02 18:04:18 +04:00
h := sha1.Sum(c.args[0])
key = hex.EncodeToString(h[0:20])
2014-09-02 13:55:12 +04:00
} else {
2014-09-24 08:34:21 +04:00
key = strings.ToLower(hack.String(c.args[0]))
2014-09-01 19:26:35 +04:00
}
2017-04-15 16:51:03 +03:00
global := l.GetGlobal(key)
2014-09-02 13:55:12 +04:00
2017-04-15 16:51:03 +03:00
if global.Type() == lua.LTNil {
2014-09-02 13:55:12 +04:00
if evalSha1 {
2014-09-02 18:04:18 +04:00
return fmt.Errorf("missing %s script", key)
2014-09-01 19:26:35 +04:00
}
2017-04-15 16:51:03 +03:00
val, err := l.LoadString(hack.String(c.args[0]))
if err != nil {
2014-09-02 13:55:12 +04:00
return err
2014-09-01 19:26:35 +04:00
}
2017-04-15 16:51:03 +03:00
l.SetGlobal(key, val)
s.chunks[key] = struct{}{}
global = val
2014-09-02 13:55:12 +04:00
}
2014-09-01 19:26:35 +04:00
2017-04-15 16:51:03 +03:00
l.Push(global)
2014-09-03 03:49:46 +04:00
2017-04-15 16:51:03 +03:00
// catch any uncaught panic
// this happens for example when the user,
// makes a mistake using `ledis.call`
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic: %v", r)
2014-09-03 03:49:46 +04:00
}
2017-04-15 16:51:03 +03:00
}()
l.Call(0, lua.MultRet)
2014-09-03 03:49:46 +04:00
2017-04-15 16:51:03 +03:00
r := luaReplyToLedisReply(l)
if v, ok := r.(error); ok {
return v
2014-09-01 19:26:35 +04:00
}
2014-09-02 13:55:12 +04:00
2017-04-15 16:51:03 +03:00
writeValue(c.resp, r)
2014-09-02 13:55:12 +04:00
return nil
}
func evalCommand(c *client) error {
return evalGenericCommand(c, false)
2014-09-01 19:26:35 +04:00
}
2014-09-02 13:55:12 +04:00
func evalshaCommand(c *client) error {
return evalGenericCommand(c, true)
}
func scriptCommand(c *client) error {
s := c.app.script
2014-09-02 13:55:12 +04:00
l := s.l
s.Lock()
base := l.GetTop()
defer func() {
l.SetTop(base)
s.Unlock()
}()
args := c.args
2014-09-01 19:26:35 +04:00
2014-09-03 13:00:03 +04:00
if len(args) < 1 {
return ErrCmdParams
}
2014-09-24 08:34:21 +04:00
switch strings.ToLower(hack.String(args[0])) {
2014-09-03 13:00:03 +04:00
case "load":
2014-09-02 13:55:12 +04:00
return scriptLoadCommand(c)
2014-09-03 13:00:03 +04:00
case "exists":
2014-09-02 13:55:12 +04:00
return scriptExistsCommand(c)
2014-09-03 13:00:03 +04:00
case "flush":
2014-09-02 13:55:12 +04:00
return scriptFlushCommand(c)
default:
2014-09-03 13:00:03 +04:00
return fmt.Errorf("invalid script %s", args[0])
2014-09-01 19:26:35 +04:00
}
}
2014-09-02 13:55:12 +04:00
func scriptLoadCommand(c *client) error {
s := c.app.script
2014-09-02 13:55:12 +04:00
l := s.l
2014-09-03 13:00:03 +04:00
if len(c.args) != 2 {
2014-09-02 13:55:12 +04:00
return ErrCmdParams
2014-09-01 19:26:35 +04:00
}
2014-09-03 13:00:03 +04:00
h := sha1.Sum(c.args[1])
2014-09-02 18:04:18 +04:00
key := hex.EncodeToString(h[0:20])
2014-09-01 19:26:35 +04:00
2017-04-15 16:51:03 +03:00
val, err := l.LoadString(hack.String(c.args[1]))
if err != nil {
2014-09-02 13:55:12 +04:00
return err
2014-09-01 19:26:35 +04:00
}
2017-04-15 16:51:03 +03:00
l.Push(val)
l.SetGlobal(key, val)
s.chunks[key] = struct{}{}
2014-09-02 13:55:12 +04:00
2014-09-24 08:34:21 +04:00
c.resp.writeBulk(hack.Slice(key))
2014-09-02 13:55:12 +04:00
return nil
2014-09-01 19:26:35 +04:00
}
2014-09-02 13:55:12 +04:00
func scriptExistsCommand(c *client) error {
s := c.app.script
2014-09-01 19:26:35 +04:00
2014-09-03 13:00:03 +04:00
if len(c.args) < 2 {
2014-09-02 18:04:18 +04:00
return ErrCmdParams
}
2014-09-03 13:00:03 +04:00
ay := make([]interface{}, len(c.args[1:]))
for i, n := range c.args[1:] {
2014-09-24 08:34:21 +04:00
if _, ok := s.chunks[hack.String(n)]; ok {
2014-09-02 13:55:12 +04:00
ay[i] = int64(1)
} else {
ay[i] = int64(0)
2014-09-01 19:26:35 +04:00
}
}
2014-09-02 13:55:12 +04:00
c.resp.writeArray(ay)
return nil
2014-09-01 19:26:35 +04:00
}
2014-09-02 13:55:12 +04:00
func scriptFlushCommand(c *client) error {
s := c.app.script
2014-09-02 13:55:12 +04:00
l := s.l
2014-09-03 13:00:03 +04:00
if len(c.args) != 1 {
return ErrCmdParams
}
for n := range s.chunks {
2017-04-15 16:51:03 +03:00
l.SetGlobal(n, lua.LNil)
2014-09-02 13:55:12 +04:00
}
2014-09-02 18:04:18 +04:00
s.chunks = map[string]struct{}{}
2014-09-02 13:55:12 +04:00
c.resp.writeStatus(OK)
return nil
}
2014-09-01 19:26:35 +04:00
2014-09-02 13:55:12 +04:00
func init() {
register("eval", evalCommand)
register("evalsha", evalshaCommand)
2014-09-03 13:00:03 +04:00
register("script", scriptCommand)
2014-09-01 19:26:35 +04:00
}