diff --git a/ssdb/app.go b/ssdb/app.go index d3286aa..8b3e9d5 100644 --- a/ssdb/app.go +++ b/ssdb/app.go @@ -4,7 +4,6 @@ import ( "github.com/siddontang/golib/leveldb" "net" "strings" - "sync" ) type App struct { @@ -14,10 +13,10 @@ type App struct { db *leveldb.DB - kvMutex sync.Mutex - hashMutex sync.Mutex - listMutex sync.Mutex - zsetMutex sync.Mutex + kvTx *tx + listTx *tx + hashTx *tx + zsetTx *tx } func NewApp(cfg *Config) (*App, error) { @@ -42,6 +41,11 @@ func NewApp(cfg *Config) (*App, error) { return nil, err } + app.kvTx = app.newTx() + app.listTx = app.newTx() + app.hashTx = app.newTx() + app.zsetTx = app.newTx() + return app, nil } diff --git a/ssdb/client.go b/ssdb/client.go index 2c8ad96..8c7b296 100644 --- a/ssdb/client.go +++ b/ssdb/client.go @@ -123,8 +123,6 @@ func (c *client) readRequest() ([][]byte, error) { } } else { - println("return 6") - return nil, errReadRequest } } diff --git a/ssdb/cmd_list.go b/ssdb/cmd_list.go index a10cb55..4fba6e0 100644 --- a/ssdb/cmd_list.go +++ b/ssdb/cmd_list.go @@ -1,30 +1,131 @@ package ssdb +import ( + "github.com/siddontang/golib/hack" + "strconv" +) + func lpushCommand(c *client) error { + args := c.args + if len(args) < 2 { + return ErrCmdParams + } + + if n, err := c.app.list_lpush(args[0], args[1:]); err != nil { + return err + } else { + c.writeInteger(n) + } + return nil } func rpushCommand(c *client) error { + args := c.args + if len(args) < 2 { + return ErrCmdParams + } + + if n, err := c.app.list_rpush(args[0], args[1:]); err != nil { + return err + } else { + c.writeInteger(n) + } + return nil } func lpopCommand(c *client) error { + args := c.args + if len(args) != 1 { + return ErrCmdParams + } + + if v, err := c.app.list_lpop(args[0]); err != nil { + return err + } else { + c.writeBulk(v) + } + return nil } func rpopCommand(c *client) error { + args := c.args + if len(args) != 1 { + return ErrCmdParams + } + + if v, err := c.app.list_rpop(args[0]); err != nil { + return err + } else { + c.writeBulk(v) + } + return nil } func llenCommand(c *client) error { + args := c.args + if len(args) != 1 { + return ErrCmdParams + } + + if n, err := c.app.list_len(args[0]); err != nil { + return err + } else { + c.writeInteger(n) + } + return nil } func lindexCommand(c *client) error { + args := c.args + if len(args) != 2 { + return ErrCmdParams + } + + index, err := strconv.ParseInt(hack.String(args[1]), 10, 64) + if err != nil { + return err + } + + if v, err := c.app.list_index(args[0], index); err != nil { + return err + } else { + c.writeBulk(v) + } + return nil } func lrangeCommand(c *client) error { + args := c.args + if len(args) != 3 { + return ErrCmdParams + } + + var start int64 + var stop int64 + var err error + + start, err = strconv.ParseInt(hack.String(args[1]), 10, 64) + if err != nil { + return err + } + + stop, err = strconv.ParseInt(hack.String(args[2]), 10, 64) + if err != nil { + return err + } + + if v, err := c.app.list_range(args[0], start, stop); err != nil { + return err + } else { + c.writeArray(v) + } + return nil } diff --git a/ssdb/cmd_list_test.go b/ssdb/cmd_list_test.go new file mode 100644 index 0000000..8cfed80 --- /dev/null +++ b/ssdb/cmd_list_test.go @@ -0,0 +1,293 @@ +package ssdb + +import ( + "fmt" + "github.com/garyburd/redigo/redis" + "strconv" + "testing" +) + +func testListIndex(key []byte, index int64, v int) error { + c := getTestConn() + defer c.Close() + + n, err := redis.Int(c.Do("lindex", key, index)) + if err == redis.ErrNil && v != 0 { + return fmt.Errorf("must nil") + } else if err != nil && err != redis.ErrNil { + return err + } else if n != v { + return fmt.Errorf("index err number %d != %d", n, v) + } + + return nil +} + +func testListRange(key []byte, start int64, stop int64, checkValues ...int) error { + c := getTestConn() + defer c.Close() + + vs, err := redis.MultiBulk(c.Do("lrange", key, start, stop)) + if err != nil { + return err + } + + if len(vs) != len(checkValues) { + return fmt.Errorf("invalid return number %d != %d", len(vs), len(checkValues)) + } + + var n int + for i, v := range vs { + if d, ok := v.([]byte); ok { + n, err = strconv.Atoi(string(d)) + if err != nil { + return err + } else if n != checkValues[i] { + return fmt.Errorf("invalid data %d: %d != %d", i, n, checkValues[i]) + } + } else { + return fmt.Errorf("invalid data %v %T", v, v) + } + } + + return nil +} + +func testPrintList(key []byte) { + headSeq, _ := testApp.db.GetString(encode_list_key(key, listHeadSeq)) + tailSeq, _ := testApp.db.GetString(encode_list_key(key, listTailSeq)) + + size, _ := testApp.db.GetString(encode_lsize_key(key)) + + println("begin ---------------------") + println(headSeq, tailSeq, size) + + it := testApp.db.Iterator(encode_list_key(key, listMinSeq), + encode_list_key(key, listMaxSeq), 0) + for ; it.Valid(); it.Next() { + k, seq, _ := decode_list_key(it.Key()) + println(string(k), seq, string(it.Value())) + } + println("end ---------------------") +} + +func TestList(t *testing.T) { + startTestApp() + + c := getTestConn() + defer c.Close() + + key := []byte("a") + + if n, err := redis.Int(c.Do("lpush", key, 1)); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Fatal(n) + } + + if n, err := redis.Int(c.Do("rpush", key, 2)); err != nil { + t.Fatal(err) + } else if n != 2 { + t.Fatal(n) + } + + if n, err := redis.Int(c.Do("rpush", key, 3)); err != nil { + t.Fatal(err) + } else if n != 3 { + t.Fatal(n) + } + + if n, err := redis.Int(c.Do("llen", key)); err != nil { + t.Fatal(err) + } else if n != 3 { + t.Fatal(n) + } + + //for redis-cli a 1 2 3 + // 127.0.0.1:6379> lrange a 0 0 + // 1) "1" + if err := testListRange(key, 0, 0, 1); err != nil { + t.Fatal(err) + } + + // 127.0.0.1:6379> lrange a 0 1 + // 1) "1" + // 2) "2" + + if err := testListRange(key, 0, 1, 1, 2); err != nil { + t.Fatal(err) + } + + // 127.0.0.1:6379> lrange a 0 5 + // 1) "1" + // 2) "2" + // 3) "3" + if err := testListRange(key, 0, 5, 1, 2, 3); err != nil { + t.Fatal(err) + } + + // 127.0.0.1:6379> lrange a -1 5 + // 1) "3" + if err := testListRange(key, -1, 5, 3); err != nil { + t.Fatal(err) + } + + // 127.0.0.1:6379> lrange a -5 -1 + // 1) "1" + // 2) "2" + // 3) "3" + if err := testListRange(key, -5, -1, 1, 2, 3); err != nil { + t.Fatal(err) + } + + // 127.0.0.1:6379> lrange a -2 -1 + // 1) "2" + // 2) "3" + if err := testListRange(key, -2, -1, 2, 3); err != nil { + t.Fatal(err) + } + + // 127.0.0.1:6379> lrange a -1 -2 + // (empty list or set) + if err := testListRange(key, -1, -2); err != nil { + t.Fatal(err) + } + + // 127.0.0.1:6379> lrange a -1 2 + // 1) "3" + if err := testListRange(key, -1, 2, 3); err != nil { + t.Fatal(err) + } + + // 127.0.0.1:6379> lrange a -5 5 + // 1) "1" + // 2) "2" + // 3) "3" + if err := testListRange(key, -5, 5, 1, 2, 3); err != nil { + t.Fatal(err) + } + + // 127.0.0.1:6379> lrange a -1 0 + // (empty list or set) + if err := testListRange(key, -1, 0); err != nil { + t.Fatal(err) + } + + if err := testListRange([]byte("empty list"), 0, 100); err != nil { + t.Fatal(err) + } + + // 127.0.0.1:6379> lrange a -1 -1 + // 1) "3" + if err := testListRange(key, -1, -1, 3); err != nil { + t.Fatal(err) + } + + if err := testListIndex(key, -1, 3); err != nil { + t.Fatal(err) + } + + if err := testListIndex(key, 0, 1); err != nil { + t.Fatal(err) + } + + if err := testListIndex(key, 1, 2); err != nil { + t.Fatal(err) + } + + if err := testListIndex(key, 2, 3); err != nil { + t.Fatal(err) + } + + if err := testListIndex(key, 5, 0); err != nil { + t.Fatal(err) + } + + if err := testListIndex(key, -1, 3); err != nil { + t.Fatal(err) + } + + if err := testListIndex(key, -2, 2); err != nil { + t.Fatal(err) + } + + if err := testListIndex(key, -3, 1); err != nil { + t.Fatal(err) + } +} + +func TestListMPush(t *testing.T) { + startTestApp() + c := getTestConn() + defer c.Close() + + key := []byte("b") + if n, err := redis.Int(c.Do("rpush", key, 1, 2, 3)); err != nil { + t.Fatal(err) + } else if n != 3 { + t.Fatal(n) + } + + if err := testListRange(key, 0, 3, 1, 2, 3); err != nil { + t.Fatal(err) + } + + if n, err := redis.Int(c.Do("lpush", key, 1, 2, 3)); err != nil { + t.Fatal(err) + } else if n != 6 { + t.Fatal(n) + } + + if err := testListRange(key, 0, 6, 3, 2, 1, 1, 2, 3); err != nil { + t.Fatal(err) + } +} + +func TestPop(t *testing.T) { + startTestApp() + c := getTestConn() + defer c.Close() + + key := []byte("c") + if n, err := redis.Int(c.Do("rpush", key, 1, 2, 3, 4, 5, 6)); err != nil { + t.Fatal(err) + } else if n != 6 { + t.Fatal(n) + } + + if v, err := redis.Int(c.Do("lpop", key)); err != nil { + t.Fatal(err) + } else if v != 1 { + t.Fatal(v) + } + + if v, err := redis.Int(c.Do("rpop", key)); err != nil { + t.Fatal(err) + } else if v != 6 { + t.Fatal(v) + } + + if n, err := redis.Int(c.Do("lpush", key, 1)); err != nil { + t.Fatal(err) + } else if n != 5 { + t.Fatal(n) + } + + if err := testListRange(key, 0, 5, 1, 2, 3, 4, 5); err != nil { + t.Fatal(err) + } + + for i := 1; i <= 5; i++ { + if v, err := redis.Int(c.Do("lpop", key)); err != nil { + t.Fatal(err) + } else if v != i { + t.Fatal(v) + } + } + + if n, err := redis.Int(c.Do("llen", key)); err != nil { + t.Fatal(err) + } else if n != 0 { + t.Fatal(n) + } +} diff --git a/ssdb/t_kv.go b/ssdb/t_kv.go index b7bb4d4..fcae406 100644 --- a/ssdb/t_kv.go +++ b/ssdb/t_kv.go @@ -6,7 +6,6 @@ import ( "strconv" ) -var errEmptyKVKey = errors.New("invalid empty kv key") var errKVKey = errors.New("invalid encode kv key") func encode_kv_key(key []byte) []byte { @@ -16,12 +15,12 @@ func encode_kv_key(key []byte) []byte { return ek } -func decode_kv_key(encodeKey []byte) ([]byte, error) { - if encodeKey[0] != KV_TYPE { +func decode_kv_key(ek []byte) ([]byte, error) { + if len(ek) == 0 || ek[0] != KV_TYPE { return nil, errKVKey } - return encodeKey[1:], nil + return ek[1:], nil } func (a *App) kv_get(key []byte) ([]byte, error) { @@ -34,10 +33,10 @@ func (a *App) kv_set(key []byte, value []byte) error { key = encode_kv_key(key) var err error - t := a.newTx() + t := a.kvTx - a.kvMutex.Lock() - defer a.kvMutex.Unlock() + t.Lock() + defer t.Unlock() t.Put(key, value) @@ -52,13 +51,13 @@ func (a *App) kv_getset(key []byte, value []byte) ([]byte, error) { key = encode_kv_key(key) var err error - a.kvMutex.Lock() - defer a.kvMutex.Unlock() + t := a.kvTx + + t.Lock() + defer t.Unlock() oldValue, _ := a.db.Get(key) - t := a.newTx() - t.Put(key, value) //todo, binlog @@ -73,10 +72,10 @@ func (a *App) kv_setnx(key []byte, value []byte) (int64, error) { var n int64 = 1 - t := a.newTx() + t := a.kvTx - a.kvMutex.Lock() - defer a.kvMutex.Unlock() + t.Lock() + defer t.Unlock() if v, _ := a.db.Get(key); v != nil { n = 0 @@ -108,25 +107,17 @@ func (a *App) kv_incr(key []byte, delta int64) (int64, error) { key = encode_kv_key(key) var err error - t := a.newTx() + t := a.kvTx - a.kvMutex.Lock() - defer a.kvMutex.Unlock() + t.Lock() + defer t.Unlock() - var v []byte - v, err = a.db.Get(key) + var n int64 + n, err = a.db.GetInt(key) if err != nil { return 0, err } - var n int64 = 0 - if v != nil { - n, err = strconv.ParseInt(hack.String(v), 10, 64) - if err != nil { - return 0, err - } - } - n += delta t.Put(key, hack.Slice(strconv.FormatInt(n, 10))) @@ -142,10 +133,10 @@ func (a *App) tx_del(keys [][]byte) (int64, error) { keys[i] = encode_kv_key(keys[i]) } - t := a.newTx() + t := a.kvTx - a.kvMutex.Lock() - defer a.kvMutex.Unlock() + t.Lock() + defer t.Unlock() for i := range keys { t.Delete(keys[i]) @@ -157,10 +148,10 @@ func (a *App) tx_del(keys [][]byte) (int64, error) { } func (a *App) tx_mset(args [][]byte) error { - t := a.newTx() + t := a.kvTx - a.kvMutex.Lock() - defer a.kvMutex.Unlock() + t.Lock() + defer t.Unlock() for i := 0; i < len(args); i += 2 { key := encode_kv_key(args[i]) diff --git a/ssdb/t_list.go b/ssdb/t_list.go new file mode 100644 index 0000000..670f289 --- /dev/null +++ b/ssdb/t_list.go @@ -0,0 +1,280 @@ +package ssdb + +import ( + "encoding/binary" + "errors" + "github.com/siddontang/golib/hack" + "strconv" +) + +const ( + listHeadSeq int64 = 1 + listTailSeq int64 = 2 + + listMinSeq int64 = 1000 + listMaxSeq int64 = 1<<63 - 1000 + listInitialSeq int64 = listMinSeq + (listMaxSeq-listMinSeq)/2 +) + +var errLSizeKey = errors.New("invalid lsize key") +var errListKey = errors.New("invalid list key") +var errListSeq = errors.New("invalid list sequence, overflow") + +func encode_lsize_key(key []byte) []byte { + buf := make([]byte, len(key)+1) + buf[0] = LSIZE_TYPE + + copy(buf[1:], key) + return buf +} + +func decode_lsize_key(ek []byte) ([]byte, error) { + if len(ek) == 0 || ek[0] != LSIZE_TYPE { + return nil, errLSizeKey + } + + return ek[1:], nil +} + +func encode_list_key(key []byte, seq int64) []byte { + buf := make([]byte, len(key)+13) + + pos := 0 + buf[pos] = LIST_TYPE + pos++ + + binary.BigEndian.PutUint32(buf[pos:], uint32(len(key))) + pos += 4 + + copy(buf[pos:], key) + pos += len(key) + + binary.BigEndian.PutUint64(buf[pos:], uint64(seq)) + + return buf +} + +func decode_list_key(ek []byte) (key []byte, seq int64, err error) { + if len(ek) < 13 || ek[0] != LIST_TYPE { + err = errListKey + return + } + + keyLen := int(binary.BigEndian.Uint32(ek[1:])) + if keyLen+13 != len(ek) { + err = errListKey + return + } + + key = ek[5 : 5+keyLen] + seq = int64(binary.BigEndian.Uint64(ek[5+keyLen:])) + return +} + +func (a *App) list_lpush(key []byte, args [][]byte) (int64, error) { + return a.list_push(key, args, listHeadSeq) +} + +func (a *App) list_rpush(key []byte, args [][]byte) (int64, error) { + return a.list_push(key, args, listTailSeq) +} + +func (a *App) list_lpop(key []byte) ([]byte, error) { + return a.list_pop(key, listHeadSeq) +} + +func (a *App) list_rpop(key []byte) ([]byte, error) { + return a.list_pop(key, listTailSeq) +} + +func (a *App) list_getSeq(key []byte, whereSeq int64) (int64, error) { + ek := encode_list_key(key, whereSeq) + + return a.db.GetInt(ek) +} + +func (a *App) list_len(key []byte) (int64, error) { + ek := encode_lsize_key(key) + + return a.db.GetInt(ek) +} + +func (a *App) list_push(key []byte, args [][]byte, whereSeq int64) (int64, error) { + t := a.listTx + t.Lock() + defer t.Unlock() + + seq, err := a.list_getSeq(key, whereSeq) + if err != nil { + return 0, err + } + + var size int64 = 0 + + var delta int64 = 1 + if whereSeq == listHeadSeq { + delta = -1 + } + + if seq == 0 { + seq = listInitialSeq + + t.Put(encode_list_key(key, listHeadSeq), hack.Slice(strconv.FormatInt(seq, 10))) + t.Put(encode_list_key(key, listTailSeq), hack.Slice(strconv.FormatInt(seq, 10))) + } else { + size, err = a.list_len(key) + if err != nil { + return 0, err + } + + seq += delta + } + + for i := 0; i < len(args); i++ { + t.Put(encode_list_key(key, seq+int64(i)*delta), args[i]) + //to do add binlog + } + + seq += int64(len(args)-1) * delta + + if seq <= listMinSeq || seq >= listMaxSeq { + return 0, errListSeq + } + + size += int64(len(args)) + + t.Put(encode_lsize_key(key), hack.Slice(strconv.FormatInt(size, 10))) + t.Put(encode_list_key(key, whereSeq), hack.Slice(strconv.FormatInt(seq, 10))) + + err = t.Commit() + + return size, err +} + +func (a *App) list_pop(key []byte, whereSeq int64) ([]byte, error) { + t := a.listTx + t.Lock() + defer t.Unlock() + + var delta int64 = 1 + if whereSeq == listTailSeq { + delta = -1 + } + + seq, err := a.list_getSeq(key, whereSeq) + if err != nil { + return nil, err + } + + var value []byte + value, err = a.db.Get(encode_list_key(key, seq)) + if err != nil { + return nil, err + } + + t.Delete(encode_list_key(key, seq)) + seq += delta + + var size int64 + size, err = a.list_len(key) + if err != nil { + return nil, err + } + + size-- + if size <= 0 { + t.Delete(encode_lsize_key(key)) + t.Delete(encode_list_key(key, listHeadSeq)) + t.Delete(encode_list_key(key, listTailSeq)) + } else { + t.Put(encode_list_key(key, whereSeq), hack.Slice(strconv.FormatInt(seq, 10))) + t.Put(encode_lsize_key(key), hack.Slice(strconv.FormatInt(size, 10))) + } + + //todo add binlog + err = t.Commit() + return value, err +} + +func (a *App) list_range(key []byte, start int64, stop int64) ([]interface{}, error) { + v := make([]interface{}, 0, 16) + + var startSeq int64 + var stopSeq int64 + + if start > stop { + return []interface{}{}, nil + } else if start >= 0 && stop >= 0 { + seq, err := a.list_getSeq(key, listHeadSeq) + if err != nil { + return nil, err + } + + startSeq = seq + start + stopSeq = seq + stop + 1 + + } else if start < 0 && stop < 0 { + seq, err := a.list_getSeq(key, listTailSeq) + if err != nil { + return nil, err + } + + startSeq = seq + start + 1 + stopSeq = seq + stop + 2 + } else { + //start < 0 && stop > 0 + var err error + startSeq, err = a.list_getSeq(key, listTailSeq) + if err != nil { + return nil, err + } + + startSeq += start + 1 + + stopSeq, err = a.list_getSeq(key, listHeadSeq) + if err != nil { + return nil, err + } + + stopSeq += stop + 1 + } + + if startSeq < listMinSeq { + startSeq = listMinSeq + } else if stopSeq > listMaxSeq { + stopSeq = listMaxSeq + } + + it := a.db.Iterator(encode_list_key(key, startSeq), + encode_list_key(key, stopSeq), 0) + for ; it.Valid(); it.Next() { + v = append(v, it.Value()) + } + + it.Close() + + return v, nil +} + +func (a *App) list_index(key []byte, index int64) ([]byte, error) { + var seq int64 + var err error + if index >= 0 { + seq, err = a.list_getSeq(key, listHeadSeq) + if err != nil { + return nil, err + } + + seq = seq + index + + } else { + seq, err = a.list_getSeq(key, listTailSeq) + if err != nil { + return nil, err + } + + seq = seq + index + 1 + } + + return a.db.Get(encode_list_key(key, seq)) +} diff --git a/ssdb/tx.go b/ssdb/tx.go index 2f34525..afe7d40 100644 --- a/ssdb/tx.go +++ b/ssdb/tx.go @@ -2,9 +2,12 @@ package ssdb import ( "github.com/siddontang/golib/leveldb" + "sync" ) type tx struct { + m sync.Mutex + app *App wb *leveldb.WriteBatch @@ -20,6 +23,10 @@ func (app *App) newTx() *tx { return t } +func (t *tx) Close() { + t.wb.Close() +} + func (t *tx) Put(key []byte, value []byte) { t.wb.Put(key, value) } @@ -28,6 +35,15 @@ func (t *tx) Delete(key []byte) { t.wb.Delete(key) } +func (t *tx) Lock() { + t.m.Lock() +} + +func (t *tx) Unlock() { + t.wb.Rollback() + t.m.Unlock() +} + func (t *tx) Commit() error { err := t.wb.Commit() return err