diff --git a/ledis/tx.go b/ledis/tx.go new file mode 100644 index 0000000..6339bae --- /dev/null +++ b/ledis/tx.go @@ -0,0 +1,112 @@ +package ledis + +import ( + "errors" + "fmt" + "github.com/siddontang/ledisdb/store" +) + +var ( + ErrNestTx = errors.New("nest transaction not supported") + ErrTxDone = errors.New("Transaction has already been committed or rolled back") +) + +type Tx struct { + *DB + + tx *store.Tx + + logs [][]byte +} + +func (db *DB) IsTransaction() bool { + return db.status == DBInTransaction +} + +// Begin a transaction, it will block all other write operations before calling Commit or Rollback. +// You must be very careful to prevent long-time transaction. +func (db *DB) Begin() (*Tx, error) { + if db.IsTransaction() { + return nil, ErrNestTx + } + + tx := new(Tx) + + tx.DB = new(DB) + tx.DB.l = db.l + + tx.l.wLock.Lock() + + tx.DB.sdb = db.sdb + + var err error + tx.tx, err = db.sdb.Begin() + if err != nil { + tx.l.wLock.Unlock() + return nil, err + } + + tx.DB.bucket = tx.tx + + tx.DB.status = DBInTransaction + + tx.DB.index = db.index + + tx.DB.kvBatch = tx.newBatch() + tx.DB.listBatch = tx.newBatch() + tx.DB.hashBatch = tx.newBatch() + tx.DB.zsetBatch = tx.newBatch() + tx.DB.binBatch = tx.newBatch() + tx.DB.setBatch = tx.newBatch() + + return tx, nil +} + +func (tx *Tx) Commit() error { + if tx.tx == nil { + return ErrTxDone + } + + tx.l.commitLock.Lock() + err := tx.tx.Commit() + tx.tx = nil + + if len(tx.logs) > 0 { + tx.l.binlog.Log(tx.logs...) + } + + tx.l.commitLock.Unlock() + + tx.l.wLock.Unlock() + + tx.DB.bucket = nil + + return err +} + +func (tx *Tx) Rollback() error { + if tx.tx == nil { + return ErrTxDone + } + + err := tx.tx.Rollback() + tx.tx = nil + + tx.l.wLock.Unlock() + tx.DB.bucket = nil + + return err +} + +func (tx *Tx) newBatch() *batch { + return tx.l.newBatch(tx.tx.NewWriteBatch(), &txBatchLocker{}, tx) +} + +func (tx *Tx) Select(index int) error { + if index < 0 || index >= int(MaxDBNumber) { + return fmt.Errorf("invalid db index %d", index) + } + + tx.DB.index = uint8(index) + return nil +} diff --git a/server/cmd_script.go b/server/cmd_script.go index f0dabd1..1325477 100644 --- a/server/cmd_script.go +++ b/server/cmd_script.go @@ -1,6 +1,7 @@ package server import ( + "crypto/sha1" "encoding/hex" "fmt" "github.com/aarzilli/golua/lua" @@ -64,20 +65,21 @@ func evalGenericCommand(c *client, evalSha1 bool) error { return err } - var sha1 string + var key string if !evalSha1 { - sha1 = hex.EncodeToString(c.args[0]) + h := sha1.Sum(c.args[0]) + key = hex.EncodeToString(h[0:20]) } else { - sha1 = ledis.String(c.args[0]) + key = strings.ToLower(ledis.String(c.args[0])) } - l.GetGlobal(sha1) + l.GetGlobal(key) if l.IsNil(-1) { l.Pop(1) if evalSha1 { - return fmt.Errorf("missing %s script", sha1) + return fmt.Errorf("missing %s script", key) } if r := l.LoadString(ledis.String(c.args[0])); r != 0 { @@ -86,9 +88,9 @@ func evalGenericCommand(c *client, evalSha1 bool) error { return err } else { l.PushValue(-1) - l.SetGlobal(sha1) + l.SetGlobal(key) - s.chunks[sha1] = struct{}{} + s.chunks[key] = struct{}{} } } @@ -125,9 +127,6 @@ func scriptCommand(c *client) error { }() args := c.args - if len(args) < 1 { - return ErrCmdParams - } switch strings.ToLower(c.cmd) { case "script load": @@ -137,7 +136,7 @@ func scriptCommand(c *client) error { case "script flush": return scriptFlushCommand(c) default: - return fmt.Errorf("invalid scirpt cmd %s", args[0]) + return fmt.Errorf("invalid script cmd %s", args[0]) } return nil @@ -151,28 +150,33 @@ func scriptLoadCommand(c *client) error { return ErrCmdParams } - sha1 := hex.EncodeToString(c.args[1]) + h := sha1.Sum(c.args[0]) + key := hex.EncodeToString(h[0:20]) - if r := l.LoadString(ledis.String(c.args[1])); r != 0 { + if r := l.LoadString(ledis.String(c.args[0])); r != 0 { err := fmt.Errorf("%s", l.ToString(-1)) l.Pop(1) return err } else { l.PushValue(-1) - l.SetGlobal(sha1) + l.SetGlobal(key) - s.chunks[sha1] = struct{}{} + s.chunks[key] = struct{}{} } - c.resp.writeBulk(ledis.Slice(sha1)) + c.resp.writeBulk(ledis.Slice(key)) return nil } func scriptExistsCommand(c *client) error { s := c.app.s - ay := make([]interface{}, len(c.args[1:])) - for i, n := range c.args[1:] { + if len(c.args) < 1 { + return ErrCmdParams + } + + ay := make([]interface{}, len(c.args)) + for i, n := range c.args { if _, ok := s.chunks[ledis.String(n)]; ok { ay[i] = int64(1) } else { @@ -193,6 +197,8 @@ func scriptFlushCommand(c *client) error { l.SetGlobal(n) } + s.chunks = map[string]struct{}{} + c.resp.writeStatus(OK) return nil diff --git a/server/cmd_script_test.go b/server/cmd_script_test.go index 5c47866..d652e08 100644 --- a/server/cmd_script_test.go +++ b/server/cmd_script_test.go @@ -13,17 +13,45 @@ func TestCmdEval(t *testing.T) { if v, err := ledis.Strings(c.Do("eval", "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", 2, "key1", "key2", "first", "second")); err != nil { t.Fatal(err) - } else if len(v) != 4 { - t.Fatal(err) } else if !reflect.DeepEqual(v, []string{"key1", "key2", "first", "second"}) { t.Fatal(fmt.Sprintf("%v", v)) } if v, err := ledis.Strings(c.Do("eval", "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", 2, "key1", "key2", "first", "second")); err != nil { t.Fatal(err) - } else if len(v) != 4 { + } else if !reflect.DeepEqual(v, []string{"key1", "key2", "first", "second"}) { + t.Fatal(fmt.Sprintf("%v", v)) + } + + var sha1 string + var err error + if sha1, err = ledis.String(c.Do("script load", "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}")); err != nil { + t.Fatal(err) + } else if len(sha1) != 40 { + t.Fatal(sha1) + } + + if v, err := ledis.Strings(c.Do("evalsha", sha1, 2, "key1", "key2", "first", "second")); err != nil { t.Fatal(err) } else if !reflect.DeepEqual(v, []string{"key1", "key2", "first", "second"}) { t.Fatal(fmt.Sprintf("%v", v)) } + + if ay, err := ledis.Values(c.Do("script exists", sha1, "01234567890123456789")); err != nil { + t.Fatal(err) + } else if !reflect.DeepEqual(ay, []interface{}{int64(1), int64(0)}) { + t.Fatal(fmt.Sprintf("%v", ay)) + } + + if ok, err := ledis.String(c.Do("script flush")); err != nil { + t.Fatal(err) + } else if ok != "OK" { + t.Fatal(ok) + } + + if ay, err := ledis.Values(c.Do("script exists", sha1)); err != nil { + t.Fatal(err) + } else if !reflect.DeepEqual(ay, []interface{}{int64(0)}) { + t.Fatal(fmt.Sprintf("%v", ay)) + } }