diff --git a/ledis/const.go b/ledis/const.go index ff9809f..156f439 100644 --- a/ledis/const.go +++ b/ledis/const.go @@ -1,5 +1,9 @@ package ledis +import ( + "errors" +) + const ( kvType byte = iota + 1 hashType @@ -10,3 +14,23 @@ const ( zSizeType zScoreType ) + +const ( + //we don't support too many databases + MaxDBNumber uint8 = 16 + + //max key size + MaxKeySize int = 1<<16 - 1 + + //max hash field size + MaxHashFieldSize int = 1<<16 - 1 + + //max zset member size + MaxZSetMemberSize int = 1<<16 - 1 +) + +var ( + ErrKeySize = errors.New("invalid key size") + ErrHashFieldSize = errors.New("invalid hash field size") + ErrZSetMemberSize = errors.New("invalid zset member size") +) diff --git a/ledis/db.go b/ledis/db.go deleted file mode 100644 index 391de74..0000000 --- a/ledis/db.go +++ /dev/null @@ -1,55 +0,0 @@ -package ledis - -import ( - "encoding/json" - "github.com/siddontang/go-leveldb/leveldb" -) - -type DBConfig struct { - DataDB leveldb.Config `json:"data_db"` -} - -type DB struct { - cfg *DBConfig - - db *leveldb.DB - - kvTx *tx - listTx *tx - hashTx *tx - zsetTx *tx -} - -func OpenDB(configJson json.RawMessage) (*DB, error) { - var cfg DBConfig - - if err := json.Unmarshal(configJson, &cfg); err != nil { - return nil, err - } - - return OpenDBWithConfig(&cfg) -} - -func OpenDBWithConfig(cfg *DBConfig) (*DB, error) { - db, err := leveldb.OpenWithConfig(&cfg.DataDB) - if err != nil { - return nil, err - } - - d := new(DB) - - d.cfg = cfg - - d.db = db - - d.kvTx = &tx{wb: db.NewWriteBatch()} - d.listTx = &tx{wb: db.NewWriteBatch()} - d.hashTx = &tx{wb: db.NewWriteBatch()} - d.zsetTx = &tx{wb: db.NewWriteBatch()} - - return d, nil -} - -func (db *DB) Close() { - db.db.Close() -} diff --git a/ledis/db_test.go b/ledis/db_test.go deleted file mode 100644 index 683882a..0000000 --- a/ledis/db_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package ledis - -import ( - "sync" - "testing" -) - -var testDB *DB -var testDBOnce sync.Once - -func getTestDB() *DB { - f := func() { - var d = []byte(` - { - "data_db" : { - "path" : "/tmp/testdb", - "compression":true, - "block_size" : 32768, - "write_buffer_size" : 2097152, - "cache_size" : 20971520 - } - } - `) - db, err := OpenDB(d) - if err != nil { - println(err.Error()) - panic(err) - } - - testDB = db - - testDB.db.Clear() - } - - testDBOnce.Do(f) - return testDB -} - -func TestDB(t *testing.T) { - getTestDB() -} diff --git a/ledis/ledis.go b/ledis/ledis.go new file mode 100644 index 0000000..330f24c --- /dev/null +++ b/ledis/ledis.go @@ -0,0 +1,82 @@ +package ledis + +import ( + "encoding/json" + "fmt" + "github.com/siddontang/go-leveldb/leveldb" +) + +type Config struct { + DataDB leveldb.Config `json:"data_db"` +} + +type DB struct { + db *leveldb.DB + + index uint8 + + kvTx *tx + listTx *tx + hashTx *tx + zsetTx *tx +} + +type Ledis struct { + cfg *Config + + ldb *leveldb.DB + dbs [MaxDBNumber]*DB +} + +func Open(configJson json.RawMessage) (*Ledis, error) { + var cfg Config + + if err := json.Unmarshal(configJson, &cfg); err != nil { + return nil, err + } + + return OpenWithConfig(&cfg) +} + +func OpenWithConfig(cfg *Config) (*Ledis, error) { + ldb, err := leveldb.OpenWithConfig(&cfg.DataDB) + if err != nil { + return nil, err + } + + l := new(Ledis) + l.ldb = ldb + + for i := uint8(0); i < MaxDBNumber; i++ { + l.dbs[i] = newDB(l, i) + } + + return l, nil +} + +func newDB(l *Ledis, index uint8) *DB { + d := new(DB) + + d.db = l.ldb + + d.index = index + + d.kvTx = &tx{wb: d.db.NewWriteBatch()} + d.listTx = &tx{wb: d.db.NewWriteBatch()} + d.hashTx = &tx{wb: d.db.NewWriteBatch()} + d.zsetTx = &tx{wb: d.db.NewWriteBatch()} + + return d +} + +func (l *Ledis) Close() { + l.ldb.Close() +} + +func (l *Ledis) Select(index int) (*DB, error) { + if index < 0 || index >= int(MaxDBNumber) { + return nil, fmt.Errorf("invalid db index %d", index) + } + + return l.dbs[index], nil +} diff --git a/ledis/ledis_test.go b/ledis/ledis_test.go new file mode 100644 index 0000000..abd656c --- /dev/null +++ b/ledis/ledis_test.go @@ -0,0 +1,64 @@ +package ledis + +import ( + "sync" + "testing" +) + +var testLedis *Ledis +var testLedisOnce sync.Once + +func getTestDB() *DB { + f := func() { + var d = []byte(` + { + "data_db" : { + "path" : "/tmp/testdb", + "compression":true, + "block_size" : 32768, + "write_buffer_size" : 2097152, + "cache_size" : 20971520 + } + } + `) + var err error + testLedis, err = Open(d) + if err != nil { + println(err.Error()) + panic(err) + } + + testLedis.ldb.Clear() + } + + testLedisOnce.Do(f) + db, _ := testLedis.Select(0) + return db +} + +func TestDB(t *testing.T) { + getTestDB() +} + +func TestSelect(t *testing.T) { + db0, _ := testLedis.Select(0) + db1, _ := testLedis.Select(1) + + key0 := []byte("db0_test_key") + key1 := []byte("db1_test_key") + + db0.Set(key0, []byte("0")) + db1.Set(key1, []byte("1")) + + if v, err := db0.Get(key0); err != nil { + t.Fatal(err) + } else if string(v) != "0" { + t.Fatal(string(v)) + } + + if v, err := db1.Get(key1); err != nil { + t.Fatal(err) + } else if string(v) != "1" { + t.Fatal(string(v)) + } +} diff --git a/ledis/t_hash.go b/ledis/t_hash.go index f6e8826..5c949c5 100644 --- a/ledis/t_hash.go +++ b/ledis/t_hash.go @@ -19,30 +19,44 @@ const ( hashStopSep byte = hashStartSep + 1 ) -func encode_hsize_key(key []byte) []byte { - buf := make([]byte, len(key)+1) - buf[0] = hSizeType +func checkHashKFSize(key []byte, field []byte) error { + if len(key) > MaxKeySize { + return ErrKeySize + } else if len(field) > MaxHashFieldSize { + return ErrHashFieldSize + } + return nil +} - copy(buf[1:], key) +func (db *DB) hEncodeSizeKey(key []byte) []byte { + buf := make([]byte, len(key)+2) + + buf[0] = db.index + buf[1] = hSizeType + + copy(buf[2:], key) return buf } -func decode_hsize_key(ek []byte) ([]byte, error) { - if len(ek) == 0 || ek[0] != hSizeType { +func (db *DB) hDecodeSizeKey(ek []byte) ([]byte, error) { + if len(ek) < 2 || ek[0] != db.index || ek[1] != hSizeType { return nil, errHSizeKey } - return ek[1:], nil + return ek[2:], nil } -func encode_hash_key(key []byte, field []byte) []byte { - buf := make([]byte, len(key)+len(field)+1+4+1) +func (db *DB) hEncodeHashKey(key []byte, field []byte) []byte { + buf := make([]byte, len(key)+len(field)+1+1+2+1) pos := 0 + buf[pos] = db.index + pos++ buf[pos] = hashType pos++ - binary.BigEndian.PutUint32(buf[pos:], uint32(len(key))) - pos += 4 + + binary.BigEndian.PutUint16(buf[pos:], uint16(len(key))) + pos += 2 copy(buf[pos:], key) pos += len(key) @@ -54,29 +68,16 @@ func encode_hash_key(key []byte, field []byte) []byte { return buf } -func encode_hash_start_key(key []byte) []byte { - k := encode_hash_key(key, nil) - return k -} - -func encode_hash_stop_key(key []byte) []byte { - k := encode_hash_key(key, nil) - - k[len(k)-1] = hashStopSep - - return k -} - -func decode_hash_key(ek []byte) ([]byte, []byte, error) { - if len(ek) < 6 || ek[0] != hashType { +func (db *DB) hDecodeHashKey(ek []byte) ([]byte, []byte, error) { + if len(ek) < 5 || ek[0] != db.index || ek[1] != hashType { return nil, nil, errHashKey } - pos := 1 - keyLen := int(binary.BigEndian.Uint32(ek[pos:])) - pos += 4 + pos := 2 + keyLen := int(binary.BigEndian.Uint16(ek[pos:])) + pos += 2 - if keyLen+6 > len(ek) { + if keyLen+5 > len(ek) { return nil, nil, errHashKey } @@ -92,14 +93,26 @@ func decode_hash_key(ek []byte) ([]byte, []byte, error) { return key, field, nil } +func (db *DB) hEncodeStartKey(key []byte) []byte { + return db.hEncodeHashKey(key, nil) +} + +func (db *DB) hEncodeStopKey(key []byte) []byte { + k := db.hEncodeHashKey(key, nil) + + k[len(k)-1] = hashStopSep + + return k +} + func (db *DB) HLen(key []byte) (int64, error) { - return Int64(db.db.Get(encode_hsize_key(key))) + return Int64(db.db.Get(db.hEncodeSizeKey(key))) } func (db *DB) hSetItem(key []byte, field []byte, value []byte) (int64, error) { t := db.hashTx - ek := encode_hash_key(key, field) + ek := db.hEncodeHashKey(key, field) var n int64 = 1 if v, _ := db.db.Get(ek); v != nil { @@ -115,6 +128,10 @@ func (db *DB) hSetItem(key []byte, field []byte, value []byte) (int64, error) { } func (db *DB) HSet(key []byte, field []byte, value []byte) (int64, error) { + if err := checkHashKFSize(key, field); err != nil { + return 0, err + } + t := db.hashTx t.Lock() defer t.Unlock() @@ -131,7 +148,11 @@ func (db *DB) HSet(key []byte, field []byte, value []byte) (int64, error) { } func (db *DB) HGet(key []byte, field []byte) ([]byte, error) { - return db.db.Get(encode_hash_key(key, field)) + if err := checkHashKFSize(key, field); err != nil { + return nil, err + } + + return db.db.Get(db.hEncodeHashKey(key, field)) } func (db *DB) HMset(key []byte, args ...FVPair) error { @@ -139,30 +160,48 @@ func (db *DB) HMset(key []byte, args ...FVPair) error { t.Lock() defer t.Unlock() + var err error + var ek []byte var num int64 = 0 for i := 0; i < len(args); i++ { - ek := encode_hash_key(key, args[i].Field) - if v, _ := db.db.Get(ek); v == nil { + if err := checkHashKFSize(key, args[i].Field); err != nil { + return err + } + + ek = db.hEncodeHashKey(key, args[i].Field) + + if v, err := db.db.Get(ek); err != nil { + return err + } else if v == nil { num++ } t.Put(ek, args[i].Value) } - if _, err := db.hIncrSize(key, num); err != nil { + if _, err = db.hIncrSize(key, num); err != nil { return err } //todo add binglog - err := t.Commit() + err = t.Commit() return err } func (db *DB) HMget(key []byte, args [][]byte) ([]interface{}, error) { + var ek []byte + var v []byte + var err error + r := make([]interface{}, len(args)) for i := 0; i < len(args); i++ { - v, err := db.db.Get(encode_hash_key(key, args[i])) - if err != nil { + if err := checkHashKFSize(key, args[i]); err != nil { + return nil, err + } + + ek = db.hEncodeHashKey(key, args[i]) + + if v, err = db.db.Get(ek); err != nil { return nil, err } @@ -174,13 +213,23 @@ func (db *DB) HMget(key []byte, args [][]byte) ([]interface{}, error) { func (db *DB) HDel(key []byte, args [][]byte) (int64, error) { t := db.hashTx + + var ek []byte + var v []byte + var err error + t.Lock() defer t.Unlock() var num int64 = 0 for i := 0; i < len(args); i++ { - ek := encode_hash_key(key, args[i]) - if v, err := db.db.Get(ek); err != nil { + if err := checkHashKFSize(key, args[i]); err != nil { + return 0, err + } + + ek = db.hEncodeHashKey(key, args[i]) + + if v, err = db.db.Get(ek); err != nil { return 0, err } else if v == nil { continue @@ -190,20 +239,22 @@ func (db *DB) HDel(key []byte, args [][]byte) (int64, error) { } } - if _, err := db.hIncrSize(key, -num); err != nil { + if _, err = db.hIncrSize(key, -num); err != nil { return 0, err } - err := t.Commit() + err = t.Commit() return num, err } func (db *DB) hIncrSize(key []byte, delta int64) (int64, error) { t := db.hashTx - sk := encode_hsize_key(key) - size, err := Int64(db.db.Get(sk)) - if err != nil { + sk := db.hEncodeSizeKey(key) + + var err error + var size int64 = 0 + if size, err = Int64(db.db.Get(sk)); err != nil { return 0, err } else { size += delta @@ -219,15 +270,21 @@ func (db *DB) hIncrSize(key []byte, delta int64) (int64, error) { } func (db *DB) HIncrBy(key []byte, field []byte, delta int64) (int64, error) { + if err := checkHashKFSize(key, field); err != nil { + return 0, err + } + t := db.hashTx + var ek []byte + var err error + t.Lock() defer t.Unlock() - ek := encode_hash_key(key, field) + ek = db.hEncodeHashKey(key, field) var n int64 = 0 - n, err := StrInt64(db.db.Get(ek)) - if err != nil { + if n, err = StrInt64(db.db.Get(ek)); err != nil { return 0, err } @@ -244,14 +301,18 @@ func (db *DB) HIncrBy(key []byte, field []byte, delta int64) (int64, error) { } func (db *DB) HGetAll(key []byte) ([]interface{}, error) { - start := encode_hash_start_key(key) - stop := encode_hash_stop_key(key) + if err := checkKeySize(key); err != nil { + return nil, err + } + + start := db.hEncodeStartKey(key) + stop := db.hEncodeStopKey(key) v := make([]interface{}, 0, 16) it := db.db.Iterator(start, stop, leveldb.RangeROpen, 0, -1) for ; it.Valid(); it.Next() { - _, k, err := decode_hash_key(it.Key()) + _, k, err := db.hDecodeHashKey(it.Key()) if err != nil { return nil, err } @@ -265,14 +326,18 @@ func (db *DB) HGetAll(key []byte) ([]interface{}, error) { } func (db *DB) HKeys(key []byte) ([]interface{}, error) { - start := encode_hash_start_key(key) - stop := encode_hash_stop_key(key) + if err := checkKeySize(key); err != nil { + return nil, err + } + + start := db.hEncodeStartKey(key) + stop := db.hEncodeStopKey(key) v := make([]interface{}, 0, 16) it := db.db.Iterator(start, stop, leveldb.RangeROpen, 0, -1) for ; it.Valid(); it.Next() { - _, k, err := decode_hash_key(it.Key()) + _, k, err := db.hDecodeHashKey(it.Key()) if err != nil { return nil, err } @@ -285,8 +350,12 @@ func (db *DB) HKeys(key []byte) ([]interface{}, error) { } func (db *DB) HValues(key []byte) ([]interface{}, error) { - start := encode_hash_start_key(key) - stop := encode_hash_stop_key(key) + if err := checkKeySize(key); err != nil { + return nil, err + } + + start := db.hEncodeStartKey(key) + stop := db.hEncodeStopKey(key) v := make([]interface{}, 0, 16) @@ -301,15 +370,18 @@ func (db *DB) HValues(key []byte) ([]interface{}, error) { } func (db *DB) HClear(key []byte) (int64, error) { - sk := encode_hsize_key(key) + if err := checkKeySize(key); err != nil { + return 0, err + } + + sk := db.hEncodeSizeKey(key) + start := db.hEncodeStartKey(key) + stop := db.hEncodeStopKey(key) t := db.hashTx t.Lock() defer t.Unlock() - start := encode_hash_start_key(key) - stop := encode_hash_stop_key(key) - var num int64 = 0 it := db.db.Iterator(start, stop, leveldb.RangeROpen, 0, -1) for ; it.Valid(); it.Next() { diff --git a/ledis/t_hash_test.go b/ledis/t_hash_test.go index ed13f2b..5521ad2 100644 --- a/ledis/t_hash_test.go +++ b/ledis/t_hash_test.go @@ -4,6 +4,29 @@ import ( "testing" ) +func TestHashCodec(t *testing.T) { + db := getTestDB() + + key := []byte("key") + field := []byte("field") + + ek := db.hEncodeSizeKey(key) + if k, err := db.hDecodeSizeKey(ek); err != nil { + t.Fatal(err) + } else if string(k) != "key" { + t.Fatal(string(k)) + } + + ek = db.hEncodeHashKey(key, field) + if k, f, err := db.hDecodeHashKey(ek); err != nil { + t.Fatal(err) + } else if string(k) != "key" { + t.Fatal(string(k)) + } else if string(f) != "field" { + t.Fatal(string(f)) + } +} + func TestDBHash(t *testing.T) { db := getTestDB() diff --git a/ledis/t_kv.go b/ledis/t_kv.go index 5df995b..5f08d62 100644 --- a/ledis/t_kv.go +++ b/ledis/t_kv.go @@ -11,24 +11,36 @@ type KVPair struct { var errKVKey = errors.New("invalid encode kv key") -func encode_kv_key(key []byte) []byte { - ek := make([]byte, len(key)+1) - ek[0] = kvType - copy(ek[1:], key) +func checkKeySize(key []byte) error { + if len(key) > MaxKeySize { + return ErrKeySize + } + return nil +} + +func (db *DB) encodeKVKey(key []byte) []byte { + ek := make([]byte, len(key)+2) + ek[0] = db.index + ek[1] = kvType + copy(ek[2:], key) return ek } -func decode_kv_key(ek []byte) ([]byte, error) { - if len(ek) == 0 || ek[0] != kvType { +func (db *DB) decodeKVKey(ek []byte) ([]byte, error) { + if len(ek) < 2 || ek[0] != db.index || ek[1] != kvType { return nil, errKVKey } - return ek[1:], nil + return ek[2:], nil } func (db *DB) incr(key []byte, delta int64) (int64, error) { - key = encode_kv_key(key) + if err := checkKeySize(key); err != nil { + return 0, err + } + var err error + key = db.encodeKVKey(key) t := db.kvTx @@ -64,8 +76,9 @@ func (db *DB) Del(keys ...[]byte) (int64, error) { return 0, nil } + var err error for i := range keys { - keys[i] = encode_kv_key(keys[i]) + keys[i] = db.encodeKVKey(keys[i]) } t := db.kvTx @@ -78,13 +91,17 @@ func (db *DB) Del(keys ...[]byte) (int64, error) { //todo binlog } - err := t.Commit() + err = t.Commit() return int64(len(keys)), err } func (db *DB) Exists(key []byte) (int64, error) { - key = encode_kv_key(key) + if err := checkKeySize(key); err != nil { + return 0, err + } + var err error + key = db.encodeKVKey(key) var v []byte v, err = db.db.Get(key) @@ -96,13 +113,21 @@ func (db *DB) Exists(key []byte) (int64, error) { } func (db *DB) Get(key []byte) ([]byte, error) { - key = encode_kv_key(key) + if err := checkKeySize(key); err != nil { + return nil, err + } + + key = db.encodeKVKey(key) return db.db.Get(key) } func (db *DB) GetSet(key []byte, value []byte) ([]byte, error) { - key = encode_kv_key(key) + if err := checkKeySize(key); err != nil { + return nil, err + } + + key = db.encodeKVKey(key) t := db.kvTx @@ -133,10 +158,14 @@ func (db *DB) IncryBy(key []byte, increment int64) (int64, error) { func (db *DB) MGet(keys ...[]byte) ([]interface{}, error) { values := make([]interface{}, len(keys)) + var err error + var value []byte for i := range keys { - key := encode_kv_key(keys[i]) - value, err := db.db.Get(key) - if err != nil { + if err := checkKeySize(keys[i]); err != nil { + return nil, err + } + + if value, err = db.db.Get(db.encodeKVKey(keys[i])); err != nil { return nil, err } @@ -153,25 +182,38 @@ func (db *DB) MSet(args ...KVPair) error { t := db.kvTx + var err error + var key []byte + var value []byte + t.Lock() defer t.Unlock() for i := 0; i < len(args); i++ { - key := encode_kv_key(args[i].Key) - value := args[i].Value + if err := checkKeySize(args[i].Key); err != nil { + return err + } + + key = db.encodeKVKey(args[i].Key) + + value = args[i].Value t.Put(key, value) //todo binlog } - err := t.Commit() + err = t.Commit() return err } func (db *DB) Set(key []byte, value []byte) error { - key = encode_kv_key(key) + if err := checkKeySize(key); err != nil { + return err + } + var err error + key = db.encodeKVKey(key) t := db.kvTx @@ -188,8 +230,12 @@ func (db *DB) Set(key []byte, value []byte) error { } func (db *DB) SetNX(key []byte, value []byte) (int64, error) { - key = encode_kv_key(key) + if err := checkKeySize(key); err != nil { + return 0, err + } + var err error + key = db.encodeKVKey(key) var n int64 = 1 diff --git a/ledis/t_kv_test.go b/ledis/t_kv_test.go index f041b43..9bd21e6 100644 --- a/ledis/t_kv_test.go +++ b/ledis/t_kv_test.go @@ -4,6 +4,18 @@ import ( "testing" ) +func TestKVCodec(t *testing.T) { + db := getTestDB() + + ek := db.encodeKVKey([]byte("key")) + + if k, err := db.decodeKVKey(ek); err != nil { + t.Fatal(err) + } else if string(k) != "key" { + t.Fatal(string(k)) + } +} + func TestDBKV(t *testing.T) { db := getTestDB() diff --git a/ledis/t_list.go b/ledis/t_list.go index 2a8aa3e..bb18585 100644 --- a/ledis/t_list.go +++ b/ledis/t_list.go @@ -19,31 +19,34 @@ var errLMetaKey = errors.New("invalid lmeta key") var errListKey = errors.New("invalid list key") var errListSeq = errors.New("invalid list sequence, overflow") -func encode_lmeta_key(key []byte) []byte { - buf := make([]byte, len(key)+1) - buf[0] = lMetaType +func (db *DB) lEncodeMetaKey(key []byte) []byte { + buf := make([]byte, len(key)+2) + buf[0] = db.index + buf[1] = lMetaType - copy(buf[1:], key) + copy(buf[2:], key) return buf } -func decode_lmeta_key(ek []byte) ([]byte, error) { - if len(ek) == 0 || ek[0] != lMetaType { +func (db *DB) lDecodeMetaKey(ek []byte) ([]byte, error) { + if len(ek) < 2 || ek[0] != db.index || ek[1] != lMetaType { return nil, errLMetaKey } - return ek[1:], nil + return ek[2:], nil } -func encode_list_key(key []byte, seq int32) []byte { - buf := make([]byte, len(key)+9) +func (db *DB) lEncodeListKey(key []byte, seq int32) []byte { + buf := make([]byte, len(key)+8) pos := 0 + buf[pos] = db.index + pos++ buf[pos] = listType pos++ - binary.BigEndian.PutUint32(buf[pos:], uint32(len(key))) - pos += 4 + binary.BigEndian.PutUint16(buf[pos:], uint16(len(key))) + pos += 2 copy(buf[pos:], key) pos += len(key) @@ -53,25 +56,34 @@ func encode_list_key(key []byte, seq int32) []byte { return buf } -func decode_list_key(ek []byte) (key []byte, seq int32, err error) { - if len(ek) < 9 || ek[0] != listType { +func (db *DB) lDecodeListKey(ek []byte) (key []byte, seq int32, err error) { + if len(ek) < 8 || ek[0] != db.index || ek[1] != listType { err = errListKey return } - keyLen := int(binary.BigEndian.Uint32(ek[1:])) - if keyLen+9 != len(ek) { + keyLen := int(binary.BigEndian.Uint16(ek[2:])) + if keyLen+8 != len(ek) { err = errListKey return } - key = ek[5 : 5+keyLen] - seq = int32(binary.BigEndian.Uint32(ek[5+keyLen:])) + key = ek[4 : 4+keyLen] + seq = int32(binary.BigEndian.Uint32(ek[4+keyLen:])) return } func (db *DB) lpush(key []byte, whereSeq int32, args ...[]byte) (int64, error) { - metaKey := encode_lmeta_key(key) + if err := checkKeySize(key); err != nil { + return 0, err + } + + var headSeq int32 + var tailSeq int32 + var size int32 + var err error + + metaKey := db.lEncodeMetaKey(key) if len(args) == 0 { _, _, size, err := db.lGetMeta(metaKey) @@ -82,9 +94,7 @@ func (db *DB) lpush(key []byte, whereSeq int32, args ...[]byte) (int64, error) { t.Lock() defer t.Unlock() - headSeq, tailSeq, size, err := db.lGetMeta(metaKey) - - if err != nil { + if headSeq, tailSeq, size, err = db.lGetMeta(metaKey); err != nil { return 0, err } @@ -106,7 +116,8 @@ func (db *DB) lpush(key []byte, whereSeq int32, args ...[]byte) (int64, error) { } for i := 0; i < len(args); i++ { - t.Put(encode_list_key(key, seq+int32(i)*delta), args[i]) + ek := db.lEncodeListKey(key, seq+int32(i)*delta) + t.Put(ek, args[i]) //to do add binlog } @@ -132,12 +143,22 @@ func (db *DB) lpush(key []byte, whereSeq int32, args ...[]byte) (int64, error) { } func (db *DB) lpop(key []byte, whereSeq int32) ([]byte, error) { + if err := checkKeySize(key); err != nil { + return nil, err + } + t := db.listTx t.Lock() defer t.Unlock() - metaKey := encode_lmeta_key(key) - headSeq, tailSeq, size, err := db.lGetMeta(metaKey) + var headSeq int32 + var tailSeq int32 + var size int32 + var err error + + metaKey := db.lEncodeMetaKey(key) + + headSeq, tailSeq, size, err = db.lGetMeta(metaKey) if err != nil { return nil, err @@ -152,7 +173,7 @@ func (db *DB) lpop(key []byte, whereSeq int32) ([]byte, error) { seq = tailSeq } - itemKey := encode_list_key(key, seq) + itemKey := db.lEncodeListKey(key, seq) var value []byte value, err = db.db.Get(itemKey) if err != nil { @@ -181,7 +202,7 @@ func (db *DB) lpop(key []byte, whereSeq int32) ([]byte, error) { } func (db *DB) lGetSeq(key []byte, whereSeq int32) (int64, error) { - ek := encode_list_key(key, whereSeq) + ek := db.lEncodeListKey(key, whereSeq) return Int64(db.db.Get(ek)) } @@ -215,8 +236,18 @@ func (db *DB) lSetMeta(ek []byte, headSeq int32, tailSeq int32, size int32) { } func (db *DB) LIndex(key []byte, index int32) ([]byte, error) { + if err := checkKeySize(key); err != nil { + return nil, err + } + var seq int32 - headSeq, tailSeq, _, err := db.lGetMeta(encode_lmeta_key(key)) + var headSeq int32 + var tailSeq int32 + var err error + + metaKey := db.lEncodeMetaKey(key) + + headSeq, tailSeq, _, err = db.lGetMeta(metaKey) if err != nil { return nil, err } @@ -227,11 +258,16 @@ func (db *DB) LIndex(key []byte, index int32) ([]byte, error) { seq = tailSeq + index + 1 } - return db.db.Get(encode_list_key(key, seq)) + sk := db.lEncodeListKey(key, seq) + return db.db.Get(sk) } func (db *DB) LLen(key []byte) (int64, error) { - ek := encode_lmeta_key(key) + if err := checkKeySize(key); err != nil { + return 0, err + } + + ek := db.lEncodeMetaKey(key) _, _, size, err := db.lGetMeta(ek) return int64(size), err } @@ -245,6 +281,10 @@ func (db *DB) LPush(key []byte, args ...[]byte) (int64, error) { } func (db *DB) LRange(key []byte, start int32, stop int32) ([]interface{}, error) { + if err := checkKeySize(key); err != nil { + return nil, err + } + v := make([]interface{}, 0, 16) var startSeq int32 @@ -254,8 +294,13 @@ func (db *DB) LRange(key []byte, start int32, stop int32) ([]interface{}, error) return []interface{}{}, nil } - headSeq, tailSeq, _, err := db.lGetMeta(encode_lmeta_key(key)) - if err != nil { + var headSeq int32 + var tailSeq int32 + var err error + + metaKey := db.lEncodeMetaKey(key) + + if headSeq, tailSeq, _, err = db.lGetMeta(metaKey); err != nil { return nil, err } @@ -277,8 +322,9 @@ func (db *DB) LRange(key []byte, start int32, stop int32) ([]interface{}, error) stopSeq = listMaxSeq } - it := db.db.Iterator(encode_list_key(key, startSeq), - encode_list_key(key, stopSeq), leveldb.RangeClose, 0, -1) + startKey := db.lEncodeListKey(key, startSeq) + stopKey := db.lEncodeListKey(key, stopSeq) + it := db.db.Iterator(startKey, stopKey, leveldb.RangeClose, 0, -1) for ; it.Valid(); it.Next() { v = append(v, it.Value()) } @@ -297,22 +343,31 @@ func (db *DB) RPush(key []byte, args ...[]byte) (int64, error) { } func (db *DB) LClear(key []byte) (int64, error) { - mk := encode_lmeta_key(key) + if err := checkKeySize(key); err != nil { + return 0, err + } + + mk := db.lEncodeMetaKey(key) t := db.listTx t.Lock() defer t.Unlock() - metaKey := encode_lmeta_key(key) - headSeq, tailSeq, _, err := db.lGetMeta(metaKey) + var headSeq int32 + var tailSeq int32 + var err error + + headSeq, tailSeq, _, err = db.lGetMeta(mk) if err != nil { return 0, err } var num int64 = 0 - it := db.db.Iterator(encode_list_key(key, headSeq), - encode_list_key(key, tailSeq), leveldb.RangeClose, 0, -1) + startKey := db.lEncodeListKey(key, headSeq) + stopKey := db.lEncodeListKey(key, tailSeq) + + it := db.db.Iterator(startKey, stopKey, leveldb.RangeClose, 0, -1) for ; it.Valid(); it.Next() { t.Delete(it.Key()) num++ diff --git a/ledis/t_list_test.go b/ledis/t_list_test.go index 056602e..fe5d24d 100644 --- a/ledis/t_list_test.go +++ b/ledis/t_list_test.go @@ -4,6 +4,28 @@ import ( "testing" ) +func TestListCodec(t *testing.T) { + db := getTestDB() + + key := []byte("key") + + ek := db.lEncodeMetaKey(key) + if k, err := db.lDecodeMetaKey(ek); err != nil { + t.Fatal(err) + } else if string(k) != "key" { + t.Fatal(string(k)) + } + + ek = db.lEncodeListKey(key, 1024) + if k, seq, err := db.lDecodeListKey(ek); err != nil { + t.Fatal(err) + } else if string(k) != "key" { + t.Fatal(string(k)) + } else if seq != 1024 { + t.Fatal(seq) + } +} + func TestDBList(t *testing.T) { db := getTestDB() diff --git a/ledis/t_zset.go b/ledis/t_zset.go index 013e66b..bd1318c 100644 --- a/ledis/t_zset.go +++ b/ledis/t_zset.go @@ -31,31 +31,44 @@ const ( zsetStopMemSep byte = zsetStartMemSep + 1 ) -func encode_zsize_key(key []byte) []byte { - buf := make([]byte, len(key)+1) - buf[0] = zSizeType +func checkZSetKMSize(key []byte, member []byte) error { + if len(key) > MaxKeySize { + return ErrKeySize + } else if len(member) > MaxZSetMemberSize { + return ErrZSetMemberSize + } + return nil +} - copy(buf[1:], key) +func (db *DB) zEncodeSizeKey(key []byte) []byte { + buf := make([]byte, len(key)+2) + buf[0] = db.index + buf[1] = zSizeType + + copy(buf[2:], key) return buf } -func decode_zsize_key(ek []byte) ([]byte, error) { - if len(ek) == 0 || ek[0] != zSizeType { +func (db *DB) zDecodeSizeKey(ek []byte) ([]byte, error) { + if len(ek) < 2 || ek[0] != db.index || ek[1] != zSizeType { return nil, errZSizeKey } - return ek[1:], nil + return ek[2:], nil } -func encode_zset_key(key []byte, member []byte) []byte { - buf := make([]byte, len(key)+len(member)+5) +func (db *DB) zEncodeSetKey(key []byte, member []byte) []byte { + buf := make([]byte, len(key)+len(member)+4) pos := 0 + buf[pos] = db.index + pos++ + buf[pos] = zsetType pos++ - binary.BigEndian.PutUint32(buf[pos:], uint32(len(key))) - pos += 4 + binary.BigEndian.PutUint16(buf[pos:], uint16(len(key))) + pos += 2 copy(buf[pos:], key) pos += len(key) @@ -65,30 +78,33 @@ func encode_zset_key(key []byte, member []byte) []byte { return buf } -func decode_zset_key(ek []byte) ([]byte, []byte, error) { - if len(ek) < 5 || ek[0] != zsetType { +func (db *DB) zDecodeSetKey(ek []byte) ([]byte, []byte, error) { + if len(ek) < 4 || ek[0] != db.index || ek[1] != zsetType { return nil, nil, errZSetKey } - keyLen := int(binary.BigEndian.Uint32(ek[1:])) - if keyLen+5 > len(ek) { + keyLen := int(binary.BigEndian.Uint16(ek[2:])) + if keyLen+4 > len(ek) { return nil, nil, errZSetKey } - key := ek[5 : 5+keyLen] - member := ek[5+keyLen:] + key := ek[4 : 4+keyLen] + member := ek[4+keyLen:] return key, member, nil } -func encode_zscore_key(key []byte, member []byte, score int64) []byte { - buf := make([]byte, len(key)+len(member)+15) +func (db *DB) zEncodeScoreKey(key []byte, member []byte, score int64) []byte { + buf := make([]byte, len(key)+len(member)+14) pos := 0 + buf[pos] = db.index + pos++ + buf[pos] = zScoreType pos++ - binary.BigEndian.PutUint32(buf[pos:], uint32(len(key))) - pos += 4 + binary.BigEndian.PutUint16(buf[pos:], uint16(len(key))) + pos += 2 copy(buf[pos:], key) pos += len(key) @@ -110,31 +126,30 @@ func encode_zscore_key(key []byte, member []byte, score int64) []byte { return buf } -func encode_start_zscore_key(key []byte, score int64) []byte { - k := encode_zscore_key(key, nil, score) - return k +func (db *DB) zEncodeStartScoreKey(key []byte, score int64) []byte { + return db.zEncodeScoreKey(key, nil, score) } -func encode_stop_zscore_key(key []byte, score int64) []byte { - k := encode_zscore_key(key, nil, score) +func (db *DB) zEncodeStopScoreKey(key []byte, score int64) []byte { + k := db.zEncodeScoreKey(key, nil, score) k[len(k)-1] = zsetStopMemSep return k } -func decode_zscore_key(ek []byte) (key []byte, member []byte, score int64, err error) { - if len(ek) < 15 || ek[0] != zScoreType { +func (db *DB) zDecodeScoreKey(ek []byte) (key []byte, member []byte, score int64, err error) { + if len(ek) < 14 || ek[0] != db.index || ek[1] != zScoreType { err = errZScoreKey return } - keyLen := int(binary.BigEndian.Uint32(ek[1:])) + keyLen := int(binary.BigEndian.Uint16(ek[2:])) if keyLen+14 > len(ek) { err = errZScoreKey return } - key = ek[5 : 5+keyLen] - pos := 5 + keyLen + key = ek[4 : 4+keyLen] + pos := 4 + keyLen if (ek[pos] != zsetNScoreSep) && (ek[pos] != zsetPScoreSep) { err = errZScoreKey @@ -164,7 +179,8 @@ func (db *DB) zSetItem(key []byte, score int64, member []byte) (int64, error) { t := db.zsetTx var exists int64 = 0 - ek := encode_zset_key(key, member) + ek := db.zEncodeSetKey(key, member) + if v, err := db.db.Get(ek); err != nil { return 0, err } else if v != nil { @@ -173,14 +189,14 @@ func (db *DB) zSetItem(key []byte, score int64, member []byte) (int64, error) { if s, err := Int64(v, err); err != nil { return 0, err } else { - sk := encode_zscore_key(key, member, s) + sk := db.zEncodeScoreKey(key, member, s) t.Delete(sk) } } t.Put(ek, PutInt64(score)) - sk := encode_zscore_key(key, member, score) + sk := db.zEncodeScoreKey(key, member, score) t.Put(sk, []byte{}) return exists, nil @@ -189,7 +205,7 @@ func (db *DB) zSetItem(key []byte, score int64, member []byte) (int64, error) { func (db *DB) zDelItem(key []byte, member []byte, skipDelScore bool) (int64, error) { t := db.zsetTx - ek := encode_zset_key(key, member) + ek := db.zEncodeSetKey(key, member) if v, err := db.db.Get(ek); err != nil { return 0, err } else if v == nil { @@ -202,7 +218,7 @@ func (db *DB) zDelItem(key []byte, member []byte, skipDelScore bool) (int64, err if s, err := Int64(v, err); err != nil { return 0, err } else { - sk := encode_zscore_key(key, member, s) + sk := db.zEncodeScoreKey(key, member, s) t.Delete(sk) } } @@ -226,6 +242,10 @@ func (db *DB) ZAdd(key []byte, args ...ScorePair) (int64, error) { score := args[i].Score member := args[i].Member + if err := checkZSetKMSize(key, member); err != nil { + return 0, err + } + if n, err := db.zSetItem(key, score, member); err != nil { return 0, err } else if n == 0 { @@ -245,7 +265,8 @@ func (db *DB) ZAdd(key []byte, args ...ScorePair) (int64, error) { func (db *DB) zIncrSize(key []byte, delta int64) (int64, error) { t := db.zsetTx - sk := encode_zsize_key(key) + sk := db.zEncodeSizeKey(key) + size, err := Int64(db.db.Get(sk)) if err != nil { return 0, err @@ -263,13 +284,21 @@ func (db *DB) zIncrSize(key []byte, delta int64) (int64, error) { } func (db *DB) ZCard(key []byte) (int64, error) { - sk := encode_zsize_key(key) - size, err := Int64(db.db.Get(sk)) - return size, err + if err := checkKeySize(key); err != nil { + return 0, err + } + + sk := db.zEncodeSizeKey(key) + return Int64(db.db.Get(sk)) } func (db *DB) ZScore(key []byte, member []byte) ([]byte, error) { - k := encode_zset_key(key, member) + if err := checkZSetKMSize(key, member); err != nil { + return nil, err + } + + k := db.zEncodeSetKey(key, member) + score, err := Int64(db.db.Get(k)) if err != nil { return nil, err @@ -289,6 +318,10 @@ func (db *DB) ZRem(key []byte, members ...[]byte) (int64, error) { var num int64 = 0 for i := 0; i < len(members); i++ { + if err := checkZSetKMSize(key, members[i]); err != nil { + return 0, err + } + if n, err := db.zDelItem(key, members[i], false); err != nil { return 0, err } else if n == 1 { @@ -305,12 +338,18 @@ func (db *DB) ZRem(key []byte, members ...[]byte) (int64, error) { } func (db *DB) ZIncrBy(key []byte, delta int64, member []byte) ([]byte, error) { + if err := checkZSetKMSize(key, member); err != nil { + return nil, err + } + t := db.zsetTx t.Lock() defer t.Unlock() - ek := encode_zset_key(key, member) + ek := db.zEncodeSetKey(key, member) + var score int64 = delta + v, err := db.db.Get(ek) if err != nil { return nil, err @@ -318,7 +357,7 @@ func (db *DB) ZIncrBy(key []byte, delta int64, member []byte) ([]byte, error) { if s, err := Int64(v, err); err != nil { return nil, err } else { - sk := encode_zscore_key(key, member, s) + sk := db.zEncodeScoreKey(key, member, s) t.Delete(sk) score = s + delta @@ -333,15 +372,19 @@ func (db *DB) ZIncrBy(key []byte, delta int64, member []byte) ([]byte, error) { t.Put(ek, PutInt64(score)) - t.Put(encode_zscore_key(key, member, score), []byte{}) + sk := db.zEncodeScoreKey(key, member, score) + t.Put(sk, []byte{}) err = t.Commit() return StrPutInt64(score), err } func (db *DB) ZCount(key []byte, min int64, max int64) (int64, error) { - minKey := encode_start_zscore_key(key, min) - maxKey := encode_stop_zscore_key(key, max) + if err := checkKeySize(key); err != nil { + return 0, err + } + minKey := db.zEncodeStartScoreKey(key, min) + maxKey := db.zEncodeStopScoreKey(key, max) rangeType := leveldb.RangeROpen @@ -356,7 +399,11 @@ func (db *DB) ZCount(key []byte, min int64, max int64) (int64, error) { } func (db *DB) zrank(key []byte, member []byte, reverse bool) (int64, error) { - k := encode_zset_key(key, member) + if err := checkZSetKMSize(key, member); err != nil { + return 0, err + } + + k := db.zEncodeSetKey(key, member) if v, err := db.db.Get(k); err != nil { return 0, err @@ -368,13 +415,13 @@ func (db *DB) zrank(key []byte, member []byte, reverse bool) (int64, error) { } else { var it *leveldb.Iterator - sk := encode_zscore_key(key, member, s) + sk := db.zEncodeScoreKey(key, member, s) if !reverse { - minKey := encode_start_zscore_key(key, MinScore) + minKey := db.zEncodeStartScoreKey(key, MinScore) it = db.db.Iterator(minKey, sk, leveldb.RangeClose, 0, -1) } else { - maxKey := encode_stop_zscore_key(key, MaxScore) + maxKey := db.zEncodeStopScoreKey(key, MaxScore) it = db.db.RevIterator(sk, maxKey, leveldb.RangeClose, 0, -1) } @@ -389,7 +436,7 @@ func (db *DB) zrank(key []byte, member []byte, reverse bool) (int64, error) { it.Close() - if _, m, _, err := decode_zscore_key(lastKey); err == nil && bytes.Equal(m, member) { + if _, m, _, err := db.zDecodeScoreKey(lastKey); err == nil && bytes.Equal(m, member) { n-- return n, nil } @@ -400,8 +447,8 @@ func (db *DB) zrank(key []byte, member []byte, reverse bool) (int64, error) { } func (db *DB) zIterator(key []byte, min int64, max int64, offset int, limit int, reverse bool) *leveldb.Iterator { - minKey := encode_start_zscore_key(key, min) - maxKey := encode_stop_zscore_key(key, max) + minKey := db.zEncodeStartScoreKey(key, min) + maxKey := db.zEncodeStopScoreKey(key, max) if !reverse { return db.db.Iterator(minKey, maxKey, leveldb.RangeClose, offset, limit) @@ -411,6 +458,10 @@ func (db *DB) zIterator(key []byte, min int64, max int64, offset int, limit int, } func (db *DB) zRemRange(key []byte, min int64, max int64, offset int, limit int) (int64, error) { + if len(key) > MaxKeySize { + return 0, ErrKeySize + } + t := db.zsetTx t.Lock() defer t.Unlock() @@ -419,7 +470,7 @@ func (db *DB) zRemRange(key []byte, min int64, max int64, offset int, limit int) var num int64 = 0 for ; it.Valid(); it.Next() { k := it.Key() - _, m, _, err := decode_zscore_key(k) + _, m, _, err := db.zDecodeScoreKey(k) if err != nil { continue } @@ -459,6 +510,10 @@ func (db *DB) zReverse(s []interface{}, withScores bool) []interface{} { } func (db *DB) zRange(key []byte, min int64, max int64, withScores bool, offset int, limit int, reverse bool) ([]interface{}, error) { + if len(key) > MaxKeySize { + return nil, ErrKeySize + } + if offset < 0 { return []interface{}{}, nil } @@ -483,7 +538,7 @@ func (db *DB) zRange(key []byte, min int64, max int64, withScores bool, offset i } for ; it.Valid(); it.Next() { - _, m, s, err := decode_zscore_key(it.Key()) + _, m, s, err := db.zDecodeScoreKey(it.Key()) //may be we will check key equal? if err != nil { continue diff --git a/ledis/t_zset_test.go b/ledis/t_zset_test.go index fcbe32a..3eed420 100644 --- a/ledis/t_zset_test.go +++ b/ledis/t_zset_test.go @@ -4,6 +4,41 @@ import ( "testing" ) +func TestZSetCodec(t *testing.T) { + db := getTestDB() + + key := []byte("key") + member := []byte("member") + + ek := db.zEncodeSizeKey(key) + if k, err := db.zDecodeSizeKey(ek); err != nil { + t.Fatal(err) + } else if string(k) != "key" { + t.Fatal(string(k)) + } + + ek = db.zEncodeSetKey(key, member) + if k, m, err := db.zDecodeSetKey(ek); err != nil { + t.Fatal(err) + } else if string(k) != "key" { + t.Fatal(string(k)) + } else if string(m) != "member" { + t.Fatal(string(m)) + } + + ek = db.zEncodeScoreKey(key, member, 100) + if k, m, s, err := db.zDecodeScoreKey(ek); err != nil { + t.Fatal(err) + } else if string(k) != "key" { + t.Fatal(string(k)) + } else if string(m) != "member" { + t.Fatal(string(m)) + } else if s != 100 { + t.Fatal(s) + } + +} + func TestDBZSet(t *testing.T) { db := getTestDB() diff --git a/server/app.go b/server/app.go index 99e7dce..9e7e029 100644 --- a/server/app.go +++ b/server/app.go @@ -11,7 +11,7 @@ type App struct { listener net.Listener - db *ledis.DB + ldb *ledis.Ledis closed bool } @@ -35,7 +35,7 @@ func NewApp(cfg *Config) (*App, error) { return nil, err } - app.db, err = ledis.OpenDBWithConfig(&cfg.DB) + app.ldb, err = ledis.OpenWithConfig(&cfg.DB) if err != nil { return nil, err } @@ -50,7 +50,7 @@ func (app *App) Close() { app.listener.Close() - app.db.Close() + app.ldb.Close() app.closed = true } @@ -62,6 +62,6 @@ func (app *App) Run() { continue } - newClient(conn, app.db) + newClient(conn, app.ldb) } } diff --git a/server/client.go b/server/client.go index f89f38e..9fd1d68 100644 --- a/server/client.go +++ b/server/client.go @@ -15,6 +15,8 @@ import ( var errReadRequest = errors.New("invalid request protocol") type client struct { + ldb *ledis.Ledis + db *ledis.DB c net.Conn @@ -27,9 +29,11 @@ type client struct { reqC chan error } -func newClient(c net.Conn, db *ledis.DB) { +func newClient(c net.Conn, ldb *ledis.Ledis) { co := new(client) - co.db = db + co.ldb = ldb + //use default db + co.db, _ = ldb.Select(0) co.c = c co.rb = bufio.NewReaderSize(c, 256) diff --git a/server/command.go b/server/command.go index 0a64810..b54e11e 100644 --- a/server/command.go +++ b/server/command.go @@ -2,6 +2,9 @@ package server import ( "fmt" + "github.com/siddontang/ledisdb/ledis" + "strconv" + "strings" ) @@ -31,7 +34,26 @@ func echoCommand(c *client) error { return nil } +func selectCommand(c *client) error { + if len(c.args) != 1 { + return ErrCmdParams + } + + if index, err := strconv.Atoi(ledis.String(c.args[0])); err != nil { + return err + } else { + if db, err := c.ldb.Select(index); err != nil { + return err + } else { + c.db = db + c.writeStatus(OK) + } + } + return nil +} + func init() { register("ping", pingCommand) register("echo", echoCommand) + register("select", selectCommand) } diff --git a/server/config.go b/server/config.go index ee9b503..5e98694 100644 --- a/server/config.go +++ b/server/config.go @@ -9,7 +9,7 @@ import ( type Config struct { Addr string `json:"addr"` - DB ledis.DBConfig `json:"db"` + DB ledis.Config `json:"db"` } func NewConfig(data json.RawMessage) (*Config, error) {