forked from mirror/ledisdb
add lua support
This commit is contained in:
parent
93f3cc5343
commit
ab1ae62bf7
|
@ -0,0 +1,105 @@
|
|||
package ledis
|
||||
|
||||
import (
|
||||
"github.com/siddontang/ledisdb/store"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type batch struct {
|
||||
l *Ledis
|
||||
|
||||
store.WriteBatch
|
||||
|
||||
sync.Locker
|
||||
|
||||
logs [][]byte
|
||||
|
||||
tx *Tx
|
||||
}
|
||||
|
||||
func (b *batch) Commit() error {
|
||||
b.l.commitLock.Lock()
|
||||
defer b.l.commitLock.Unlock()
|
||||
|
||||
err := b.WriteBatch.Commit()
|
||||
|
||||
if b.l.binlog != nil {
|
||||
if err == nil {
|
||||
if b.tx == nil {
|
||||
b.l.binlog.Log(b.logs...)
|
||||
} else {
|
||||
b.tx.logs = append(b.tx.logs, b.logs...)
|
||||
}
|
||||
}
|
||||
b.logs = [][]byte{}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *batch) Lock() {
|
||||
b.Locker.Lock()
|
||||
}
|
||||
|
||||
func (b *batch) Unlock() {
|
||||
if b.l.binlog != nil {
|
||||
b.logs = [][]byte{}
|
||||
}
|
||||
b.WriteBatch.Rollback()
|
||||
b.Locker.Unlock()
|
||||
}
|
||||
|
||||
func (b *batch) Put(key []byte, value []byte) {
|
||||
if b.l.binlog != nil {
|
||||
buf := encodeBinLogPut(key, value)
|
||||
b.logs = append(b.logs, buf)
|
||||
}
|
||||
b.WriteBatch.Put(key, value)
|
||||
}
|
||||
|
||||
func (b *batch) Delete(key []byte) {
|
||||
if b.l.binlog != nil {
|
||||
buf := encodeBinLogDelete(key)
|
||||
b.logs = append(b.logs, buf)
|
||||
}
|
||||
b.WriteBatch.Delete(key)
|
||||
}
|
||||
|
||||
type dbBatchLocker struct {
|
||||
l *sync.Mutex
|
||||
wrLock *sync.RWMutex
|
||||
}
|
||||
|
||||
func (l *dbBatchLocker) Lock() {
|
||||
l.wrLock.RLock()
|
||||
l.l.Lock()
|
||||
}
|
||||
|
||||
func (l *dbBatchLocker) Unlock() {
|
||||
l.l.Unlock()
|
||||
l.wrLock.RUnlock()
|
||||
}
|
||||
|
||||
type txBatchLocker struct {
|
||||
}
|
||||
|
||||
func (l *txBatchLocker) Lock() {}
|
||||
func (l *txBatchLocker) Unlock() {}
|
||||
|
||||
type multiBatchLocker struct {
|
||||
}
|
||||
|
||||
func (l *multiBatchLocker) Lock() {}
|
||||
func (l *multiBatchLocker) Unlock() {}
|
||||
|
||||
func (l *Ledis) newBatch(wb store.WriteBatch, locker sync.Locker, tx *Tx) *batch {
|
||||
b := new(batch)
|
||||
b.l = l
|
||||
b.WriteBatch = wb
|
||||
|
||||
b.tx = tx
|
||||
b.Locker = locker
|
||||
|
||||
b.logs = [][]byte{}
|
||||
return b
|
||||
}
|
|
@ -86,3 +86,9 @@ const (
|
|||
BinLogTypePut uint8 = 0x1
|
||||
BinLogTypeCommand uint8 = 0x2
|
||||
)
|
||||
|
||||
const (
|
||||
DBAutoCommit uint8 = 0x0
|
||||
DBInTransaction uint8 = 0x1
|
||||
DBInMulti uint8 = 0x2
|
||||
)
|
||||
|
|
|
@ -3,6 +3,7 @@ package ledis
|
|||
import (
|
||||
"fmt"
|
||||
"github.com/siddontang/ledisdb/store"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type ibucket interface {
|
||||
|
@ -37,7 +38,7 @@ type DB struct {
|
|||
binBatch *batch
|
||||
setBatch *batch
|
||||
|
||||
isTx bool
|
||||
status uint8
|
||||
}
|
||||
|
||||
func (l *Ledis) newDB(index uint8) *DB {
|
||||
|
@ -49,7 +50,7 @@ func (l *Ledis) newDB(index uint8) *DB {
|
|||
|
||||
d.bucket = d.sdb
|
||||
|
||||
d.isTx = false
|
||||
d.status = DBAutoCommit
|
||||
d.index = index
|
||||
|
||||
d.kvBatch = d.newBatch()
|
||||
|
@ -62,10 +63,18 @@ func (l *Ledis) newDB(index uint8) *DB {
|
|||
return d
|
||||
}
|
||||
|
||||
func (db *DB) newBatch() *batch {
|
||||
return db.l.newBatch(db.bucket.NewWriteBatch(), &dbBatchLocker{l: &sync.Mutex{}, wrLock: &db.l.wLock}, nil)
|
||||
}
|
||||
|
||||
func (db *DB) Index() int {
|
||||
return int(db.index)
|
||||
}
|
||||
|
||||
func (db *DB) IsAutoCommit() bool {
|
||||
return db.status == DBAutoCommit
|
||||
}
|
||||
|
||||
func (db *DB) FlushAll() (drop int64, err error) {
|
||||
all := [...](func() (int64, error)){
|
||||
db.flush,
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
package ledis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNestMulti = errors.New("nest multi not supported")
|
||||
ErrMultiDone = errors.New("multi has been closed")
|
||||
)
|
||||
|
||||
type Multi struct {
|
||||
*DB
|
||||
}
|
||||
|
||||
func (db *DB) IsInMulti() bool {
|
||||
return db.status == DBInMulti
|
||||
}
|
||||
|
||||
// begin a mutli to execute commands,
|
||||
// it will block any other write operations before you close the multi, unlike transaction, mutli can not rollback
|
||||
func (db *DB) Multi() (*Multi, error) {
|
||||
if db.IsInMulti() {
|
||||
return nil, ErrNestMulti
|
||||
}
|
||||
|
||||
m := new(Multi)
|
||||
|
||||
m.DB = new(DB)
|
||||
m.DB.status = DBInMulti
|
||||
|
||||
m.DB.l = db.l
|
||||
|
||||
m.l.wLock.Lock()
|
||||
|
||||
m.DB.sdb = db.sdb
|
||||
|
||||
m.DB.bucket = db.sdb
|
||||
|
||||
m.DB.index = db.index
|
||||
|
||||
m.DB.kvBatch = m.newBatch()
|
||||
m.DB.listBatch = m.newBatch()
|
||||
m.DB.hashBatch = m.newBatch()
|
||||
m.DB.zsetBatch = m.newBatch()
|
||||
m.DB.binBatch = m.newBatch()
|
||||
m.DB.setBatch = m.newBatch()
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Multi) newBatch() *batch {
|
||||
return m.l.newBatch(m.bucket.NewWriteBatch(), &multiBatchLocker{}, nil)
|
||||
}
|
||||
|
||||
func (m *Multi) Close() error {
|
||||
if m.bucket == nil {
|
||||
return ErrMultiDone
|
||||
}
|
||||
m.l.wLock.Unlock()
|
||||
m.bucket = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Multi) Select(index int) error {
|
||||
if index < 0 || index >= int(MaxDBNumber) {
|
||||
return fmt.Errorf("invalid db index %d", index)
|
||||
}
|
||||
|
||||
m.DB.index = uint8(index)
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
package ledis
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMulti(t *testing.T) {
|
||||
db := getTestDB()
|
||||
|
||||
key := []byte("test_multi_1")
|
||||
v1 := []byte("v1")
|
||||
v2 := []byte("v2")
|
||||
|
||||
m, err := db.Multi()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
if err := db.Set(key, v2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
if err := m.Set(key, v1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v, err := m.Get(key); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if string(v) != string(v1) {
|
||||
t.Fatal(string(v))
|
||||
}
|
||||
|
||||
m.Close()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if v, err := db.Get(key); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if string(v) != string(v2) {
|
||||
t.Fatal(string(v))
|
||||
}
|
||||
|
||||
}
|
|
@ -70,6 +70,12 @@ func TestReplication(t *testing.T) {
|
|||
db.HSet([]byte("c"), []byte("3"), []byte("value"))
|
||||
}
|
||||
|
||||
m, _ := db.Multi()
|
||||
m.Set([]byte("a1"), []byte("value"))
|
||||
m.Set([]byte("b1"), []byte("value"))
|
||||
m.Set([]byte("c1"), []byte("value"))
|
||||
m.Close()
|
||||
|
||||
for _, name := range master.binlog.LogNames() {
|
||||
p := path.Join(master.binlog.LogPath(), name)
|
||||
|
||||
|
|
209
ledis/tx.go
209
ledis/tx.go
|
@ -1,209 +0,0 @@
|
|||
package ledis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/siddontang/ledisdb/store"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNestTx = errors.New("nest transaction not supported")
|
||||
ErrTxDone = errors.New("Transaction has already been committed or rolled back")
|
||||
)
|
||||
|
||||
type batch struct {
|
||||
l *Ledis
|
||||
|
||||
store.WriteBatch
|
||||
|
||||
sync.Locker
|
||||
|
||||
logs [][]byte
|
||||
|
||||
tx *Tx
|
||||
}
|
||||
|
||||
type dbBatchLocker struct {
|
||||
l *sync.Mutex
|
||||
wrLock *sync.RWMutex
|
||||
}
|
||||
|
||||
func (l *dbBatchLocker) Lock() {
|
||||
l.wrLock.RLock()
|
||||
l.l.Lock()
|
||||
}
|
||||
|
||||
func (l *dbBatchLocker) Unlock() {
|
||||
l.l.Unlock()
|
||||
l.wrLock.RUnlock()
|
||||
}
|
||||
|
||||
type txBatchLocker struct {
|
||||
}
|
||||
|
||||
func (l *txBatchLocker) Lock() {}
|
||||
func (l *txBatchLocker) Unlock() {}
|
||||
|
||||
func (l *Ledis) newBatch(wb store.WriteBatch, tx *Tx) *batch {
|
||||
b := new(batch)
|
||||
b.l = l
|
||||
b.WriteBatch = wb
|
||||
|
||||
b.tx = tx
|
||||
if tx == nil {
|
||||
b.Locker = &dbBatchLocker{l: &sync.Mutex{}, wrLock: &l.wLock}
|
||||
} else {
|
||||
b.Locker = &txBatchLocker{}
|
||||
}
|
||||
|
||||
b.logs = [][]byte{}
|
||||
return b
|
||||
}
|
||||
|
||||
func (db *DB) newBatch() *batch {
|
||||
return db.l.newBatch(db.bucket.NewWriteBatch(), nil)
|
||||
}
|
||||
|
||||
func (b *batch) Commit() error {
|
||||
b.l.commitLock.Lock()
|
||||
defer b.l.commitLock.Unlock()
|
||||
|
||||
err := b.WriteBatch.Commit()
|
||||
|
||||
if b.l.binlog != nil {
|
||||
if err == nil {
|
||||
if b.tx == nil {
|
||||
b.l.binlog.Log(b.logs...)
|
||||
} else {
|
||||
b.tx.logs = append(b.tx.logs, b.logs...)
|
||||
}
|
||||
}
|
||||
b.logs = [][]byte{}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *batch) Lock() {
|
||||
b.Locker.Lock()
|
||||
}
|
||||
|
||||
func (b *batch) Unlock() {
|
||||
if b.l.binlog != nil {
|
||||
b.logs = [][]byte{}
|
||||
}
|
||||
b.WriteBatch.Rollback()
|
||||
b.Locker.Unlock()
|
||||
}
|
||||
|
||||
func (b *batch) Put(key []byte, value []byte) {
|
||||
if b.l.binlog != nil {
|
||||
buf := encodeBinLogPut(key, value)
|
||||
b.logs = append(b.logs, buf)
|
||||
}
|
||||
b.WriteBatch.Put(key, value)
|
||||
}
|
||||
|
||||
func (b *batch) Delete(key []byte) {
|
||||
if b.l.binlog != nil {
|
||||
buf := encodeBinLogDelete(key)
|
||||
b.logs = append(b.logs, buf)
|
||||
}
|
||||
b.WriteBatch.Delete(key)
|
||||
}
|
||||
|
||||
type Tx struct {
|
||||
*DB
|
||||
|
||||
tx *store.Tx
|
||||
|
||||
logs [][]byte
|
||||
|
||||
index uint8
|
||||
}
|
||||
|
||||
func (db *DB) IsTransaction() bool {
|
||||
return db.isTx
|
||||
}
|
||||
|
||||
// Begin a transaction, it will block all other write operations before calling Commit or Rollback.
|
||||
// You must be very careful to prevent long-time transaction.
|
||||
func (db *DB) Begin() (*Tx, error) {
|
||||
if db.isTx {
|
||||
return nil, ErrNestTx
|
||||
}
|
||||
|
||||
tx := new(Tx)
|
||||
|
||||
tx.DB = new(DB)
|
||||
tx.DB.l = db.l
|
||||
|
||||
tx.l.wLock.Lock()
|
||||
|
||||
tx.index = db.index
|
||||
|
||||
tx.DB.sdb = db.sdb
|
||||
|
||||
var err error
|
||||
tx.tx, err = db.sdb.Begin()
|
||||
if err != nil {
|
||||
tx.l.wLock.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tx.DB.bucket = tx.tx
|
||||
|
||||
tx.DB.isTx = true
|
||||
|
||||
tx.DB.index = db.index
|
||||
|
||||
tx.DB.kvBatch = tx.newTxBatch()
|
||||
tx.DB.listBatch = tx.newTxBatch()
|
||||
tx.DB.hashBatch = tx.newTxBatch()
|
||||
tx.DB.zsetBatch = tx.newTxBatch()
|
||||
tx.DB.binBatch = tx.newTxBatch()
|
||||
tx.DB.setBatch = tx.newTxBatch()
|
||||
|
||||
return tx, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) Commit() error {
|
||||
if tx.tx == nil {
|
||||
return ErrTxDone
|
||||
}
|
||||
|
||||
tx.l.commitLock.Lock()
|
||||
err := tx.tx.Commit()
|
||||
tx.tx = nil
|
||||
|
||||
if len(tx.logs) > 0 {
|
||||
tx.l.binlog.Log(tx.logs...)
|
||||
}
|
||||
|
||||
tx.l.commitLock.Unlock()
|
||||
|
||||
tx.l.wLock.Unlock()
|
||||
tx.DB = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (tx *Tx) Rollback() error {
|
||||
if tx.tx == nil {
|
||||
return ErrTxDone
|
||||
}
|
||||
|
||||
err := tx.tx.Rollback()
|
||||
tx.tx = nil
|
||||
|
||||
tx.l.wLock.Unlock()
|
||||
tx.DB = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (tx *Tx) newTxBatch() *batch {
|
||||
return tx.l.newBatch(tx.tx.NewWriteBatch(), tx)
|
||||
}
|
||||
|
||||
func (tx *Tx) Index() int {
|
||||
return int(tx.index)
|
||||
}
|
|
@ -144,6 +144,51 @@ func testTxCommit(t *testing.T, db *DB) {
|
|||
}
|
||||
}
|
||||
|
||||
func testTxSelect(t *testing.T, db *DB) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer tx.Rollback()
|
||||
|
||||
tx.Set([]byte("tx_select_1"), []byte("a"))
|
||||
|
||||
tx.Select(1)
|
||||
|
||||
tx.Set([]byte("tx_select_2"), []byte("b"))
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v, err := db.Get([]byte("tx_select_1")); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if string(v) != "a" {
|
||||
t.Fatal(string(v))
|
||||
}
|
||||
|
||||
if v, err := db.Get([]byte("tx_select_2")); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if v != nil {
|
||||
t.Fatal("must nil")
|
||||
}
|
||||
|
||||
db, _ = db.l.Select(1)
|
||||
|
||||
if v, err := db.Get([]byte("tx_select_2")); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if string(v) != "b" {
|
||||
t.Fatal(string(v))
|
||||
}
|
||||
|
||||
if v, err := db.Get([]byte("tx_select_1")); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if v != nil {
|
||||
t.Fatal("must nil")
|
||||
}
|
||||
}
|
||||
|
||||
func testTx(t *testing.T, name string) {
|
||||
cfg := new(config.Config)
|
||||
cfg.DataDir = "/tmp/ledis_test_tx"
|
||||
|
@ -164,6 +209,7 @@ func testTx(t *testing.T, name string) {
|
|||
|
||||
testTxRollback(t, db)
|
||||
testTxCommit(t, db)
|
||||
testTxSelect(t, db)
|
||||
}
|
||||
|
||||
//only lmdb, boltdb support Transaction
|
||||
|
|
|
@ -27,6 +27,8 @@ type App struct {
|
|||
m *master
|
||||
|
||||
info *info
|
||||
|
||||
s *script
|
||||
}
|
||||
|
||||
func netType(s string) string {
|
||||
|
@ -85,6 +87,8 @@ func NewApp(cfg *config.Config) (*App, error) {
|
|||
|
||||
app.m = newMaster(app)
|
||||
|
||||
app.openScript()
|
||||
|
||||
return app, nil
|
||||
}
|
||||
|
||||
|
@ -103,6 +107,8 @@ func (app *App) Close() {
|
|||
app.httpListener.Close()
|
||||
}
|
||||
|
||||
app.closeScript()
|
||||
|
||||
app.m.Close()
|
||||
|
||||
if app.access != nil {
|
||||
|
|
|
@ -16,6 +16,18 @@ var txUnsupportedCmds = map[string]struct{}{
|
|||
"begin": struct{}{},
|
||||
"flushall": struct{}{},
|
||||
"flushdb": struct{}{},
|
||||
"eval": struct{}{},
|
||||
}
|
||||
|
||||
var scriptUnsupportedCmds = map[string]struct{}{
|
||||
"slaveof": struct{}{},
|
||||
"fullsync": struct{}{},
|
||||
"sync": struct{}{},
|
||||
"begin": struct{}{},
|
||||
"commit": struct{}{},
|
||||
"rollback": struct{}{},
|
||||
"flushall": struct{}{},
|
||||
"flushdb": struct{}{},
|
||||
}
|
||||
|
||||
type responseWriter interface {
|
||||
|
@ -34,6 +46,7 @@ type responseWriter interface {
|
|||
type client struct {
|
||||
app *App
|
||||
ldb *ledis.Ledis
|
||||
|
||||
db *ledis.DB
|
||||
|
||||
remoteAddr string
|
||||
|
@ -50,6 +63,7 @@ type client struct {
|
|||
buf bytes.Buffer
|
||||
|
||||
tx *ledis.Tx
|
||||
script *ledis.Multi
|
||||
}
|
||||
|
||||
func newClient(app *App) *client {
|
||||
|
@ -59,16 +73,12 @@ func newClient(app *App) *client {
|
|||
c.ldb = app.ldb
|
||||
c.db, _ = app.ldb.Select(0) //use default db
|
||||
|
||||
c.compressBuf = make([]byte, 256)
|
||||
c.compressBuf = []byte{}
|
||||
c.reqErr = make(chan error)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *client) isInTransaction() bool {
|
||||
return c.tx != nil
|
||||
}
|
||||
|
||||
func (c *client) perform() {
|
||||
var err error
|
||||
|
||||
|
@ -79,10 +89,14 @@ func (c *client) perform() {
|
|||
} else if exeCmd, ok := regCmds[c.cmd]; !ok {
|
||||
err = ErrNotFound
|
||||
} else {
|
||||
if c.isInTransaction() {
|
||||
if c.db.IsTransaction() {
|
||||
if _, ok := txUnsupportedCmds[c.cmd]; ok {
|
||||
err = fmt.Errorf("%s not supported in transaction", c.cmd)
|
||||
}
|
||||
} else if c.db.IsInMulti() {
|
||||
if _, ok := scriptUnsupportedCmds[c.cmd]; ok {
|
||||
err = fmt.Errorf("%s not supported in multi", c.cmd)
|
||||
}
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
|
@ -128,3 +142,22 @@ func (c *client) catGenericCommand() []byte {
|
|||
|
||||
return buffer.Bytes()
|
||||
}
|
||||
|
||||
func writeValue(w responseWriter, value interface{}) {
|
||||
switch v := value.(type) {
|
||||
case []interface{}:
|
||||
w.writeArray(v)
|
||||
case [][]byte:
|
||||
w.writeSliceArray(v)
|
||||
case []byte:
|
||||
w.writeBulk(v)
|
||||
case string:
|
||||
w.writeStatus(v)
|
||||
case nil:
|
||||
w.writeBulk(nil)
|
||||
case int64:
|
||||
w.writeInteger(v)
|
||||
default:
|
||||
panic("invalid value type")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,158 +1,207 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"github.com/aarzilli/golua/lua"
|
||||
"github.com/siddontang/ledisdb/ledis"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//ledis <-> lua type conversion, same as http://redis.io/commands/eval
|
||||
|
||||
type luaClient struct {
|
||||
l *lua.State
|
||||
}
|
||||
|
||||
type luaWriter struct {
|
||||
l *lua.State
|
||||
}
|
||||
|
||||
func (w *luaWriter) writeError(err error) {
|
||||
w.l.NewTable()
|
||||
top := w.l.GetTop()
|
||||
|
||||
w.l.PushString("err")
|
||||
w.l.PushString(err.Error())
|
||||
w.l.SetTable(top)
|
||||
}
|
||||
|
||||
func (w *luaWriter) writeStatus(status string) {
|
||||
w.l.NewTable()
|
||||
top := w.l.GetTop()
|
||||
|
||||
w.l.PushString("ok")
|
||||
w.l.PushString(status)
|
||||
w.l.SetTable(top)
|
||||
}
|
||||
|
||||
func (w *luaWriter) writeInteger(n int64) {
|
||||
w.l.PushInteger(n)
|
||||
}
|
||||
|
||||
func (w *luaWriter) writeBulk(b []byte) {
|
||||
if b == nil {
|
||||
w.l.PushBoolean(false)
|
||||
} else {
|
||||
w.l.PushString(ledis.String(b))
|
||||
}
|
||||
}
|
||||
|
||||
func (w *luaWriter) writeArray(lst []interface{}) {
|
||||
if lst == nil {
|
||||
w.l.PushBoolean(false)
|
||||
return
|
||||
func parseEvalArgs(l *lua.State, c *client) error {
|
||||
args := c.args
|
||||
if len(args) < 2 {
|
||||
return ErrCmdParams
|
||||
}
|
||||
|
||||
base := w.l.GetTop()
|
||||
args = args[1:]
|
||||
|
||||
n, err := strconv.Atoi(ledis.String(args[0]))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if n > len(args)-1 {
|
||||
return ErrCmdParams
|
||||
}
|
||||
|
||||
luaSetGlobalArray(l, "KEYS", args[1:n+1])
|
||||
luaSetGlobalArray(l, "ARGV", args[n+1:])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func evalGenericCommand(c *client, evalSha1 bool) error {
|
||||
m, err := c.db.Multi()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s := c.app.s
|
||||
luaClient := s.c
|
||||
l := s.l
|
||||
|
||||
s.Lock()
|
||||
|
||||
base := l.GetTop()
|
||||
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
w.l.SetTop(base)
|
||||
w.writeError(fmt.Errorf("%v", e))
|
||||
}
|
||||
l.SetTop(base)
|
||||
luaClient.db = nil
|
||||
luaClient.script = nil
|
||||
|
||||
s.Unlock()
|
||||
|
||||
m.Close()
|
||||
}()
|
||||
|
||||
w.l.CreateTable(len(lst), 0)
|
||||
top := w.l.GetTop()
|
||||
luaClient.db = m.DB
|
||||
luaClient.script = m
|
||||
luaClient.remoteAddr = c.remoteAddr
|
||||
|
||||
for i, _ := range lst {
|
||||
w.l.PushInteger(int64(i) + 1)
|
||||
|
||||
switch v := lst[i].(type) {
|
||||
case []interface{}:
|
||||
w.writeArray(v)
|
||||
case [][]byte:
|
||||
w.writeSliceArray(v)
|
||||
case []byte:
|
||||
w.writeBulk(v)
|
||||
case nil:
|
||||
w.writeBulk(nil)
|
||||
case int64:
|
||||
w.writeInteger(v)
|
||||
default:
|
||||
panic("invalid array type")
|
||||
if err := parseEvalArgs(l, c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.l.SetTable(top)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *luaWriter) writeSliceArray(lst [][]byte) {
|
||||
if lst == nil {
|
||||
w.l.PushBoolean(false)
|
||||
return
|
||||
}
|
||||
|
||||
w.l.CreateTable(len(lst), 0)
|
||||
top := w.l.GetTop()
|
||||
for i, v := range lst {
|
||||
w.l.PushInteger(int64(i) + 1)
|
||||
w.l.PushString(ledis.String(v))
|
||||
w.l.SetTable(top)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *luaWriter) writeFVPairArray(lst []ledis.FVPair) {
|
||||
if lst == nil {
|
||||
w.l.PushBoolean(false)
|
||||
return
|
||||
}
|
||||
|
||||
w.l.CreateTable(len(lst)*2, 0)
|
||||
top := w.l.GetTop()
|
||||
for i, v := range lst {
|
||||
w.l.PushInteger(int64(2*i) + 1)
|
||||
w.l.PushString(ledis.String(v.Field))
|
||||
w.l.SetTable(top)
|
||||
|
||||
w.l.PushInteger(int64(2*i) + 2)
|
||||
w.l.PushString(ledis.String(v.Value))
|
||||
w.l.SetTable(top)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *luaWriter) writeScorePairArray(lst []ledis.ScorePair, withScores bool) {
|
||||
if lst == nil {
|
||||
w.l.PushBoolean(false)
|
||||
return
|
||||
}
|
||||
|
||||
if withScores {
|
||||
w.l.CreateTable(len(lst)*2, 0)
|
||||
top := w.l.GetTop()
|
||||
for i, v := range lst {
|
||||
w.l.PushInteger(int64(2*i) + 1)
|
||||
w.l.PushString(ledis.String(v.Member))
|
||||
w.l.SetTable(top)
|
||||
|
||||
w.l.PushInteger(int64(2*i) + 2)
|
||||
w.l.PushString(ledis.String(ledis.StrPutInt64(v.Score)))
|
||||
w.l.SetTable(top)
|
||||
}
|
||||
var sha1 string
|
||||
if !evalSha1 {
|
||||
sha1 = hex.EncodeToString(c.args[0])
|
||||
} else {
|
||||
w.l.CreateTable(len(lst), 0)
|
||||
top := w.l.GetTop()
|
||||
for i, v := range lst {
|
||||
w.l.PushInteger(int64(i) + 1)
|
||||
w.l.PushString(ledis.String(v.Member))
|
||||
w.l.SetTable(top)
|
||||
sha1 = ledis.String(c.args[0])
|
||||
}
|
||||
|
||||
l.GetGlobal(sha1)
|
||||
|
||||
if l.IsNil(-1) {
|
||||
l.Pop(1)
|
||||
|
||||
if evalSha1 {
|
||||
return fmt.Errorf("missing %s script", sha1)
|
||||
}
|
||||
|
||||
if r := l.LoadString(ledis.String(c.args[0])); r != 0 {
|
||||
err := fmt.Errorf("%s", l.ToString(-1))
|
||||
l.Pop(1)
|
||||
return err
|
||||
} else {
|
||||
l.PushValue(-1)
|
||||
l.SetGlobal(sha1)
|
||||
|
||||
s.chunks[sha1] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if err := l.Call(0, lua.LUA_MULTRET); err != nil {
|
||||
return err
|
||||
} else {
|
||||
r := luaReplyToLedisReply(l)
|
||||
m.Close()
|
||||
writeValue(c.resp, r)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *luaWriter) writeBulkFrom(n int64, rb io.Reader) {
|
||||
w.writeError(fmt.Errorf("unsupport"))
|
||||
func evalCommand(c *client) error {
|
||||
return evalGenericCommand(c, false)
|
||||
}
|
||||
|
||||
func (w *luaWriter) flush() {
|
||||
|
||||
func evalshaCommand(c *client) error {
|
||||
return evalGenericCommand(c, true)
|
||||
}
|
||||
|
||||
func scriptCommand(c *client) error {
|
||||
s := c.app.s
|
||||
l := s.l
|
||||
|
||||
s.Lock()
|
||||
|
||||
base := l.GetTop()
|
||||
|
||||
defer func() {
|
||||
l.SetTop(base)
|
||||
s.Unlock()
|
||||
}()
|
||||
|
||||
args := c.args
|
||||
if len(args) < 1 {
|
||||
return ErrCmdParams
|
||||
}
|
||||
|
||||
switch strings.ToLower(c.cmd) {
|
||||
case "script load":
|
||||
return scriptLoadCommand(c)
|
||||
case "script exists":
|
||||
return scriptExistsCommand(c)
|
||||
case "script flush":
|
||||
return scriptFlushCommand(c)
|
||||
default:
|
||||
return fmt.Errorf("invalid scirpt cmd %s", args[0])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func scriptLoadCommand(c *client) error {
|
||||
s := c.app.s
|
||||
l := s.l
|
||||
|
||||
if len(c.args) != 1 {
|
||||
return ErrCmdParams
|
||||
}
|
||||
|
||||
sha1 := hex.EncodeToString(c.args[1])
|
||||
|
||||
if r := l.LoadString(ledis.String(c.args[1])); r != 0 {
|
||||
err := fmt.Errorf("%s", l.ToString(-1))
|
||||
l.Pop(1)
|
||||
return err
|
||||
} else {
|
||||
l.PushValue(-1)
|
||||
l.SetGlobal(sha1)
|
||||
|
||||
s.chunks[sha1] = struct{}{}
|
||||
}
|
||||
|
||||
c.resp.writeBulk(ledis.Slice(sha1))
|
||||
return nil
|
||||
}
|
||||
|
||||
func scriptExistsCommand(c *client) error {
|
||||
s := c.app.s
|
||||
|
||||
ay := make([]interface{}, len(c.args[1:]))
|
||||
for i, n := range c.args[1:] {
|
||||
if _, ok := s.chunks[ledis.String(n)]; ok {
|
||||
ay[i] = int64(1)
|
||||
} else {
|
||||
ay[i] = int64(0)
|
||||
}
|
||||
}
|
||||
|
||||
c.resp.writeArray(ay)
|
||||
return nil
|
||||
}
|
||||
|
||||
func scriptFlushCommand(c *client) error {
|
||||
s := c.app.s
|
||||
l := s.l
|
||||
|
||||
for n, _ := range s.chunks {
|
||||
l.PushNil()
|
||||
l.SetGlobal(n)
|
||||
}
|
||||
|
||||
c.resp.writeStatus(OK)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
register("eval", evalCommand)
|
||||
register("evalsha", evalshaCommand)
|
||||
register("script load", scriptCommand)
|
||||
register("script flush", scriptCommand)
|
||||
register("script exists", scriptCommand)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/siddontang/ledisdb/client/go/ledis"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCmdEval(t *testing.T) {
|
||||
c := getTestConn()
|
||||
defer c.Close()
|
||||
|
||||
if v, err := ledis.Strings(c.Do("eval", "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", 2, "key1", "key2", "first", "second")); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(v) != 4 {
|
||||
t.Fatal(err)
|
||||
} else if !reflect.DeepEqual(v, []string{"key1", "key2", "first", "second"}) {
|
||||
t.Fatal(fmt.Sprintf("%v", v))
|
||||
}
|
||||
|
||||
if v, err := ledis.Strings(c.Do("eval", "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", 2, "key1", "key2", "first", "second")); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(v) != 4 {
|
||||
t.Fatal(err)
|
||||
} else if !reflect.DeepEqual(v, []string{"key1", "key2", "first", "second"}) {
|
||||
t.Fatal(fmt.Sprintf("%v", v))
|
||||
}
|
||||
}
|
|
@ -40,14 +40,28 @@ func selectCommand(c *client) error {
|
|||
|
||||
if index, err := strconv.Atoi(ledis.String(c.args[0])); err != nil {
|
||||
return err
|
||||
} else {
|
||||
if c.db.IsTransaction() {
|
||||
if err := c.tx.Select(index); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.db = c.tx.DB
|
||||
}
|
||||
} else if c.db.IsInMulti() {
|
||||
if err := c.script.Select(index); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.db = c.script.DB
|
||||
}
|
||||
} else {
|
||||
if db, err := c.ldb.Select(index); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.db = db
|
||||
c.resp.writeStatus(OK)
|
||||
}
|
||||
}
|
||||
c.resp.writeStatus(OK)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,380 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"github.com/aarzilli/golua/lua"
|
||||
"github.com/siddontang/ledisdb/ledis"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
//ledis <-> lua type conversion, same as http://redis.io/commands/eval
|
||||
|
||||
type luaWriter struct {
|
||||
l *lua.State
|
||||
}
|
||||
|
||||
func (w *luaWriter) writeError(err error) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
func (w *luaWriter) writeStatus(status string) {
|
||||
w.l.NewTable()
|
||||
top := w.l.GetTop()
|
||||
|
||||
w.l.PushString("ok")
|
||||
w.l.PushString(status)
|
||||
w.l.SetTable(top)
|
||||
}
|
||||
|
||||
func (w *luaWriter) writeInteger(n int64) {
|
||||
w.l.PushInteger(n)
|
||||
}
|
||||
|
||||
func (w *luaWriter) writeBulk(b []byte) {
|
||||
if b == nil {
|
||||
w.l.PushBoolean(false)
|
||||
} else {
|
||||
w.l.PushString(ledis.String(b))
|
||||
}
|
||||
}
|
||||
|
||||
func (w *luaWriter) writeArray(lst []interface{}) {
|
||||
if lst == nil {
|
||||
w.l.PushBoolean(false)
|
||||
return
|
||||
}
|
||||
|
||||
w.l.CreateTable(len(lst), 0)
|
||||
top := w.l.GetTop()
|
||||
|
||||
for i, _ := range lst {
|
||||
w.l.PushInteger(int64(i) + 1)
|
||||
|
||||
switch v := lst[i].(type) {
|
||||
case []interface{}:
|
||||
w.writeArray(v)
|
||||
case [][]byte:
|
||||
w.writeSliceArray(v)
|
||||
case []byte:
|
||||
w.writeBulk(v)
|
||||
case nil:
|
||||
w.writeBulk(nil)
|
||||
case int64:
|
||||
w.writeInteger(v)
|
||||
default:
|
||||
panic("invalid array type")
|
||||
}
|
||||
|
||||
w.l.SetTable(top)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *luaWriter) writeSliceArray(lst [][]byte) {
|
||||
if lst == nil {
|
||||
w.l.PushBoolean(false)
|
||||
return
|
||||
}
|
||||
|
||||
w.l.CreateTable(len(lst), 0)
|
||||
for i, v := range lst {
|
||||
w.l.PushString(ledis.String(v))
|
||||
w.l.RawSeti(-2, i+1)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *luaWriter) writeFVPairArray(lst []ledis.FVPair) {
|
||||
if lst == nil {
|
||||
w.l.PushBoolean(false)
|
||||
return
|
||||
}
|
||||
|
||||
w.l.CreateTable(len(lst)*2, 0)
|
||||
for i, v := range lst {
|
||||
w.l.PushString(ledis.String(v.Field))
|
||||
w.l.RawSeti(-2, 2*i+1)
|
||||
|
||||
w.l.PushString(ledis.String(v.Value))
|
||||
w.l.RawSeti(-2, 2*i+2)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *luaWriter) writeScorePairArray(lst []ledis.ScorePair, withScores bool) {
|
||||
if lst == nil {
|
||||
w.l.PushBoolean(false)
|
||||
return
|
||||
}
|
||||
|
||||
if withScores {
|
||||
w.l.CreateTable(len(lst)*2, 0)
|
||||
for i, v := range lst {
|
||||
w.l.PushString(ledis.String(v.Member))
|
||||
w.l.RawSeti(-2, 2*i+1)
|
||||
|
||||
w.l.PushString(ledis.String(ledis.StrPutInt64(v.Score)))
|
||||
w.l.RawSeti(-2, 2*i+2)
|
||||
}
|
||||
} else {
|
||||
w.l.CreateTable(len(lst), 0)
|
||||
for i, v := range lst {
|
||||
w.l.PushString(ledis.String(v.Member))
|
||||
w.l.RawSeti(-2, i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *luaWriter) writeBulkFrom(n int64, rb io.Reader) {
|
||||
w.writeError(fmt.Errorf("unsupport"))
|
||||
}
|
||||
|
||||
func (w *luaWriter) flush() {
|
||||
|
||||
}
|
||||
|
||||
type script struct {
|
||||
sync.Mutex
|
||||
|
||||
app *App
|
||||
l *lua.State
|
||||
c *client
|
||||
|
||||
chunks map[string]struct{}
|
||||
}
|
||||
|
||||
func (app *App) openScript() {
|
||||
s := new(script)
|
||||
s.app = app
|
||||
|
||||
s.chunks = make(map[string]struct{})
|
||||
|
||||
app.s = s
|
||||
|
||||
l := lua.NewState()
|
||||
|
||||
l.OpenBase()
|
||||
l.OpenLibs()
|
||||
l.OpenMath()
|
||||
l.OpenString()
|
||||
l.OpenTable()
|
||||
l.OpenPackage()
|
||||
|
||||
s.l = l
|
||||
s.c = newClient(app)
|
||||
s.c.db = nil
|
||||
|
||||
w := new(luaWriter)
|
||||
w.l = l
|
||||
s.c.resp = w
|
||||
|
||||
l.NewTable()
|
||||
l.PushString("call")
|
||||
l.PushGoFunction(luaCall)
|
||||
l.SetTable(-3)
|
||||
|
||||
l.PushString("pcall")
|
||||
l.PushGoFunction(luaPCall)
|
||||
l.SetTable(-3)
|
||||
|
||||
l.PushString("sha1hex")
|
||||
l.PushGoFunction(luaSha1Hex)
|
||||
l.SetTable(-3)
|
||||
|
||||
l.PushString("error_reply")
|
||||
l.PushGoFunction(luaErrorReply)
|
||||
l.SetTable(-3)
|
||||
|
||||
l.PushString("status_reply")
|
||||
l.PushGoFunction(luaStatusReply)
|
||||
l.SetTable(-3)
|
||||
|
||||
l.SetGlobal("ledis")
|
||||
|
||||
setMapState(l, s)
|
||||
}
|
||||
|
||||
func (app *App) closeScript() {
|
||||
app.s.l.Close()
|
||||
delMapState(app.s.l)
|
||||
app.s = nil
|
||||
}
|
||||
|
||||
var mapState = map[*lua.State]*script{}
|
||||
var stateLock sync.Mutex
|
||||
|
||||
func setMapState(l *lua.State, s *script) {
|
||||
stateLock.Lock()
|
||||
defer stateLock.Unlock()
|
||||
|
||||
mapState[l] = s
|
||||
}
|
||||
|
||||
func getMapState(l *lua.State) *script {
|
||||
stateLock.Lock()
|
||||
defer stateLock.Unlock()
|
||||
|
||||
return mapState[l]
|
||||
}
|
||||
|
||||
func delMapState(l *lua.State) {
|
||||
stateLock.Lock()
|
||||
defer stateLock.Unlock()
|
||||
|
||||
delete(mapState, l)
|
||||
}
|
||||
|
||||
func luaCall(l *lua.State) int {
|
||||
return luaCallGenericCommand(l)
|
||||
}
|
||||
|
||||
func luaPCall(l *lua.State) (n int) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
luaPushError(l, fmt.Sprintf("%v", e))
|
||||
n = 1
|
||||
}
|
||||
return
|
||||
}()
|
||||
return luaCallGenericCommand(l)
|
||||
}
|
||||
|
||||
func luaErrorReply(l *lua.State) int {
|
||||
return luaReturnSingleFieldTable(l, "err")
|
||||
}
|
||||
|
||||
func luaStatusReply(l *lua.State) int {
|
||||
return luaReturnSingleFieldTable(l, "ok")
|
||||
}
|
||||
|
||||
func luaReturnSingleFieldTable(l *lua.State, filed string) int {
|
||||
if l.GetTop() != 1 || l.Type(-1) != lua.LUA_TSTRING {
|
||||
luaPushError(l, "wrong number or type of arguments")
|
||||
return 1
|
||||
}
|
||||
|
||||
l.NewTable()
|
||||
l.PushString(filed)
|
||||
l.PushValue(-3)
|
||||
l.SetTable(-3)
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaSha1Hex(l *lua.State) int {
|
||||
argc := l.GetTop()
|
||||
if argc != 1 {
|
||||
luaPushError(l, "wrong number of arguments")
|
||||
return 1
|
||||
}
|
||||
|
||||
s := l.ToString(1)
|
||||
s = hex.EncodeToString(ledis.Slice(s))
|
||||
|
||||
l.PushString(s)
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaPushError(l *lua.State, msg string) {
|
||||
l.NewTable()
|
||||
l.PushString("err")
|
||||
err := l.NewError(msg)
|
||||
l.PushString(err.Error())
|
||||
l.SetTable(-3)
|
||||
}
|
||||
|
||||
func luaCallGenericCommand(l *lua.State) int {
|
||||
s := getMapState(l)
|
||||
if s == nil {
|
||||
panic("Invalid lua call")
|
||||
} else if s.c.db == nil {
|
||||
panic("Invalid lua call, not prepared")
|
||||
}
|
||||
|
||||
c := s.c
|
||||
|
||||
argc := l.GetTop()
|
||||
if argc < 1 {
|
||||
panic("Please specify at least one argument for ledis.call()")
|
||||
}
|
||||
|
||||
c.cmd = l.ToString(1)
|
||||
|
||||
c.args = make([][]byte, argc-1)
|
||||
|
||||
for i := 2; i <= argc; i++ {
|
||||
switch l.Type(i) {
|
||||
case lua.LUA_TNUMBER:
|
||||
c.args[i-2] = []byte(fmt.Sprintf("%.17g", l.ToNumber(i)))
|
||||
case lua.LUA_TSTRING:
|
||||
c.args[i-2] = []byte(l.ToString(i))
|
||||
default:
|
||||
panic("Lua ledis() command arguments must be strings or integers")
|
||||
}
|
||||
}
|
||||
|
||||
c.perform()
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaSetGlobalArray(l *lua.State, name string, ay [][]byte) {
|
||||
l.NewTable()
|
||||
|
||||
for i := 0; i < len(ay); i++ {
|
||||
l.PushString(ledis.String(ay[i]))
|
||||
l.RawSeti(-2, i+1)
|
||||
}
|
||||
|
||||
l.SetGlobal(name)
|
||||
}
|
||||
|
||||
func luaReplyToLedisReply(l *lua.State) interface{} {
|
||||
base := l.GetTop()
|
||||
defer func() {
|
||||
l.SetTop(base - 1)
|
||||
}()
|
||||
|
||||
switch l.Type(-1) {
|
||||
case lua.LUA_TSTRING:
|
||||
return ledis.Slice(l.ToString(-1))
|
||||
case lua.LUA_TBOOLEAN:
|
||||
if l.ToBoolean(-1) {
|
||||
return int64(1)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
case lua.LUA_TNUMBER:
|
||||
return int64(l.ToInteger(-1))
|
||||
case lua.LUA_TTABLE:
|
||||
l.PushString("err")
|
||||
l.GetTable(-2)
|
||||
if l.Type(-1) == lua.LUA_TSTRING {
|
||||
return fmt.Errorf("%s", l.ToString(-1))
|
||||
}
|
||||
|
||||
l.Pop(1)
|
||||
l.PushString("ok")
|
||||
l.GetTable(-2)
|
||||
if l.Type(-1) == lua.LUA_TSTRING {
|
||||
return l.ToString(-1)
|
||||
} else {
|
||||
l.Pop(1)
|
||||
|
||||
ay := make([]interface{}, 0)
|
||||
|
||||
for i := 1; ; i++ {
|
||||
l.PushInteger(int64(i))
|
||||
l.GetTable(-2)
|
||||
if l.Type(-1) == lua.LUA_TNIL {
|
||||
l.Pop(1)
|
||||
break
|
||||
}
|
||||
|
||||
ay = append(ay, luaReplyToLedisReply(l))
|
||||
}
|
||||
return ay
|
||||
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
|
@ -0,0 +1,177 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/aarzilli/golua/lua"
|
||||
"github.com/siddontang/ledisdb/config"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
var testLuaWriter = &luaWriter{}
|
||||
|
||||
func testLuaWriteError(l *lua.State) int {
|
||||
testLuaWriter.writeError(fmt.Errorf("test error"))
|
||||
return 1
|
||||
}
|
||||
|
||||
func testLuaWriteArray(l *lua.State) int {
|
||||
ay := make([]interface{}, 2)
|
||||
ay[0] = []byte("1")
|
||||
b := make([]interface{}, 2)
|
||||
b[0] = int64(10)
|
||||
b[1] = []byte("11")
|
||||
|
||||
ay[1] = b
|
||||
|
||||
testLuaWriter.writeArray(ay)
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func TestLuaWriter(t *testing.T) {
|
||||
l := lua.NewState()
|
||||
|
||||
l.OpenBase()
|
||||
|
||||
testLuaWriter.l = l
|
||||
|
||||
l.Register("WriteError", testLuaWriteError)
|
||||
|
||||
str := `
|
||||
WriteError()
|
||||
`
|
||||
|
||||
err := l.DoString(str)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("must error")
|
||||
}
|
||||
|
||||
l.Register("WriteArray", testLuaWriteArray)
|
||||
|
||||
str = `
|
||||
local a = WriteArray()
|
||||
|
||||
if #a ~= 2 then
|
||||
error("len a must 2")
|
||||
elseif a[1] ~= "1" then
|
||||
error("a[1] must 1")
|
||||
elseif #a[2] ~= 2 then
|
||||
error("len a[2] must 2")
|
||||
elseif a[2][1] ~= 10 then
|
||||
error("a[2][1] must 10")
|
||||
elseif a[2][2] ~= "11" then
|
||||
error("a[2][2] must 11")
|
||||
end
|
||||
`
|
||||
|
||||
err = l.DoString(str)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
l.Close()
|
||||
}
|
||||
|
||||
var testScript1 = `
|
||||
return {1,2,3}
|
||||
`
|
||||
|
||||
var testScript2 = `
|
||||
return ledis.call("ping")
|
||||
`
|
||||
|
||||
var testScript3 = `
|
||||
ledis.call("set", 1, "a")
|
||||
|
||||
local a = ledis.call("get", 1)
|
||||
if type(a) ~= "string" then
|
||||
error("must string")
|
||||
elseif a ~= "a" then
|
||||
error("must a")
|
||||
end
|
||||
`
|
||||
|
||||
var testScript4 = `
|
||||
ledis.call("select", 2)
|
||||
ledis.call("set", 2, "b")
|
||||
`
|
||||
|
||||
func TestLuaCall(t *testing.T) {
|
||||
cfg := new(config.Config)
|
||||
cfg.Addr = ":11188"
|
||||
cfg.DataDir = "/tmp/testscript"
|
||||
cfg.DBName = "memory"
|
||||
|
||||
app, e := NewApp(cfg)
|
||||
if e != nil {
|
||||
t.Fatal(e)
|
||||
}
|
||||
go app.Run()
|
||||
|
||||
defer app.Close()
|
||||
|
||||
db, _ := app.ldb.Select(0)
|
||||
m, _ := db.Multi()
|
||||
defer m.Close()
|
||||
|
||||
luaClient := app.s.c
|
||||
luaClient.db = m.DB
|
||||
luaClient.script = m
|
||||
|
||||
l := app.s.l
|
||||
|
||||
err := app.s.l.DoString(testScript1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
v := luaReplyToLedisReply(l)
|
||||
if vv, ok := v.([]interface{}); ok {
|
||||
if len(vv) != 3 {
|
||||
t.Fatal(len(vv))
|
||||
}
|
||||
} else {
|
||||
t.Fatal(fmt.Sprintf("%v %T", v, v))
|
||||
}
|
||||
|
||||
err = app.s.l.DoString(testScript2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
v = luaReplyToLedisReply(l)
|
||||
if vv := v.(string); vv != "PONG" {
|
||||
t.Fatal(fmt.Sprintf("%v %T", v, v))
|
||||
}
|
||||
|
||||
err = app.s.l.DoString(testScript3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v, err := db.Get([]byte("1")); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if string(v) != "a" {
|
||||
t.Fatal(string(v))
|
||||
}
|
||||
|
||||
err = app.s.l.DoString(testScript4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if luaClient.db.Index() != 2 {
|
||||
t.Fatal(luaClient.db.Index())
|
||||
}
|
||||
|
||||
db2, _ := app.ldb.Select(2)
|
||||
if v, err := db2.Get([]byte("2")); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if string(v) != "b" {
|
||||
t.Fatal(string(v))
|
||||
}
|
||||
|
||||
luaClient.db = nil
|
||||
}
|
Loading…
Reference in New Issue