diff --git a/ledis/batch.go b/ledis/batch.go new file mode 100644 index 0000000..b23cc47 --- /dev/null +++ b/ledis/batch.go @@ -0,0 +1,105 @@ +package ledis + +import ( + "github.com/siddontang/ledisdb/store" + "sync" +) + +type batch struct { + l *Ledis + + store.WriteBatch + + sync.Locker + + logs [][]byte + + tx *Tx +} + +func (b *batch) Commit() error { + b.l.commitLock.Lock() + defer b.l.commitLock.Unlock() + + err := b.WriteBatch.Commit() + + if b.l.binlog != nil { + if err == nil { + if b.tx == nil { + b.l.binlog.Log(b.logs...) + } else { + b.tx.logs = append(b.tx.logs, b.logs...) + } + } + b.logs = [][]byte{} + } + + return err +} + +func (b *batch) Lock() { + b.Locker.Lock() +} + +func (b *batch) Unlock() { + if b.l.binlog != nil { + b.logs = [][]byte{} + } + b.WriteBatch.Rollback() + b.Locker.Unlock() +} + +func (b *batch) Put(key []byte, value []byte) { + if b.l.binlog != nil { + buf := encodeBinLogPut(key, value) + b.logs = append(b.logs, buf) + } + b.WriteBatch.Put(key, value) +} + +func (b *batch) Delete(key []byte) { + if b.l.binlog != nil { + buf := encodeBinLogDelete(key) + b.logs = append(b.logs, buf) + } + b.WriteBatch.Delete(key) +} + +type dbBatchLocker struct { + l *sync.Mutex + wrLock *sync.RWMutex +} + +func (l *dbBatchLocker) Lock() { + l.wrLock.RLock() + l.l.Lock() +} + +func (l *dbBatchLocker) Unlock() { + l.l.Unlock() + l.wrLock.RUnlock() +} + +type txBatchLocker struct { +} + +func (l *txBatchLocker) Lock() {} +func (l *txBatchLocker) Unlock() {} + +type multiBatchLocker struct { +} + +func (l *multiBatchLocker) Lock() {} +func (l *multiBatchLocker) Unlock() {} + +func (l *Ledis) newBatch(wb store.WriteBatch, locker sync.Locker, tx *Tx) *batch { + b := new(batch) + b.l = l + b.WriteBatch = wb + + b.tx = tx + b.Locker = locker + + b.logs = [][]byte{} + return b +} diff --git a/ledis/const.go b/ledis/const.go index ef416de..e889f4e 100644 --- a/ledis/const.go +++ b/ledis/const.go @@ -86,3 +86,9 @@ const ( BinLogTypePut uint8 = 0x1 BinLogTypeCommand uint8 = 0x2 ) + +const ( + DBAutoCommit uint8 = 0x0 + DBInTransaction uint8 = 0x1 + DBInMulti uint8 = 0x2 +) diff --git a/ledis/ledis_db.go b/ledis/ledis_db.go index 9241b1d..dd8ff74 100644 --- a/ledis/ledis_db.go +++ b/ledis/ledis_db.go @@ -3,6 +3,7 @@ package ledis import ( "fmt" "github.com/siddontang/ledisdb/store" + "sync" ) type ibucket interface { @@ -37,7 +38,7 @@ type DB struct { binBatch *batch setBatch *batch - isTx bool + status uint8 } func (l *Ledis) newDB(index uint8) *DB { @@ -49,7 +50,7 @@ func (l *Ledis) newDB(index uint8) *DB { d.bucket = d.sdb - d.isTx = false + d.status = DBAutoCommit d.index = index d.kvBatch = d.newBatch() @@ -62,10 +63,18 @@ func (l *Ledis) newDB(index uint8) *DB { return d } +func (db *DB) newBatch() *batch { + return db.l.newBatch(db.bucket.NewWriteBatch(), &dbBatchLocker{l: &sync.Mutex{}, wrLock: &db.l.wLock}, nil) +} + func (db *DB) Index() int { return int(db.index) } +func (db *DB) IsAutoCommit() bool { + return db.status == DBAutoCommit +} + func (db *DB) FlushAll() (drop int64, err error) { all := [...](func() (int64, error)){ db.flush, diff --git a/ledis/multi.go b/ledis/multi.go new file mode 100644 index 0000000..a549c2c --- /dev/null +++ b/ledis/multi.go @@ -0,0 +1,73 @@ +package ledis + +import ( + "errors" + "fmt" +) + +var ( + ErrNestMulti = errors.New("nest multi not supported") + ErrMultiDone = errors.New("multi has been closed") +) + +type Multi struct { + *DB +} + +func (db *DB) IsInMulti() bool { + return db.status == DBInMulti +} + +// begin a mutli to execute commands, +// it will block any other write operations before you close the multi, unlike transaction, mutli can not rollback +func (db *DB) Multi() (*Multi, error) { + if db.IsInMulti() { + return nil, ErrNestMulti + } + + m := new(Multi) + + m.DB = new(DB) + m.DB.status = DBInMulti + + m.DB.l = db.l + + m.l.wLock.Lock() + + m.DB.sdb = db.sdb + + m.DB.bucket = db.sdb + + m.DB.index = db.index + + m.DB.kvBatch = m.newBatch() + m.DB.listBatch = m.newBatch() + m.DB.hashBatch = m.newBatch() + m.DB.zsetBatch = m.newBatch() + m.DB.binBatch = m.newBatch() + m.DB.setBatch = m.newBatch() + + return m, nil +} + +func (m *Multi) newBatch() *batch { + return m.l.newBatch(m.bucket.NewWriteBatch(), &multiBatchLocker{}, nil) +} + +func (m *Multi) Close() error { + if m.bucket == nil { + return ErrMultiDone + } + m.l.wLock.Unlock() + m.bucket = nil + return nil +} + +func (m *Multi) Select(index int) error { + if index < 0 || index >= int(MaxDBNumber) { + return fmt.Errorf("invalid db index %d", index) + } + + m.DB.index = uint8(index) + return nil +} diff --git a/ledis/multi_test.go b/ledis/multi_test.go new file mode 100644 index 0000000..936c141 --- /dev/null +++ b/ledis/multi_test.go @@ -0,0 +1,51 @@ +package ledis + +import ( + "sync" + "testing" +) + +func TestMulti(t *testing.T) { + db := getTestDB() + + key := []byte("test_multi_1") + v1 := []byte("v1") + v2 := []byte("v2") + + m, err := db.Multi() + if err != nil { + t.Fatal(err) + } + + wg := sync.WaitGroup{} + + wg.Add(1) + + go func() { + if err := db.Set(key, v2); err != nil { + t.Fatal(err) + } + wg.Done() + }() + + if err := m.Set(key, v1); err != nil { + t.Fatal(err) + } + + if v, err := m.Get(key); err != nil { + t.Fatal(err) + } else if string(v) != string(v1) { + t.Fatal(string(v)) + } + + m.Close() + + wg.Wait() + + if v, err := db.Get(key); err != nil { + t.Fatal(err) + } else if string(v) != string(v2) { + t.Fatal(string(v)) + } + +} diff --git a/ledis/replication_test.go b/ledis/replication_test.go index 96bb10a..2a64a11 100644 --- a/ledis/replication_test.go +++ b/ledis/replication_test.go @@ -70,6 +70,12 @@ func TestReplication(t *testing.T) { db.HSet([]byte("c"), []byte("3"), []byte("value")) } + m, _ := db.Multi() + m.Set([]byte("a1"), []byte("value")) + m.Set([]byte("b1"), []byte("value")) + m.Set([]byte("c1"), []byte("value")) + m.Close() + for _, name := range master.binlog.LogNames() { p := path.Join(master.binlog.LogPath(), name) diff --git a/ledis/tx.go b/ledis/tx.go deleted file mode 100644 index 7488233..0000000 --- a/ledis/tx.go +++ /dev/null @@ -1,209 +0,0 @@ -package ledis - -import ( - "errors" - "github.com/siddontang/ledisdb/store" - "sync" -) - -var ( - ErrNestTx = errors.New("nest transaction not supported") - ErrTxDone = errors.New("Transaction has already been committed or rolled back") -) - -type batch struct { - l *Ledis - - store.WriteBatch - - sync.Locker - - logs [][]byte - - tx *Tx -} - -type dbBatchLocker struct { - l *sync.Mutex - wrLock *sync.RWMutex -} - -func (l *dbBatchLocker) Lock() { - l.wrLock.RLock() - l.l.Lock() -} - -func (l *dbBatchLocker) Unlock() { - l.l.Unlock() - l.wrLock.RUnlock() -} - -type txBatchLocker struct { -} - -func (l *txBatchLocker) Lock() {} -func (l *txBatchLocker) Unlock() {} - -func (l *Ledis) newBatch(wb store.WriteBatch, tx *Tx) *batch { - b := new(batch) - b.l = l - b.WriteBatch = wb - - b.tx = tx - if tx == nil { - b.Locker = &dbBatchLocker{l: &sync.Mutex{}, wrLock: &l.wLock} - } else { - b.Locker = &txBatchLocker{} - } - - b.logs = [][]byte{} - return b -} - -func (db *DB) newBatch() *batch { - return db.l.newBatch(db.bucket.NewWriteBatch(), nil) -} - -func (b *batch) Commit() error { - b.l.commitLock.Lock() - defer b.l.commitLock.Unlock() - - err := b.WriteBatch.Commit() - - if b.l.binlog != nil { - if err == nil { - if b.tx == nil { - b.l.binlog.Log(b.logs...) - } else { - b.tx.logs = append(b.tx.logs, b.logs...) - } - } - b.logs = [][]byte{} - } - - return err -} - -func (b *batch) Lock() { - b.Locker.Lock() -} - -func (b *batch) Unlock() { - if b.l.binlog != nil { - b.logs = [][]byte{} - } - b.WriteBatch.Rollback() - b.Locker.Unlock() -} - -func (b *batch) Put(key []byte, value []byte) { - if b.l.binlog != nil { - buf := encodeBinLogPut(key, value) - b.logs = append(b.logs, buf) - } - b.WriteBatch.Put(key, value) -} - -func (b *batch) Delete(key []byte) { - if b.l.binlog != nil { - buf := encodeBinLogDelete(key) - b.logs = append(b.logs, buf) - } - b.WriteBatch.Delete(key) -} - -type Tx struct { - *DB - - tx *store.Tx - - logs [][]byte - - index uint8 -} - -func (db *DB) IsTransaction() bool { - return db.isTx -} - -// Begin a transaction, it will block all other write operations before calling Commit or Rollback. -// You must be very careful to prevent long-time transaction. -func (db *DB) Begin() (*Tx, error) { - if db.isTx { - return nil, ErrNestTx - } - - tx := new(Tx) - - tx.DB = new(DB) - tx.DB.l = db.l - - tx.l.wLock.Lock() - - tx.index = db.index - - tx.DB.sdb = db.sdb - - var err error - tx.tx, err = db.sdb.Begin() - if err != nil { - tx.l.wLock.Unlock() - return nil, err - } - - tx.DB.bucket = tx.tx - - tx.DB.isTx = true - - tx.DB.index = db.index - - tx.DB.kvBatch = tx.newTxBatch() - tx.DB.listBatch = tx.newTxBatch() - tx.DB.hashBatch = tx.newTxBatch() - tx.DB.zsetBatch = tx.newTxBatch() - tx.DB.binBatch = tx.newTxBatch() - tx.DB.setBatch = tx.newTxBatch() - - return tx, nil -} - -func (tx *Tx) Commit() error { - if tx.tx == nil { - return ErrTxDone - } - - tx.l.commitLock.Lock() - err := tx.tx.Commit() - tx.tx = nil - - if len(tx.logs) > 0 { - tx.l.binlog.Log(tx.logs...) - } - - tx.l.commitLock.Unlock() - - tx.l.wLock.Unlock() - tx.DB = nil - return err -} - -func (tx *Tx) Rollback() error { - if tx.tx == nil { - return ErrTxDone - } - - err := tx.tx.Rollback() - tx.tx = nil - - tx.l.wLock.Unlock() - tx.DB = nil - return err -} - -func (tx *Tx) newTxBatch() *batch { - return tx.l.newBatch(tx.tx.NewWriteBatch(), tx) -} - -func (tx *Tx) Index() int { - return int(tx.index) -} diff --git a/ledis/tx_test.go b/ledis/tx_test.go index bf06012..026b70d 100644 --- a/ledis/tx_test.go +++ b/ledis/tx_test.go @@ -144,6 +144,51 @@ func testTxCommit(t *testing.T, db *DB) { } } +func testTxSelect(t *testing.T, db *DB) { + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + + defer tx.Rollback() + + tx.Set([]byte("tx_select_1"), []byte("a")) + + tx.Select(1) + + tx.Set([]byte("tx_select_2"), []byte("b")) + + if err = tx.Commit(); err != nil { + t.Fatal(err) + } + + if v, err := db.Get([]byte("tx_select_1")); err != nil { + t.Fatal(err) + } else if string(v) != "a" { + t.Fatal(string(v)) + } + + if v, err := db.Get([]byte("tx_select_2")); err != nil { + t.Fatal(err) + } else if v != nil { + t.Fatal("must nil") + } + + db, _ = db.l.Select(1) + + if v, err := db.Get([]byte("tx_select_2")); err != nil { + t.Fatal(err) + } else if string(v) != "b" { + t.Fatal(string(v)) + } + + if v, err := db.Get([]byte("tx_select_1")); err != nil { + t.Fatal(err) + } else if v != nil { + t.Fatal("must nil") + } +} + func testTx(t *testing.T, name string) { cfg := new(config.Config) cfg.DataDir = "/tmp/ledis_test_tx" @@ -164,6 +209,7 @@ func testTx(t *testing.T, name string) { testTxRollback(t, db) testTxCommit(t, db) + testTxSelect(t, db) } //only lmdb, boltdb support Transaction diff --git a/server/app.go b/server/app.go index d5c77c9..edd65c8 100644 --- a/server/app.go +++ b/server/app.go @@ -27,6 +27,8 @@ type App struct { m *master info *info + + s *script } func netType(s string) string { @@ -85,6 +87,8 @@ func NewApp(cfg *config.Config) (*App, error) { app.m = newMaster(app) + app.openScript() + return app, nil } @@ -103,6 +107,8 @@ func (app *App) Close() { app.httpListener.Close() } + app.closeScript() + app.m.Close() if app.access != nil { diff --git a/server/client.go b/server/client.go index f28a930..27e08b1 100644 --- a/server/client.go +++ b/server/client.go @@ -16,6 +16,18 @@ var txUnsupportedCmds = map[string]struct{}{ "begin": struct{}{}, "flushall": struct{}{}, "flushdb": struct{}{}, + "eval": struct{}{}, +} + +var scriptUnsupportedCmds = map[string]struct{}{ + "slaveof": struct{}{}, + "fullsync": struct{}{}, + "sync": struct{}{}, + "begin": struct{}{}, + "commit": struct{}{}, + "rollback": struct{}{}, + "flushall": struct{}{}, + "flushdb": struct{}{}, } type responseWriter interface { @@ -34,7 +46,8 @@ type responseWriter interface { type client struct { app *App ldb *ledis.Ledis - db *ledis.DB + + db *ledis.DB remoteAddr string cmd string @@ -49,7 +62,8 @@ type client struct { buf bytes.Buffer - tx *ledis.Tx + tx *ledis.Tx + script *ledis.Multi } func newClient(app *App) *client { @@ -59,16 +73,12 @@ func newClient(app *App) *client { c.ldb = app.ldb c.db, _ = app.ldb.Select(0) //use default db - c.compressBuf = make([]byte, 256) + c.compressBuf = []byte{} c.reqErr = make(chan error) return c } -func (c *client) isInTransaction() bool { - return c.tx != nil -} - func (c *client) perform() { var err error @@ -79,10 +89,14 @@ func (c *client) perform() { } else if exeCmd, ok := regCmds[c.cmd]; !ok { err = ErrNotFound } else { - if c.isInTransaction() { + if c.db.IsTransaction() { if _, ok := txUnsupportedCmds[c.cmd]; ok { err = fmt.Errorf("%s not supported in transaction", c.cmd) } + } else if c.db.IsInMulti() { + if _, ok := scriptUnsupportedCmds[c.cmd]; ok { + err = fmt.Errorf("%s not supported in multi", c.cmd) + } } if err == nil { @@ -128,3 +142,22 @@ func (c *client) catGenericCommand() []byte { return buffer.Bytes() } + +func writeValue(w responseWriter, value interface{}) { + switch v := value.(type) { + case []interface{}: + w.writeArray(v) + case [][]byte: + w.writeSliceArray(v) + case []byte: + w.writeBulk(v) + case string: + w.writeStatus(v) + case nil: + w.writeBulk(nil) + case int64: + w.writeInteger(v) + default: + panic("invalid value type") + } +} diff --git a/server/cmd_script.go b/server/cmd_script.go index 20f5b63..f0dabd1 100644 --- a/server/cmd_script.go +++ b/server/cmd_script.go @@ -1,158 +1,207 @@ package server import ( + "encoding/hex" "fmt" "github.com/aarzilli/golua/lua" "github.com/siddontang/ledisdb/ledis" - "io" + "strconv" + "strings" ) -//ledis <-> lua type conversion, same as http://redis.io/commands/eval - -type luaClient struct { - l *lua.State -} - -type luaWriter struct { - l *lua.State -} - -func (w *luaWriter) writeError(err error) { - w.l.NewTable() - top := w.l.GetTop() - - w.l.PushString("err") - w.l.PushString(err.Error()) - w.l.SetTable(top) -} - -func (w *luaWriter) writeStatus(status string) { - w.l.NewTable() - top := w.l.GetTop() - - w.l.PushString("ok") - w.l.PushString(status) - w.l.SetTable(top) -} - -func (w *luaWriter) writeInteger(n int64) { - w.l.PushInteger(n) -} - -func (w *luaWriter) writeBulk(b []byte) { - if b == nil { - w.l.PushBoolean(false) - } else { - w.l.PushString(ledis.String(b)) - } -} - -func (w *luaWriter) writeArray(lst []interface{}) { - if lst == nil { - w.l.PushBoolean(false) - return +func parseEvalArgs(l *lua.State, c *client) error { + args := c.args + if len(args) < 2 { + return ErrCmdParams } - base := w.l.GetTop() + args = args[1:] + + n, err := strconv.Atoi(ledis.String(args[0])) + if err != nil { + return err + } + + if n > len(args)-1 { + return ErrCmdParams + } + + luaSetGlobalArray(l, "KEYS", args[1:n+1]) + luaSetGlobalArray(l, "ARGV", args[n+1:]) + + return nil +} + +func evalGenericCommand(c *client, evalSha1 bool) error { + m, err := c.db.Multi() + if err != nil { + return err + } + + s := c.app.s + luaClient := s.c + l := s.l + + s.Lock() + + base := l.GetTop() + defer func() { - if e := recover(); e != nil { - w.l.SetTop(base) - w.writeError(fmt.Errorf("%v", e)) - } + l.SetTop(base) + luaClient.db = nil + luaClient.script = nil + + s.Unlock() + + m.Close() }() - w.l.CreateTable(len(lst), 0) - top := w.l.GetTop() + luaClient.db = m.DB + luaClient.script = m + luaClient.remoteAddr = c.remoteAddr - for i, _ := range lst { - w.l.PushInteger(int64(i) + 1) - - switch v := lst[i].(type) { - case []interface{}: - w.writeArray(v) - case [][]byte: - w.writeSliceArray(v) - case []byte: - w.writeBulk(v) - case nil: - w.writeBulk(nil) - case int64: - w.writeInteger(v) - default: - panic("invalid array type") - } - - w.l.SetTable(top) - } -} - -func (w *luaWriter) writeSliceArray(lst [][]byte) { - if lst == nil { - w.l.PushBoolean(false) - return + if err := parseEvalArgs(l, c); err != nil { + return err } - w.l.CreateTable(len(lst), 0) - top := w.l.GetTop() - for i, v := range lst { - w.l.PushInteger(int64(i) + 1) - w.l.PushString(ledis.String(v)) - w.l.SetTable(top) - } -} - -func (w *luaWriter) writeFVPairArray(lst []ledis.FVPair) { - if lst == nil { - w.l.PushBoolean(false) - return - } - - w.l.CreateTable(len(lst)*2, 0) - top := w.l.GetTop() - for i, v := range lst { - w.l.PushInteger(int64(2*i) + 1) - w.l.PushString(ledis.String(v.Field)) - w.l.SetTable(top) - - w.l.PushInteger(int64(2*i) + 2) - w.l.PushString(ledis.String(v.Value)) - w.l.SetTable(top) - } -} - -func (w *luaWriter) writeScorePairArray(lst []ledis.ScorePair, withScores bool) { - if lst == nil { - w.l.PushBoolean(false) - return - } - - if withScores { - w.l.CreateTable(len(lst)*2, 0) - top := w.l.GetTop() - for i, v := range lst { - w.l.PushInteger(int64(2*i) + 1) - w.l.PushString(ledis.String(v.Member)) - w.l.SetTable(top) - - w.l.PushInteger(int64(2*i) + 2) - w.l.PushString(ledis.String(ledis.StrPutInt64(v.Score))) - w.l.SetTable(top) - } + var sha1 string + if !evalSha1 { + sha1 = hex.EncodeToString(c.args[0]) } else { - w.l.CreateTable(len(lst), 0) - top := w.l.GetTop() - for i, v := range lst { - w.l.PushInteger(int64(i) + 1) - w.l.PushString(ledis.String(v.Member)) - w.l.SetTable(top) + sha1 = ledis.String(c.args[0]) + } + + l.GetGlobal(sha1) + + if l.IsNil(-1) { + l.Pop(1) + + if evalSha1 { + return fmt.Errorf("missing %s script", sha1) + } + + if r := l.LoadString(ledis.String(c.args[0])); r != 0 { + err := fmt.Errorf("%s", l.ToString(-1)) + l.Pop(1) + return err + } else { + l.PushValue(-1) + l.SetGlobal(sha1) + + s.chunks[sha1] = struct{}{} } } + + if err := l.Call(0, lua.LUA_MULTRET); err != nil { + return err + } else { + r := luaReplyToLedisReply(l) + m.Close() + writeValue(c.resp, r) + } + + return nil } -func (w *luaWriter) writeBulkFrom(n int64, rb io.Reader) { - w.writeError(fmt.Errorf("unsupport")) +func evalCommand(c *client) error { + return evalGenericCommand(c, false) } -func (w *luaWriter) flush() { - +func evalshaCommand(c *client) error { + return evalGenericCommand(c, true) +} + +func scriptCommand(c *client) error { + s := c.app.s + l := s.l + + s.Lock() + + base := l.GetTop() + + defer func() { + l.SetTop(base) + s.Unlock() + }() + + args := c.args + if len(args) < 1 { + return ErrCmdParams + } + + switch strings.ToLower(c.cmd) { + case "script load": + return scriptLoadCommand(c) + case "script exists": + return scriptExistsCommand(c) + case "script flush": + return scriptFlushCommand(c) + default: + return fmt.Errorf("invalid scirpt cmd %s", args[0]) + } + + return nil +} + +func scriptLoadCommand(c *client) error { + s := c.app.s + l := s.l + + if len(c.args) != 1 { + return ErrCmdParams + } + + sha1 := hex.EncodeToString(c.args[1]) + + if r := l.LoadString(ledis.String(c.args[1])); r != 0 { + err := fmt.Errorf("%s", l.ToString(-1)) + l.Pop(1) + return err + } else { + l.PushValue(-1) + l.SetGlobal(sha1) + + s.chunks[sha1] = struct{}{} + } + + c.resp.writeBulk(ledis.Slice(sha1)) + return nil +} + +func scriptExistsCommand(c *client) error { + s := c.app.s + + ay := make([]interface{}, len(c.args[1:])) + for i, n := range c.args[1:] { + if _, ok := s.chunks[ledis.String(n)]; ok { + ay[i] = int64(1) + } else { + ay[i] = int64(0) + } + } + + c.resp.writeArray(ay) + return nil +} + +func scriptFlushCommand(c *client) error { + s := c.app.s + l := s.l + + for n, _ := range s.chunks { + l.PushNil() + l.SetGlobal(n) + } + + c.resp.writeStatus(OK) + + return nil +} + +func init() { + register("eval", evalCommand) + register("evalsha", evalshaCommand) + register("script load", scriptCommand) + register("script flush", scriptCommand) + register("script exists", scriptCommand) } diff --git a/server/cmd_script_test.go b/server/cmd_script_test.go new file mode 100644 index 0000000..5c47866 --- /dev/null +++ b/server/cmd_script_test.go @@ -0,0 +1,29 @@ +package server + +import ( + "fmt" + "github.com/siddontang/ledisdb/client/go/ledis" + "reflect" + "testing" +) + +func TestCmdEval(t *testing.T) { + c := getTestConn() + defer c.Close() + + if v, err := ledis.Strings(c.Do("eval", "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", 2, "key1", "key2", "first", "second")); err != nil { + t.Fatal(err) + } else if len(v) != 4 { + t.Fatal(err) + } else if !reflect.DeepEqual(v, []string{"key1", "key2", "first", "second"}) { + t.Fatal(fmt.Sprintf("%v", v)) + } + + if v, err := ledis.Strings(c.Do("eval", "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", 2, "key1", "key2", "first", "second")); err != nil { + t.Fatal(err) + } else if len(v) != 4 { + t.Fatal(err) + } else if !reflect.DeepEqual(v, []string{"key1", "key2", "first", "second"}) { + t.Fatal(fmt.Sprintf("%v", v)) + } +} diff --git a/server/command.go b/server/command.go index af95244..458343b 100644 --- a/server/command.go +++ b/server/command.go @@ -41,12 +41,26 @@ func selectCommand(c *client) error { if index, err := strconv.Atoi(ledis.String(c.args[0])); err != nil { return err } else { - if db, err := c.ldb.Select(index); err != nil { - return err + if c.db.IsTransaction() { + if err := c.tx.Select(index); err != nil { + return err + } else { + c.db = c.tx.DB + } + } else if c.db.IsInMulti() { + if err := c.script.Select(index); err != nil { + return err + } else { + c.db = c.script.DB + } } else { - c.db = db - c.resp.writeStatus(OK) + if db, err := c.ldb.Select(index); err != nil { + return err + } else { + c.db = db + } } + c.resp.writeStatus(OK) } return nil diff --git a/server/script.go b/server/script.go new file mode 100644 index 0000000..c2c8ff2 --- /dev/null +++ b/server/script.go @@ -0,0 +1,380 @@ +package server + +import ( + "encoding/hex" + "fmt" + "github.com/aarzilli/golua/lua" + "github.com/siddontang/ledisdb/ledis" + "io" + "sync" +) + +//ledis <-> lua type conversion, same as http://redis.io/commands/eval + +type luaWriter struct { + l *lua.State +} + +func (w *luaWriter) writeError(err error) { + panic(err) +} + +func (w *luaWriter) writeStatus(status string) { + w.l.NewTable() + top := w.l.GetTop() + + w.l.PushString("ok") + w.l.PushString(status) + w.l.SetTable(top) +} + +func (w *luaWriter) writeInteger(n int64) { + w.l.PushInteger(n) +} + +func (w *luaWriter) writeBulk(b []byte) { + if b == nil { + w.l.PushBoolean(false) + } else { + w.l.PushString(ledis.String(b)) + } +} + +func (w *luaWriter) writeArray(lst []interface{}) { + if lst == nil { + w.l.PushBoolean(false) + return + } + + w.l.CreateTable(len(lst), 0) + top := w.l.GetTop() + + for i, _ := range lst { + w.l.PushInteger(int64(i) + 1) + + switch v := lst[i].(type) { + case []interface{}: + w.writeArray(v) + case [][]byte: + w.writeSliceArray(v) + case []byte: + w.writeBulk(v) + case nil: + w.writeBulk(nil) + case int64: + w.writeInteger(v) + default: + panic("invalid array type") + } + + w.l.SetTable(top) + } +} + +func (w *luaWriter) writeSliceArray(lst [][]byte) { + if lst == nil { + w.l.PushBoolean(false) + return + } + + w.l.CreateTable(len(lst), 0) + for i, v := range lst { + w.l.PushString(ledis.String(v)) + w.l.RawSeti(-2, i+1) + } +} + +func (w *luaWriter) writeFVPairArray(lst []ledis.FVPair) { + if lst == nil { + w.l.PushBoolean(false) + return + } + + w.l.CreateTable(len(lst)*2, 0) + for i, v := range lst { + w.l.PushString(ledis.String(v.Field)) + w.l.RawSeti(-2, 2*i+1) + + w.l.PushString(ledis.String(v.Value)) + w.l.RawSeti(-2, 2*i+2) + } +} + +func (w *luaWriter) writeScorePairArray(lst []ledis.ScorePair, withScores bool) { + if lst == nil { + w.l.PushBoolean(false) + return + } + + if withScores { + w.l.CreateTable(len(lst)*2, 0) + for i, v := range lst { + w.l.PushString(ledis.String(v.Member)) + w.l.RawSeti(-2, 2*i+1) + + w.l.PushString(ledis.String(ledis.StrPutInt64(v.Score))) + w.l.RawSeti(-2, 2*i+2) + } + } else { + w.l.CreateTable(len(lst), 0) + for i, v := range lst { + w.l.PushString(ledis.String(v.Member)) + w.l.RawSeti(-2, i+1) + } + } +} + +func (w *luaWriter) writeBulkFrom(n int64, rb io.Reader) { + w.writeError(fmt.Errorf("unsupport")) +} + +func (w *luaWriter) flush() { + +} + +type script struct { + sync.Mutex + + app *App + l *lua.State + c *client + + chunks map[string]struct{} +} + +func (app *App) openScript() { + s := new(script) + s.app = app + + s.chunks = make(map[string]struct{}) + + app.s = s + + l := lua.NewState() + + l.OpenBase() + l.OpenLibs() + l.OpenMath() + l.OpenString() + l.OpenTable() + l.OpenPackage() + + s.l = l + s.c = newClient(app) + s.c.db = nil + + w := new(luaWriter) + w.l = l + s.c.resp = w + + l.NewTable() + l.PushString("call") + l.PushGoFunction(luaCall) + l.SetTable(-3) + + l.PushString("pcall") + l.PushGoFunction(luaPCall) + l.SetTable(-3) + + l.PushString("sha1hex") + l.PushGoFunction(luaSha1Hex) + l.SetTable(-3) + + l.PushString("error_reply") + l.PushGoFunction(luaErrorReply) + l.SetTable(-3) + + l.PushString("status_reply") + l.PushGoFunction(luaStatusReply) + l.SetTable(-3) + + l.SetGlobal("ledis") + + setMapState(l, s) +} + +func (app *App) closeScript() { + app.s.l.Close() + delMapState(app.s.l) + app.s = nil +} + +var mapState = map[*lua.State]*script{} +var stateLock sync.Mutex + +func setMapState(l *lua.State, s *script) { + stateLock.Lock() + defer stateLock.Unlock() + + mapState[l] = s +} + +func getMapState(l *lua.State) *script { + stateLock.Lock() + defer stateLock.Unlock() + + return mapState[l] +} + +func delMapState(l *lua.State) { + stateLock.Lock() + defer stateLock.Unlock() + + delete(mapState, l) +} + +func luaCall(l *lua.State) int { + return luaCallGenericCommand(l) +} + +func luaPCall(l *lua.State) (n int) { + defer func() { + if e := recover(); e != nil { + luaPushError(l, fmt.Sprintf("%v", e)) + n = 1 + } + return + }() + return luaCallGenericCommand(l) +} + +func luaErrorReply(l *lua.State) int { + return luaReturnSingleFieldTable(l, "err") +} + +func luaStatusReply(l *lua.State) int { + return luaReturnSingleFieldTable(l, "ok") +} + +func luaReturnSingleFieldTable(l *lua.State, filed string) int { + if l.GetTop() != 1 || l.Type(-1) != lua.LUA_TSTRING { + luaPushError(l, "wrong number or type of arguments") + return 1 + } + + l.NewTable() + l.PushString(filed) + l.PushValue(-3) + l.SetTable(-3) + return 1 +} + +func luaSha1Hex(l *lua.State) int { + argc := l.GetTop() + if argc != 1 { + luaPushError(l, "wrong number of arguments") + return 1 + } + + s := l.ToString(1) + s = hex.EncodeToString(ledis.Slice(s)) + + l.PushString(s) + return 1 +} + +func luaPushError(l *lua.State, msg string) { + l.NewTable() + l.PushString("err") + err := l.NewError(msg) + l.PushString(err.Error()) + l.SetTable(-3) +} + +func luaCallGenericCommand(l *lua.State) int { + s := getMapState(l) + if s == nil { + panic("Invalid lua call") + } else if s.c.db == nil { + panic("Invalid lua call, not prepared") + } + + c := s.c + + argc := l.GetTop() + if argc < 1 { + panic("Please specify at least one argument for ledis.call()") + } + + c.cmd = l.ToString(1) + + c.args = make([][]byte, argc-1) + + for i := 2; i <= argc; i++ { + switch l.Type(i) { + case lua.LUA_TNUMBER: + c.args[i-2] = []byte(fmt.Sprintf("%.17g", l.ToNumber(i))) + case lua.LUA_TSTRING: + c.args[i-2] = []byte(l.ToString(i)) + default: + panic("Lua ledis() command arguments must be strings or integers") + } + } + + c.perform() + + return 1 +} + +func luaSetGlobalArray(l *lua.State, name string, ay [][]byte) { + l.NewTable() + + for i := 0; i < len(ay); i++ { + l.PushString(ledis.String(ay[i])) + l.RawSeti(-2, i+1) + } + + l.SetGlobal(name) +} + +func luaReplyToLedisReply(l *lua.State) interface{} { + base := l.GetTop() + defer func() { + l.SetTop(base - 1) + }() + + switch l.Type(-1) { + case lua.LUA_TSTRING: + return ledis.Slice(l.ToString(-1)) + case lua.LUA_TBOOLEAN: + if l.ToBoolean(-1) { + return int64(1) + } else { + return nil + } + case lua.LUA_TNUMBER: + return int64(l.ToInteger(-1)) + case lua.LUA_TTABLE: + l.PushString("err") + l.GetTable(-2) + if l.Type(-1) == lua.LUA_TSTRING { + return fmt.Errorf("%s", l.ToString(-1)) + } + + l.Pop(1) + l.PushString("ok") + l.GetTable(-2) + if l.Type(-1) == lua.LUA_TSTRING { + return l.ToString(-1) + } else { + l.Pop(1) + + ay := make([]interface{}, 0) + + for i := 1; ; i++ { + l.PushInteger(int64(i)) + l.GetTable(-2) + if l.Type(-1) == lua.LUA_TNIL { + l.Pop(1) + break + } + + ay = append(ay, luaReplyToLedisReply(l)) + } + return ay + + } + default: + return nil + } +} diff --git a/server/script_test.go b/server/script_test.go new file mode 100644 index 0000000..0d91231 --- /dev/null +++ b/server/script_test.go @@ -0,0 +1,177 @@ +package server + +import ( + "fmt" + "github.com/aarzilli/golua/lua" + "github.com/siddontang/ledisdb/config" + + "testing" +) + +var testLuaWriter = &luaWriter{} + +func testLuaWriteError(l *lua.State) int { + testLuaWriter.writeError(fmt.Errorf("test error")) + return 1 +} + +func testLuaWriteArray(l *lua.State) int { + ay := make([]interface{}, 2) + ay[0] = []byte("1") + b := make([]interface{}, 2) + b[0] = int64(10) + b[1] = []byte("11") + + ay[1] = b + + testLuaWriter.writeArray(ay) + + return 1 +} + +func TestLuaWriter(t *testing.T) { + l := lua.NewState() + + l.OpenBase() + + testLuaWriter.l = l + + l.Register("WriteError", testLuaWriteError) + + str := ` + WriteError() + ` + + err := l.DoString(str) + + if err == nil { + t.Fatal("must error") + } + + l.Register("WriteArray", testLuaWriteArray) + + str = ` + local a = WriteArray() + + if #a ~= 2 then + error("len a must 2") + elseif a[1] ~= "1" then + error("a[1] must 1") + elseif #a[2] ~= 2 then + error("len a[2] must 2") + elseif a[2][1] ~= 10 then + error("a[2][1] must 10") + elseif a[2][2] ~= "11" then + error("a[2][2] must 11") + end + ` + + err = l.DoString(str) + if err != nil { + t.Fatal(err) + } + + l.Close() +} + +var testScript1 = ` + return {1,2,3} +` + +var testScript2 = ` + return ledis.call("ping") +` + +var testScript3 = ` + ledis.call("set", 1, "a") + + local a = ledis.call("get", 1) + if type(a) ~= "string" then + error("must string") + elseif a ~= "a" then + error("must a") + end +` + +var testScript4 = ` + ledis.call("select", 2) + ledis.call("set", 2, "b") +` + +func TestLuaCall(t *testing.T) { + cfg := new(config.Config) + cfg.Addr = ":11188" + cfg.DataDir = "/tmp/testscript" + cfg.DBName = "memory" + + app, e := NewApp(cfg) + if e != nil { + t.Fatal(e) + } + go app.Run() + + defer app.Close() + + db, _ := app.ldb.Select(0) + m, _ := db.Multi() + defer m.Close() + + luaClient := app.s.c + luaClient.db = m.DB + luaClient.script = m + + l := app.s.l + + err := app.s.l.DoString(testScript1) + if err != nil { + t.Fatal(err) + } + + v := luaReplyToLedisReply(l) + if vv, ok := v.([]interface{}); ok { + if len(vv) != 3 { + t.Fatal(len(vv)) + } + } else { + t.Fatal(fmt.Sprintf("%v %T", v, v)) + } + + err = app.s.l.DoString(testScript2) + if err != nil { + t.Fatal(err) + } + + v = luaReplyToLedisReply(l) + if vv := v.(string); vv != "PONG" { + t.Fatal(fmt.Sprintf("%v %T", v, v)) + } + + err = app.s.l.DoString(testScript3) + if err != nil { + t.Fatal(err) + } + + if v, err := db.Get([]byte("1")); err != nil { + t.Fatal(err) + } else if string(v) != "a" { + t.Fatal(string(v)) + } + + err = app.s.l.DoString(testScript4) + if err != nil { + t.Fatal(err) + } + + if luaClient.db.Index() != 2 { + t.Fatal(luaClient.db.Index()) + } + + db2, _ := app.ldb.Select(2) + if v, err := db2.Get([]byte("2")); err != nil { + t.Fatal(err) + } else if string(v) != "b" { + t.Fatal(string(v)) + } + + luaClient.db = nil +}