add lua support

This commit is contained in:
siddontang 2014-09-02 17:55:12 +08:00
parent 93f3cc5343
commit ab1ae62bf7
15 changed files with 1132 additions and 357 deletions

105
ledis/batch.go Normal file
View File

@ -0,0 +1,105 @@
package ledis
import (
"github.com/siddontang/ledisdb/store"
"sync"
)
type batch struct {
l *Ledis
store.WriteBatch
sync.Locker
logs [][]byte
tx *Tx
}
func (b *batch) Commit() error {
b.l.commitLock.Lock()
defer b.l.commitLock.Unlock()
err := b.WriteBatch.Commit()
if b.l.binlog != nil {
if err == nil {
if b.tx == nil {
b.l.binlog.Log(b.logs...)
} else {
b.tx.logs = append(b.tx.logs, b.logs...)
}
}
b.logs = [][]byte{}
}
return err
}
func (b *batch) Lock() {
b.Locker.Lock()
}
func (b *batch) Unlock() {
if b.l.binlog != nil {
b.logs = [][]byte{}
}
b.WriteBatch.Rollback()
b.Locker.Unlock()
}
func (b *batch) Put(key []byte, value []byte) {
if b.l.binlog != nil {
buf := encodeBinLogPut(key, value)
b.logs = append(b.logs, buf)
}
b.WriteBatch.Put(key, value)
}
func (b *batch) Delete(key []byte) {
if b.l.binlog != nil {
buf := encodeBinLogDelete(key)
b.logs = append(b.logs, buf)
}
b.WriteBatch.Delete(key)
}
type dbBatchLocker struct {
l *sync.Mutex
wrLock *sync.RWMutex
}
func (l *dbBatchLocker) Lock() {
l.wrLock.RLock()
l.l.Lock()
}
func (l *dbBatchLocker) Unlock() {
l.l.Unlock()
l.wrLock.RUnlock()
}
type txBatchLocker struct {
}
func (l *txBatchLocker) Lock() {}
func (l *txBatchLocker) Unlock() {}
type multiBatchLocker struct {
}
func (l *multiBatchLocker) Lock() {}
func (l *multiBatchLocker) Unlock() {}
func (l *Ledis) newBatch(wb store.WriteBatch, locker sync.Locker, tx *Tx) *batch {
b := new(batch)
b.l = l
b.WriteBatch = wb
b.tx = tx
b.Locker = locker
b.logs = [][]byte{}
return b
}

View File

@ -86,3 +86,9 @@ const (
BinLogTypePut uint8 = 0x1
BinLogTypeCommand uint8 = 0x2
)
const (
DBAutoCommit uint8 = 0x0
DBInTransaction uint8 = 0x1
DBInMulti uint8 = 0x2
)

View File

@ -3,6 +3,7 @@ package ledis
import (
"fmt"
"github.com/siddontang/ledisdb/store"
"sync"
)
type ibucket interface {
@ -37,7 +38,7 @@ type DB struct {
binBatch *batch
setBatch *batch
isTx bool
status uint8
}
func (l *Ledis) newDB(index uint8) *DB {
@ -49,7 +50,7 @@ func (l *Ledis) newDB(index uint8) *DB {
d.bucket = d.sdb
d.isTx = false
d.status = DBAutoCommit
d.index = index
d.kvBatch = d.newBatch()
@ -62,10 +63,18 @@ func (l *Ledis) newDB(index uint8) *DB {
return d
}
func (db *DB) newBatch() *batch {
return db.l.newBatch(db.bucket.NewWriteBatch(), &dbBatchLocker{l: &sync.Mutex{}, wrLock: &db.l.wLock}, nil)
}
func (db *DB) Index() int {
return int(db.index)
}
func (db *DB) IsAutoCommit() bool {
return db.status == DBAutoCommit
}
func (db *DB) FlushAll() (drop int64, err error) {
all := [...](func() (int64, error)){
db.flush,

73
ledis/multi.go Normal file
View File

@ -0,0 +1,73 @@
package ledis
import (
"errors"
"fmt"
)
var (
ErrNestMulti = errors.New("nest multi not supported")
ErrMultiDone = errors.New("multi has been closed")
)
type Multi struct {
*DB
}
func (db *DB) IsInMulti() bool {
return db.status == DBInMulti
}
// begin a mutli to execute commands,
// it will block any other write operations before you close the multi, unlike transaction, mutli can not rollback
func (db *DB) Multi() (*Multi, error) {
if db.IsInMulti() {
return nil, ErrNestMulti
}
m := new(Multi)
m.DB = new(DB)
m.DB.status = DBInMulti
m.DB.l = db.l
m.l.wLock.Lock()
m.DB.sdb = db.sdb
m.DB.bucket = db.sdb
m.DB.index = db.index
m.DB.kvBatch = m.newBatch()
m.DB.listBatch = m.newBatch()
m.DB.hashBatch = m.newBatch()
m.DB.zsetBatch = m.newBatch()
m.DB.binBatch = m.newBatch()
m.DB.setBatch = m.newBatch()
return m, nil
}
func (m *Multi) newBatch() *batch {
return m.l.newBatch(m.bucket.NewWriteBatch(), &multiBatchLocker{}, nil)
}
func (m *Multi) Close() error {
if m.bucket == nil {
return ErrMultiDone
}
m.l.wLock.Unlock()
m.bucket = nil
return nil
}
func (m *Multi) Select(index int) error {
if index < 0 || index >= int(MaxDBNumber) {
return fmt.Errorf("invalid db index %d", index)
}
m.DB.index = uint8(index)
return nil
}

51
ledis/multi_test.go Normal file
View File

@ -0,0 +1,51 @@
package ledis
import (
"sync"
"testing"
)
func TestMulti(t *testing.T) {
db := getTestDB()
key := []byte("test_multi_1")
v1 := []byte("v1")
v2 := []byte("v2")
m, err := db.Multi()
if err != nil {
t.Fatal(err)
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
if err := db.Set(key, v2); err != nil {
t.Fatal(err)
}
wg.Done()
}()
if err := m.Set(key, v1); err != nil {
t.Fatal(err)
}
if v, err := m.Get(key); err != nil {
t.Fatal(err)
} else if string(v) != string(v1) {
t.Fatal(string(v))
}
m.Close()
wg.Wait()
if v, err := db.Get(key); err != nil {
t.Fatal(err)
} else if string(v) != string(v2) {
t.Fatal(string(v))
}
}

View File

@ -70,6 +70,12 @@ func TestReplication(t *testing.T) {
db.HSet([]byte("c"), []byte("3"), []byte("value"))
}
m, _ := db.Multi()
m.Set([]byte("a1"), []byte("value"))
m.Set([]byte("b1"), []byte("value"))
m.Set([]byte("c1"), []byte("value"))
m.Close()
for _, name := range master.binlog.LogNames() {
p := path.Join(master.binlog.LogPath(), name)

View File

@ -1,209 +0,0 @@
package ledis
import (
"errors"
"github.com/siddontang/ledisdb/store"
"sync"
)
var (
ErrNestTx = errors.New("nest transaction not supported")
ErrTxDone = errors.New("Transaction has already been committed or rolled back")
)
type batch struct {
l *Ledis
store.WriteBatch
sync.Locker
logs [][]byte
tx *Tx
}
type dbBatchLocker struct {
l *sync.Mutex
wrLock *sync.RWMutex
}
func (l *dbBatchLocker) Lock() {
l.wrLock.RLock()
l.l.Lock()
}
func (l *dbBatchLocker) Unlock() {
l.l.Unlock()
l.wrLock.RUnlock()
}
type txBatchLocker struct {
}
func (l *txBatchLocker) Lock() {}
func (l *txBatchLocker) Unlock() {}
func (l *Ledis) newBatch(wb store.WriteBatch, tx *Tx) *batch {
b := new(batch)
b.l = l
b.WriteBatch = wb
b.tx = tx
if tx == nil {
b.Locker = &dbBatchLocker{l: &sync.Mutex{}, wrLock: &l.wLock}
} else {
b.Locker = &txBatchLocker{}
}
b.logs = [][]byte{}
return b
}
func (db *DB) newBatch() *batch {
return db.l.newBatch(db.bucket.NewWriteBatch(), nil)
}
func (b *batch) Commit() error {
b.l.commitLock.Lock()
defer b.l.commitLock.Unlock()
err := b.WriteBatch.Commit()
if b.l.binlog != nil {
if err == nil {
if b.tx == nil {
b.l.binlog.Log(b.logs...)
} else {
b.tx.logs = append(b.tx.logs, b.logs...)
}
}
b.logs = [][]byte{}
}
return err
}
func (b *batch) Lock() {
b.Locker.Lock()
}
func (b *batch) Unlock() {
if b.l.binlog != nil {
b.logs = [][]byte{}
}
b.WriteBatch.Rollback()
b.Locker.Unlock()
}
func (b *batch) Put(key []byte, value []byte) {
if b.l.binlog != nil {
buf := encodeBinLogPut(key, value)
b.logs = append(b.logs, buf)
}
b.WriteBatch.Put(key, value)
}
func (b *batch) Delete(key []byte) {
if b.l.binlog != nil {
buf := encodeBinLogDelete(key)
b.logs = append(b.logs, buf)
}
b.WriteBatch.Delete(key)
}
type Tx struct {
*DB
tx *store.Tx
logs [][]byte
index uint8
}
func (db *DB) IsTransaction() bool {
return db.isTx
}
// Begin a transaction, it will block all other write operations before calling Commit or Rollback.
// You must be very careful to prevent long-time transaction.
func (db *DB) Begin() (*Tx, error) {
if db.isTx {
return nil, ErrNestTx
}
tx := new(Tx)
tx.DB = new(DB)
tx.DB.l = db.l
tx.l.wLock.Lock()
tx.index = db.index
tx.DB.sdb = db.sdb
var err error
tx.tx, err = db.sdb.Begin()
if err != nil {
tx.l.wLock.Unlock()
return nil, err
}
tx.DB.bucket = tx.tx
tx.DB.isTx = true
tx.DB.index = db.index
tx.DB.kvBatch = tx.newTxBatch()
tx.DB.listBatch = tx.newTxBatch()
tx.DB.hashBatch = tx.newTxBatch()
tx.DB.zsetBatch = tx.newTxBatch()
tx.DB.binBatch = tx.newTxBatch()
tx.DB.setBatch = tx.newTxBatch()
return tx, nil
}
func (tx *Tx) Commit() error {
if tx.tx == nil {
return ErrTxDone
}
tx.l.commitLock.Lock()
err := tx.tx.Commit()
tx.tx = nil
if len(tx.logs) > 0 {
tx.l.binlog.Log(tx.logs...)
}
tx.l.commitLock.Unlock()
tx.l.wLock.Unlock()
tx.DB = nil
return err
}
func (tx *Tx) Rollback() error {
if tx.tx == nil {
return ErrTxDone
}
err := tx.tx.Rollback()
tx.tx = nil
tx.l.wLock.Unlock()
tx.DB = nil
return err
}
func (tx *Tx) newTxBatch() *batch {
return tx.l.newBatch(tx.tx.NewWriteBatch(), tx)
}
func (tx *Tx) Index() int {
return int(tx.index)
}

View File

@ -144,6 +144,51 @@ func testTxCommit(t *testing.T, db *DB) {
}
}
func testTxSelect(t *testing.T, db *DB) {
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()
tx.Set([]byte("tx_select_1"), []byte("a"))
tx.Select(1)
tx.Set([]byte("tx_select_2"), []byte("b"))
if err = tx.Commit(); err != nil {
t.Fatal(err)
}
if v, err := db.Get([]byte("tx_select_1")); err != nil {
t.Fatal(err)
} else if string(v) != "a" {
t.Fatal(string(v))
}
if v, err := db.Get([]byte("tx_select_2")); err != nil {
t.Fatal(err)
} else if v != nil {
t.Fatal("must nil")
}
db, _ = db.l.Select(1)
if v, err := db.Get([]byte("tx_select_2")); err != nil {
t.Fatal(err)
} else if string(v) != "b" {
t.Fatal(string(v))
}
if v, err := db.Get([]byte("tx_select_1")); err != nil {
t.Fatal(err)
} else if v != nil {
t.Fatal("must nil")
}
}
func testTx(t *testing.T, name string) {
cfg := new(config.Config)
cfg.DataDir = "/tmp/ledis_test_tx"
@ -164,6 +209,7 @@ func testTx(t *testing.T, name string) {
testTxRollback(t, db)
testTxCommit(t, db)
testTxSelect(t, db)
}
//only lmdb, boltdb support Transaction

View File

@ -27,6 +27,8 @@ type App struct {
m *master
info *info
s *script
}
func netType(s string) string {
@ -85,6 +87,8 @@ func NewApp(cfg *config.Config) (*App, error) {
app.m = newMaster(app)
app.openScript()
return app, nil
}
@ -103,6 +107,8 @@ func (app *App) Close() {
app.httpListener.Close()
}
app.closeScript()
app.m.Close()
if app.access != nil {

View File

@ -16,6 +16,18 @@ var txUnsupportedCmds = map[string]struct{}{
"begin": struct{}{},
"flushall": struct{}{},
"flushdb": struct{}{},
"eval": struct{}{},
}
var scriptUnsupportedCmds = map[string]struct{}{
"slaveof": struct{}{},
"fullsync": struct{}{},
"sync": struct{}{},
"begin": struct{}{},
"commit": struct{}{},
"rollback": struct{}{},
"flushall": struct{}{},
"flushdb": struct{}{},
}
type responseWriter interface {
@ -34,6 +46,7 @@ type responseWriter interface {
type client struct {
app *App
ldb *ledis.Ledis
db *ledis.DB
remoteAddr string
@ -50,6 +63,7 @@ type client struct {
buf bytes.Buffer
tx *ledis.Tx
script *ledis.Multi
}
func newClient(app *App) *client {
@ -59,16 +73,12 @@ func newClient(app *App) *client {
c.ldb = app.ldb
c.db, _ = app.ldb.Select(0) //use default db
c.compressBuf = make([]byte, 256)
c.compressBuf = []byte{}
c.reqErr = make(chan error)
return c
}
func (c *client) isInTransaction() bool {
return c.tx != nil
}
func (c *client) perform() {
var err error
@ -79,10 +89,14 @@ func (c *client) perform() {
} else if exeCmd, ok := regCmds[c.cmd]; !ok {
err = ErrNotFound
} else {
if c.isInTransaction() {
if c.db.IsTransaction() {
if _, ok := txUnsupportedCmds[c.cmd]; ok {
err = fmt.Errorf("%s not supported in transaction", c.cmd)
}
} else if c.db.IsInMulti() {
if _, ok := scriptUnsupportedCmds[c.cmd]; ok {
err = fmt.Errorf("%s not supported in multi", c.cmd)
}
}
if err == nil {
@ -128,3 +142,22 @@ func (c *client) catGenericCommand() []byte {
return buffer.Bytes()
}
func writeValue(w responseWriter, value interface{}) {
switch v := value.(type) {
case []interface{}:
w.writeArray(v)
case [][]byte:
w.writeSliceArray(v)
case []byte:
w.writeBulk(v)
case string:
w.writeStatus(v)
case nil:
w.writeBulk(nil)
case int64:
w.writeInteger(v)
default:
panic("invalid value type")
}
}

View File

@ -1,158 +1,207 @@
package server
import (
"encoding/hex"
"fmt"
"github.com/aarzilli/golua/lua"
"github.com/siddontang/ledisdb/ledis"
"io"
"strconv"
"strings"
)
//ledis <-> lua type conversion, same as http://redis.io/commands/eval
type luaClient struct {
l *lua.State
}
type luaWriter struct {
l *lua.State
}
func (w *luaWriter) writeError(err error) {
w.l.NewTable()
top := w.l.GetTop()
w.l.PushString("err")
w.l.PushString(err.Error())
w.l.SetTable(top)
}
func (w *luaWriter) writeStatus(status string) {
w.l.NewTable()
top := w.l.GetTop()
w.l.PushString("ok")
w.l.PushString(status)
w.l.SetTable(top)
}
func (w *luaWriter) writeInteger(n int64) {
w.l.PushInteger(n)
}
func (w *luaWriter) writeBulk(b []byte) {
if b == nil {
w.l.PushBoolean(false)
} else {
w.l.PushString(ledis.String(b))
}
}
func (w *luaWriter) writeArray(lst []interface{}) {
if lst == nil {
w.l.PushBoolean(false)
return
func parseEvalArgs(l *lua.State, c *client) error {
args := c.args
if len(args) < 2 {
return ErrCmdParams
}
base := w.l.GetTop()
args = args[1:]
n, err := strconv.Atoi(ledis.String(args[0]))
if err != nil {
return err
}
if n > len(args)-1 {
return ErrCmdParams
}
luaSetGlobalArray(l, "KEYS", args[1:n+1])
luaSetGlobalArray(l, "ARGV", args[n+1:])
return nil
}
func evalGenericCommand(c *client, evalSha1 bool) error {
m, err := c.db.Multi()
if err != nil {
return err
}
s := c.app.s
luaClient := s.c
l := s.l
s.Lock()
base := l.GetTop()
defer func() {
if e := recover(); e != nil {
w.l.SetTop(base)
w.writeError(fmt.Errorf("%v", e))
}
l.SetTop(base)
luaClient.db = nil
luaClient.script = nil
s.Unlock()
m.Close()
}()
w.l.CreateTable(len(lst), 0)
top := w.l.GetTop()
luaClient.db = m.DB
luaClient.script = m
luaClient.remoteAddr = c.remoteAddr
for i, _ := range lst {
w.l.PushInteger(int64(i) + 1)
switch v := lst[i].(type) {
case []interface{}:
w.writeArray(v)
case [][]byte:
w.writeSliceArray(v)
case []byte:
w.writeBulk(v)
case nil:
w.writeBulk(nil)
case int64:
w.writeInteger(v)
default:
panic("invalid array type")
if err := parseEvalArgs(l, c); err != nil {
return err
}
w.l.SetTable(top)
}
}
func (w *luaWriter) writeSliceArray(lst [][]byte) {
if lst == nil {
w.l.PushBoolean(false)
return
}
w.l.CreateTable(len(lst), 0)
top := w.l.GetTop()
for i, v := range lst {
w.l.PushInteger(int64(i) + 1)
w.l.PushString(ledis.String(v))
w.l.SetTable(top)
}
}
func (w *luaWriter) writeFVPairArray(lst []ledis.FVPair) {
if lst == nil {
w.l.PushBoolean(false)
return
}
w.l.CreateTable(len(lst)*2, 0)
top := w.l.GetTop()
for i, v := range lst {
w.l.PushInteger(int64(2*i) + 1)
w.l.PushString(ledis.String(v.Field))
w.l.SetTable(top)
w.l.PushInteger(int64(2*i) + 2)
w.l.PushString(ledis.String(v.Value))
w.l.SetTable(top)
}
}
func (w *luaWriter) writeScorePairArray(lst []ledis.ScorePair, withScores bool) {
if lst == nil {
w.l.PushBoolean(false)
return
}
if withScores {
w.l.CreateTable(len(lst)*2, 0)
top := w.l.GetTop()
for i, v := range lst {
w.l.PushInteger(int64(2*i) + 1)
w.l.PushString(ledis.String(v.Member))
w.l.SetTable(top)
w.l.PushInteger(int64(2*i) + 2)
w.l.PushString(ledis.String(ledis.StrPutInt64(v.Score)))
w.l.SetTable(top)
}
var sha1 string
if !evalSha1 {
sha1 = hex.EncodeToString(c.args[0])
} else {
w.l.CreateTable(len(lst), 0)
top := w.l.GetTop()
for i, v := range lst {
w.l.PushInteger(int64(i) + 1)
w.l.PushString(ledis.String(v.Member))
w.l.SetTable(top)
sha1 = ledis.String(c.args[0])
}
l.GetGlobal(sha1)
if l.IsNil(-1) {
l.Pop(1)
if evalSha1 {
return fmt.Errorf("missing %s script", sha1)
}
if r := l.LoadString(ledis.String(c.args[0])); r != 0 {
err := fmt.Errorf("%s", l.ToString(-1))
l.Pop(1)
return err
} else {
l.PushValue(-1)
l.SetGlobal(sha1)
s.chunks[sha1] = struct{}{}
}
}
if err := l.Call(0, lua.LUA_MULTRET); err != nil {
return err
} else {
r := luaReplyToLedisReply(l)
m.Close()
writeValue(c.resp, r)
}
return nil
}
func (w *luaWriter) writeBulkFrom(n int64, rb io.Reader) {
w.writeError(fmt.Errorf("unsupport"))
func evalCommand(c *client) error {
return evalGenericCommand(c, false)
}
func (w *luaWriter) flush() {
func evalshaCommand(c *client) error {
return evalGenericCommand(c, true)
}
func scriptCommand(c *client) error {
s := c.app.s
l := s.l
s.Lock()
base := l.GetTop()
defer func() {
l.SetTop(base)
s.Unlock()
}()
args := c.args
if len(args) < 1 {
return ErrCmdParams
}
switch strings.ToLower(c.cmd) {
case "script load":
return scriptLoadCommand(c)
case "script exists":
return scriptExistsCommand(c)
case "script flush":
return scriptFlushCommand(c)
default:
return fmt.Errorf("invalid scirpt cmd %s", args[0])
}
return nil
}
func scriptLoadCommand(c *client) error {
s := c.app.s
l := s.l
if len(c.args) != 1 {
return ErrCmdParams
}
sha1 := hex.EncodeToString(c.args[1])
if r := l.LoadString(ledis.String(c.args[1])); r != 0 {
err := fmt.Errorf("%s", l.ToString(-1))
l.Pop(1)
return err
} else {
l.PushValue(-1)
l.SetGlobal(sha1)
s.chunks[sha1] = struct{}{}
}
c.resp.writeBulk(ledis.Slice(sha1))
return nil
}
func scriptExistsCommand(c *client) error {
s := c.app.s
ay := make([]interface{}, len(c.args[1:]))
for i, n := range c.args[1:] {
if _, ok := s.chunks[ledis.String(n)]; ok {
ay[i] = int64(1)
} else {
ay[i] = int64(0)
}
}
c.resp.writeArray(ay)
return nil
}
func scriptFlushCommand(c *client) error {
s := c.app.s
l := s.l
for n, _ := range s.chunks {
l.PushNil()
l.SetGlobal(n)
}
c.resp.writeStatus(OK)
return nil
}
func init() {
register("eval", evalCommand)
register("evalsha", evalshaCommand)
register("script load", scriptCommand)
register("script flush", scriptCommand)
register("script exists", scriptCommand)
}

29
server/cmd_script_test.go Normal file
View File

@ -0,0 +1,29 @@
package server
import (
"fmt"
"github.com/siddontang/ledisdb/client/go/ledis"
"reflect"
"testing"
)
func TestCmdEval(t *testing.T) {
c := getTestConn()
defer c.Close()
if v, err := ledis.Strings(c.Do("eval", "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", 2, "key1", "key2", "first", "second")); err != nil {
t.Fatal(err)
} else if len(v) != 4 {
t.Fatal(err)
} else if !reflect.DeepEqual(v, []string{"key1", "key2", "first", "second"}) {
t.Fatal(fmt.Sprintf("%v", v))
}
if v, err := ledis.Strings(c.Do("eval", "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", 2, "key1", "key2", "first", "second")); err != nil {
t.Fatal(err)
} else if len(v) != 4 {
t.Fatal(err)
} else if !reflect.DeepEqual(v, []string{"key1", "key2", "first", "second"}) {
t.Fatal(fmt.Sprintf("%v", v))
}
}

View File

@ -40,14 +40,28 @@ func selectCommand(c *client) error {
if index, err := strconv.Atoi(ledis.String(c.args[0])); err != nil {
return err
} else {
if c.db.IsTransaction() {
if err := c.tx.Select(index); err != nil {
return err
} else {
c.db = c.tx.DB
}
} else if c.db.IsInMulti() {
if err := c.script.Select(index); err != nil {
return err
} else {
c.db = c.script.DB
}
} else {
if db, err := c.ldb.Select(index); err != nil {
return err
} else {
c.db = db
c.resp.writeStatus(OK)
}
}
c.resp.writeStatus(OK)
}
return nil
}

380
server/script.go Normal file
View File

@ -0,0 +1,380 @@
package server
import (
"encoding/hex"
"fmt"
"github.com/aarzilli/golua/lua"
"github.com/siddontang/ledisdb/ledis"
"io"
"sync"
)
//ledis <-> lua type conversion, same as http://redis.io/commands/eval
type luaWriter struct {
l *lua.State
}
func (w *luaWriter) writeError(err error) {
panic(err)
}
func (w *luaWriter) writeStatus(status string) {
w.l.NewTable()
top := w.l.GetTop()
w.l.PushString("ok")
w.l.PushString(status)
w.l.SetTable(top)
}
func (w *luaWriter) writeInteger(n int64) {
w.l.PushInteger(n)
}
func (w *luaWriter) writeBulk(b []byte) {
if b == nil {
w.l.PushBoolean(false)
} else {
w.l.PushString(ledis.String(b))
}
}
func (w *luaWriter) writeArray(lst []interface{}) {
if lst == nil {
w.l.PushBoolean(false)
return
}
w.l.CreateTable(len(lst), 0)
top := w.l.GetTop()
for i, _ := range lst {
w.l.PushInteger(int64(i) + 1)
switch v := lst[i].(type) {
case []interface{}:
w.writeArray(v)
case [][]byte:
w.writeSliceArray(v)
case []byte:
w.writeBulk(v)
case nil:
w.writeBulk(nil)
case int64:
w.writeInteger(v)
default:
panic("invalid array type")
}
w.l.SetTable(top)
}
}
func (w *luaWriter) writeSliceArray(lst [][]byte) {
if lst == nil {
w.l.PushBoolean(false)
return
}
w.l.CreateTable(len(lst), 0)
for i, v := range lst {
w.l.PushString(ledis.String(v))
w.l.RawSeti(-2, i+1)
}
}
func (w *luaWriter) writeFVPairArray(lst []ledis.FVPair) {
if lst == nil {
w.l.PushBoolean(false)
return
}
w.l.CreateTable(len(lst)*2, 0)
for i, v := range lst {
w.l.PushString(ledis.String(v.Field))
w.l.RawSeti(-2, 2*i+1)
w.l.PushString(ledis.String(v.Value))
w.l.RawSeti(-2, 2*i+2)
}
}
func (w *luaWriter) writeScorePairArray(lst []ledis.ScorePair, withScores bool) {
if lst == nil {
w.l.PushBoolean(false)
return
}
if withScores {
w.l.CreateTable(len(lst)*2, 0)
for i, v := range lst {
w.l.PushString(ledis.String(v.Member))
w.l.RawSeti(-2, 2*i+1)
w.l.PushString(ledis.String(ledis.StrPutInt64(v.Score)))
w.l.RawSeti(-2, 2*i+2)
}
} else {
w.l.CreateTable(len(lst), 0)
for i, v := range lst {
w.l.PushString(ledis.String(v.Member))
w.l.RawSeti(-2, i+1)
}
}
}
func (w *luaWriter) writeBulkFrom(n int64, rb io.Reader) {
w.writeError(fmt.Errorf("unsupport"))
}
func (w *luaWriter) flush() {
}
type script struct {
sync.Mutex
app *App
l *lua.State
c *client
chunks map[string]struct{}
}
func (app *App) openScript() {
s := new(script)
s.app = app
s.chunks = make(map[string]struct{})
app.s = s
l := lua.NewState()
l.OpenBase()
l.OpenLibs()
l.OpenMath()
l.OpenString()
l.OpenTable()
l.OpenPackage()
s.l = l
s.c = newClient(app)
s.c.db = nil
w := new(luaWriter)
w.l = l
s.c.resp = w
l.NewTable()
l.PushString("call")
l.PushGoFunction(luaCall)
l.SetTable(-3)
l.PushString("pcall")
l.PushGoFunction(luaPCall)
l.SetTable(-3)
l.PushString("sha1hex")
l.PushGoFunction(luaSha1Hex)
l.SetTable(-3)
l.PushString("error_reply")
l.PushGoFunction(luaErrorReply)
l.SetTable(-3)
l.PushString("status_reply")
l.PushGoFunction(luaStatusReply)
l.SetTable(-3)
l.SetGlobal("ledis")
setMapState(l, s)
}
func (app *App) closeScript() {
app.s.l.Close()
delMapState(app.s.l)
app.s = nil
}
var mapState = map[*lua.State]*script{}
var stateLock sync.Mutex
func setMapState(l *lua.State, s *script) {
stateLock.Lock()
defer stateLock.Unlock()
mapState[l] = s
}
func getMapState(l *lua.State) *script {
stateLock.Lock()
defer stateLock.Unlock()
return mapState[l]
}
func delMapState(l *lua.State) {
stateLock.Lock()
defer stateLock.Unlock()
delete(mapState, l)
}
func luaCall(l *lua.State) int {
return luaCallGenericCommand(l)
}
func luaPCall(l *lua.State) (n int) {
defer func() {
if e := recover(); e != nil {
luaPushError(l, fmt.Sprintf("%v", e))
n = 1
}
return
}()
return luaCallGenericCommand(l)
}
func luaErrorReply(l *lua.State) int {
return luaReturnSingleFieldTable(l, "err")
}
func luaStatusReply(l *lua.State) int {
return luaReturnSingleFieldTable(l, "ok")
}
func luaReturnSingleFieldTable(l *lua.State, filed string) int {
if l.GetTop() != 1 || l.Type(-1) != lua.LUA_TSTRING {
luaPushError(l, "wrong number or type of arguments")
return 1
}
l.NewTable()
l.PushString(filed)
l.PushValue(-3)
l.SetTable(-3)
return 1
}
func luaSha1Hex(l *lua.State) int {
argc := l.GetTop()
if argc != 1 {
luaPushError(l, "wrong number of arguments")
return 1
}
s := l.ToString(1)
s = hex.EncodeToString(ledis.Slice(s))
l.PushString(s)
return 1
}
func luaPushError(l *lua.State, msg string) {
l.NewTable()
l.PushString("err")
err := l.NewError(msg)
l.PushString(err.Error())
l.SetTable(-3)
}
func luaCallGenericCommand(l *lua.State) int {
s := getMapState(l)
if s == nil {
panic("Invalid lua call")
} else if s.c.db == nil {
panic("Invalid lua call, not prepared")
}
c := s.c
argc := l.GetTop()
if argc < 1 {
panic("Please specify at least one argument for ledis.call()")
}
c.cmd = l.ToString(1)
c.args = make([][]byte, argc-1)
for i := 2; i <= argc; i++ {
switch l.Type(i) {
case lua.LUA_TNUMBER:
c.args[i-2] = []byte(fmt.Sprintf("%.17g", l.ToNumber(i)))
case lua.LUA_TSTRING:
c.args[i-2] = []byte(l.ToString(i))
default:
panic("Lua ledis() command arguments must be strings or integers")
}
}
c.perform()
return 1
}
func luaSetGlobalArray(l *lua.State, name string, ay [][]byte) {
l.NewTable()
for i := 0; i < len(ay); i++ {
l.PushString(ledis.String(ay[i]))
l.RawSeti(-2, i+1)
}
l.SetGlobal(name)
}
func luaReplyToLedisReply(l *lua.State) interface{} {
base := l.GetTop()
defer func() {
l.SetTop(base - 1)
}()
switch l.Type(-1) {
case lua.LUA_TSTRING:
return ledis.Slice(l.ToString(-1))
case lua.LUA_TBOOLEAN:
if l.ToBoolean(-1) {
return int64(1)
} else {
return nil
}
case lua.LUA_TNUMBER:
return int64(l.ToInteger(-1))
case lua.LUA_TTABLE:
l.PushString("err")
l.GetTable(-2)
if l.Type(-1) == lua.LUA_TSTRING {
return fmt.Errorf("%s", l.ToString(-1))
}
l.Pop(1)
l.PushString("ok")
l.GetTable(-2)
if l.Type(-1) == lua.LUA_TSTRING {
return l.ToString(-1)
} else {
l.Pop(1)
ay := make([]interface{}, 0)
for i := 1; ; i++ {
l.PushInteger(int64(i))
l.GetTable(-2)
if l.Type(-1) == lua.LUA_TNIL {
l.Pop(1)
break
}
ay = append(ay, luaReplyToLedisReply(l))
}
return ay
}
default:
return nil
}
}

177
server/script_test.go Normal file
View File

@ -0,0 +1,177 @@
package server
import (
"fmt"
"github.com/aarzilli/golua/lua"
"github.com/siddontang/ledisdb/config"
"testing"
)
var testLuaWriter = &luaWriter{}
func testLuaWriteError(l *lua.State) int {
testLuaWriter.writeError(fmt.Errorf("test error"))
return 1
}
func testLuaWriteArray(l *lua.State) int {
ay := make([]interface{}, 2)
ay[0] = []byte("1")
b := make([]interface{}, 2)
b[0] = int64(10)
b[1] = []byte("11")
ay[1] = b
testLuaWriter.writeArray(ay)
return 1
}
func TestLuaWriter(t *testing.T) {
l := lua.NewState()
l.OpenBase()
testLuaWriter.l = l
l.Register("WriteError", testLuaWriteError)
str := `
WriteError()
`
err := l.DoString(str)
if err == nil {
t.Fatal("must error")
}
l.Register("WriteArray", testLuaWriteArray)
str = `
local a = WriteArray()
if #a ~= 2 then
error("len a must 2")
elseif a[1] ~= "1" then
error("a[1] must 1")
elseif #a[2] ~= 2 then
error("len a[2] must 2")
elseif a[2][1] ~= 10 then
error("a[2][1] must 10")
elseif a[2][2] ~= "11" then
error("a[2][2] must 11")
end
`
err = l.DoString(str)
if err != nil {
t.Fatal(err)
}
l.Close()
}
var testScript1 = `
return {1,2,3}
`
var testScript2 = `
return ledis.call("ping")
`
var testScript3 = `
ledis.call("set", 1, "a")
local a = ledis.call("get", 1)
if type(a) ~= "string" then
error("must string")
elseif a ~= "a" then
error("must a")
end
`
var testScript4 = `
ledis.call("select", 2)
ledis.call("set", 2, "b")
`
func TestLuaCall(t *testing.T) {
cfg := new(config.Config)
cfg.Addr = ":11188"
cfg.DataDir = "/tmp/testscript"
cfg.DBName = "memory"
app, e := NewApp(cfg)
if e != nil {
t.Fatal(e)
}
go app.Run()
defer app.Close()
db, _ := app.ldb.Select(0)
m, _ := db.Multi()
defer m.Close()
luaClient := app.s.c
luaClient.db = m.DB
luaClient.script = m
l := app.s.l
err := app.s.l.DoString(testScript1)
if err != nil {
t.Fatal(err)
}
v := luaReplyToLedisReply(l)
if vv, ok := v.([]interface{}); ok {
if len(vv) != 3 {
t.Fatal(len(vv))
}
} else {
t.Fatal(fmt.Sprintf("%v %T", v, v))
}
err = app.s.l.DoString(testScript2)
if err != nil {
t.Fatal(err)
}
v = luaReplyToLedisReply(l)
if vv := v.(string); vv != "PONG" {
t.Fatal(fmt.Sprintf("%v %T", v, v))
}
err = app.s.l.DoString(testScript3)
if err != nil {
t.Fatal(err)
}
if v, err := db.Get([]byte("1")); err != nil {
t.Fatal(err)
} else if string(v) != "a" {
t.Fatal(string(v))
}
err = app.s.l.DoString(testScript4)
if err != nil {
t.Fatal(err)
}
if luaClient.db.Index() != 2 {
t.Fatal(luaClient.db.Index())
}
db2, _ := app.ldb.Select(2)
if v, err := db2.Get([]byte("2")); err != nil {
t.Fatal(err)
} else if string(v) != "b" {
t.Fatal(string(v))
}
luaClient.db = nil
}