diff --git a/ledis/const.go b/ledis/const.go index 393b97d..3c9b99c 100644 --- a/ledis/const.go +++ b/ledis/const.go @@ -104,7 +104,7 @@ var ( ) const ( - MaxDatabases int = 16 + MaxDatabases int = 10240 //max key size MaxKeySize int = 1024 diff --git a/ledis/event.go b/ledis/event.go index b9a4833..1f5b792 100644 --- a/ledis/event.go +++ b/ledis/event.go @@ -18,7 +18,11 @@ func formatEventKey(buf []byte, k []byte) ([]byte, error) { buf = append(buf, fmt.Sprintf("%s ", TypeName[k[1]])...) db := new(DB) - db.index = k[0] + index, _, err := decodeDBIndex(k) + if err != nil { + return nil, err + } + db.setIndex(index) //to do format at respective place diff --git a/ledis/ledis_db.go b/ledis/ledis_db.go index 21811b5..3ceeaae 100644 --- a/ledis/ledis_db.go +++ b/ledis/ledis_db.go @@ -1,6 +1,8 @@ package ledis import ( + "bytes" + "encoding/binary" "fmt" "github.com/siddontang/ledisdb/store" "sync" @@ -30,7 +32,10 @@ type DB struct { bucket ibucket - index uint8 + index int + + // buffer to store index varint + indexVarBuf []byte kvBatch *batch listBatch *batch @@ -56,7 +61,7 @@ func (l *Ledis) newDB(index int) *DB { d.bucket = d.sdb // d.status = DBAutoCommit - d.index = uint8(index) + d.setIndex(index) d.kvBatch = d.newBatch() d.listBatch = d.newBatch() @@ -72,6 +77,37 @@ func (l *Ledis) newDB(index int) *DB { return d } +func decodeDBIndex(buf []byte) (int, int, error) { + index, n := binary.Uvarint(buf) + if n == 0 { + return 0, 0, fmt.Errorf("buf is too small to save index") + } else if n < 0 { + return 0, 0, fmt.Errorf("value larger than 64 bits") + } else if index > uint64(MaxDatabases) { + return 0, 0, fmt.Errorf("value %d is larger than max databases %d", index, MaxDatabases) + } + return int(index), n, nil +} + +func (db *DB) setIndex(index int) { + db.index = index + // the most size for varint is 10 bytes + buf := make([]byte, 10) + n := binary.PutUvarint(buf, uint64(index)) + + db.indexVarBuf = buf[0:n] +} + +func (db *DB) checkKeyIndex(buf []byte) (int, error) { + if len(buf) < len(db.indexVarBuf) { + return 0, fmt.Errorf("key is too small") + } else if !bytes.Equal(db.indexVarBuf, buf[0:len(db.indexVarBuf)]) { + return 0, fmt.Errorf("invalid db index") + } + + return len(db.indexVarBuf), nil +} + func (db *DB) newTTLChecker() *ttlChecker { c := new(ttlChecker) c.db = db diff --git a/ledis/ledis_test.go b/ledis/ledis_test.go index 51e4d0d..a0c9879 100644 --- a/ledis/ledis_test.go +++ b/ledis/ledis_test.go @@ -37,24 +37,85 @@ func TestDB(t *testing.T) { func TestSelect(t *testing.T) { db0, _ := testLedis.Select(0) db1, _ := testLedis.Select(1) + db1024, _ := testLedis.Select(1024) - key0 := []byte("db0_test_key") - key1 := []byte("db1_test_key") + testSelect(t, db0) + testSelect(t, db1) + testSelect(t, db1024) +} - db0.Set(key0, []byte("0")) - db1.Set(key1, []byte("1")) - - if v, err := db0.Get(key0); err != nil { +func testSelect(t *testing.T, db *DB) { + key := []byte("test_select_key") + value := []byte("value") + if err := db.Set(key, value); err != nil { t.Fatal(err) - } else if string(v) != "0" { + } + + if v, err := db.Get(key); err != nil { + t.Fatal(err) + } else if string(v) != string(value) { t.Fatal(string(v)) } - if v, err := db1.Get(key1); err != nil { + if _, err := db.Expire(key, 100); err != nil { t.Fatal(err) - } else if string(v) != "1" { + } + + if _, err := db.TTL(key); err != nil { + t.Fatal(err) + } + + if _, err := db.Persist(key); err != nil { + t.Fatal(err) + } + + key = []byte("test_select_list_key") + if _, err := db.LPush(key, value); err != nil { + t.Fatal(err) + } + + if _, err := db.LRange(key, 0, 100); err != nil { + t.Fatal(err) + } + + if v, err := db.LPop(key); err != nil { + t.Fatal(err) + } else if string(v) != string(value) { t.Fatal(string(v)) } + + key = []byte("test_select_hash_key") + if _, err := db.HSet(key, []byte("a"), value); err != nil { + t.Fatal(err) + } + + if v, err := db.HGet(key, []byte("a")); err != nil { + t.Fatal(err) + } else if string(v) != string(value) { + t.Fatal(string(v)) + } + + key = []byte("test_select_set_key") + if _, err := db.SAdd(key, []byte("a"), []byte("b")); err != nil { + t.Fatal(err) + } + + if n, err := db.SIsMember(key, []byte("a")); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Fatal(n) + } + + key = []byte("test_select_zset_key") + if _, err := db.ZAdd(key, ScorePair{1, []byte("a")}, ScorePair{2, []byte("b")}); err != nil { + t.Fatal(err) + } + + if v, err := db.ZRangeByScore(key, 0, 100, 0, -1); err != nil { + t.Fatal(err) + } else if len(v) != 2 { + t.Fatal(len(v)) + } } func TestFlush(t *testing.T) { diff --git a/ledis/scan.go b/ledis/scan.go index cdf18dc..94bff5d 100644 --- a/ledis/scan.go +++ b/ledis/scan.go @@ -178,11 +178,22 @@ func (db *DB) encodeScanKey(storeDataType byte, key []byte) ([]byte, error) { } } -func (db *DB) decodeScanKey(storeDataType byte, ek []byte) ([]byte, error) { - if len(ek) < 2 || ek[0] != db.index || ek[1] != storeDataType { - return nil, errMetaKey +func (db *DB) decodeScanKey(storeDataType byte, ek []byte) (key []byte, err error) { + switch storeDataType { + case KVType: + key, err = db.decodeKVKey(ek) + case LMetaType: + key, err = db.lDecodeMetaKey(ek) + case HSizeType: + key, err = db.hDecodeSizeKey(ek) + case ZSizeType: + key, err = db.zDecodeSizeKey(ek) + case SSizeType: + key, err = db.sDecodeSizeKey(ek) + default: + err = errDataType } - return ek[2:], nil + return } // for specail data scan diff --git a/ledis/t_hash.go b/ledis/t_hash.go index cd158cb..9914ea8 100644 --- a/ledis/t_hash.go +++ b/ledis/t_hash.go @@ -31,29 +31,41 @@ func checkHashKFSize(key []byte, field []byte) error { } func (db *DB) hEncodeSizeKey(key []byte) []byte { - buf := make([]byte, len(key)+2) + buf := make([]byte, len(key)+1+len(db.indexVarBuf)) - buf[0] = db.index - buf[1] = HSizeType + pos := 0 + n := copy(buf, db.indexVarBuf) + + pos += n + buf[pos] = HSizeType + + pos++ + copy(buf[pos:], key) - copy(buf[2:], key) return buf } func (db *DB) hDecodeSizeKey(ek []byte) ([]byte, error) { - if len(ek) < 2 || ek[0] != db.index || ek[1] != HSizeType { - return nil, errHSizeKey + pos, err := db.checkKeyIndex(ek) + if err != nil { + return nil, err } - return ek[2:], nil + if pos+1 > len(ek) || ek[pos] != HSizeType { + return nil, errHSizeKey + } + pos++ + + return ek[pos:], nil } func (db *DB) hEncodeHashKey(key []byte, field []byte) []byte { - buf := make([]byte, len(key)+len(field)+1+1+2+1) + buf := make([]byte, len(key)+len(field)+1+1+2+len(db.indexVarBuf)) pos := 0 - buf[pos] = db.index - pos++ + n := copy(buf, db.indexVarBuf) + pos += n + buf[pos] = HashType pos++ @@ -71,15 +83,24 @@ func (db *DB) hEncodeHashKey(key []byte, field []byte) []byte { } func (db *DB) hDecodeHashKey(ek []byte) ([]byte, []byte, error) { - if len(ek) < 5 || ek[0] != db.index || ek[1] != HashType { + pos, err := db.checkKeyIndex(ek) + if err != nil { + return nil, nil, err + } + + if pos+1 > len(ek) || ek[pos] != HashType { + return nil, nil, errHashKey + } + pos++ + + if pos+2 > len(ek) { return nil, nil, errHashKey } - pos := 2 keyLen := int(binary.BigEndian.Uint16(ek[pos:])) pos += 2 - if keyLen+5 > len(ek) { + if keyLen+pos > len(ek) { return nil, nil, errHashKey } diff --git a/ledis/t_kv.go b/ledis/t_kv.go index eb2f1fe..8c85c52 100644 --- a/ledis/t_kv.go +++ b/ledis/t_kv.go @@ -33,19 +33,26 @@ func checkValueSize(value []byte) error { } func (db *DB) encodeKVKey(key []byte) []byte { - ek := make([]byte, len(key)+2) - ek[0] = db.index - ek[1] = KVType - copy(ek[2:], key) + ek := make([]byte, len(key)+1+len(db.indexVarBuf)) + pos := copy(ek, db.indexVarBuf) + ek[pos] = KVType + pos++ + copy(ek[pos:], key) return ek } func (db *DB) decodeKVKey(ek []byte) ([]byte, error) { - if len(ek) < 2 || ek[0] != db.index || ek[1] != KVType { + pos, err := db.checkKeyIndex(ek) + if err != nil { + return nil, err + } + if pos+1 > len(ek) || ek[pos] != KVType { return nil, errKVKey } - return ek[2:], nil + pos++ + + return ek[pos:], nil } func (db *DB) encodeKVMinKey() []byte { diff --git a/ledis/t_list.go b/ledis/t_list.go index ce6eeb0..c872c7b 100644 --- a/ledis/t_list.go +++ b/ledis/t_list.go @@ -24,28 +24,34 @@ var errListKey = errors.New("invalid list key") var errListSeq = errors.New("invalid list sequence, overflow") func (db *DB) lEncodeMetaKey(key []byte) []byte { - buf := make([]byte, len(key)+2) - buf[0] = db.index - buf[1] = LMetaType + buf := make([]byte, len(key)+1+len(db.indexVarBuf)) + pos := copy(buf, db.indexVarBuf) + buf[pos] = LMetaType + pos++ - copy(buf[2:], key) + copy(buf[pos:], key) return buf } func (db *DB) lDecodeMetaKey(ek []byte) ([]byte, error) { - if len(ek) < 2 || ek[0] != db.index || ek[1] != LMetaType { + pos, err := db.checkKeyIndex(ek) + if err != nil { + return nil, err + } + + if pos+1 > len(ek) || ek[pos] != LMetaType { return nil, errLMetaKey } - return ek[2:], nil + pos++ + return ek[pos:], nil } func (db *DB) lEncodeListKey(key []byte, seq int32) []byte { - buf := make([]byte, len(key)+8) + buf := make([]byte, len(key)+7+len(db.indexVarBuf)) + + pos := copy(buf, db.indexVarBuf) - pos := 0 - buf[pos] = db.index - pos++ buf[pos] = ListType pos++ @@ -61,19 +67,33 @@ func (db *DB) lEncodeListKey(key []byte, seq int32) []byte { } func (db *DB) lDecodeListKey(ek []byte) (key []byte, seq int32, err error) { - if len(ek) < 8 || ek[0] != db.index || ek[1] != ListType { + pos := 0 + pos, err = db.checkKeyIndex(ek) + if err != nil { + return + } + + if pos+1 > len(ek) || ek[pos] != ListType { err = errListKey return } - keyLen := int(binary.BigEndian.Uint16(ek[2:])) - if keyLen+8 != len(ek) { + pos++ + + if pos+2 > len(ek) { err = errListKey return } - key = ek[4 : 4+keyLen] - seq = int32(binary.BigEndian.Uint32(ek[4+keyLen:])) + keyLen := int(binary.BigEndian.Uint16(ek[pos:])) + pos += 2 + if keyLen+pos+4 != len(ek) { + err = errListKey + return + } + + key = ek[pos : pos+keyLen] + seq = int32(binary.BigEndian.Uint32(ek[pos+keyLen:])) return } diff --git a/ledis/t_set.go b/ledis/t_set.go index 3980768..998170a 100644 --- a/ledis/t_set.go +++ b/ledis/t_set.go @@ -29,29 +29,36 @@ func checkSetKMSize(key []byte, member []byte) error { } func (db *DB) sEncodeSizeKey(key []byte) []byte { - buf := make([]byte, len(key)+2) + buf := make([]byte, len(key)+1+len(db.indexVarBuf)) - buf[0] = db.index - buf[1] = SSizeType + pos := copy(buf, db.indexVarBuf) + buf[pos] = SSizeType - copy(buf[2:], key) + pos++ + + copy(buf[pos:], key) return buf } func (db *DB) sDecodeSizeKey(ek []byte) ([]byte, error) { - if len(ek) < 2 || ek[0] != db.index || ek[1] != SSizeType { - return nil, errSSizeKey + pos, err := db.checkKeyIndex(ek) + if err != nil { + return nil, err } - return ek[2:], nil + if pos+1 > len(ek) || ek[pos] != SSizeType { + return nil, errSSizeKey + } + pos++ + + return ek[pos:], nil } func (db *DB) sEncodeSetKey(key []byte, member []byte) []byte { - buf := make([]byte, len(key)+len(member)+1+1+2+1) + buf := make([]byte, len(key)+len(member)+1+1+2+len(db.indexVarBuf)) + + pos := copy(buf, db.indexVarBuf) - pos := 0 - buf[pos] = db.index - pos++ buf[pos] = SetType pos++ @@ -69,15 +76,25 @@ func (db *DB) sEncodeSetKey(key []byte, member []byte) []byte { } func (db *DB) sDecodeSetKey(ek []byte) ([]byte, []byte, error) { - if len(ek) < 5 || ek[0] != db.index || ek[1] != SetType { + pos, err := db.checkKeyIndex(ek) + if err != nil { + return nil, nil, err + } + + if pos+1 > len(ek) || ek[pos] != SetType { + return nil, nil, errSetKey + } + + pos++ + + if pos+2 > len(ek) { return nil, nil, errSetKey } - pos := 2 keyLen := int(binary.BigEndian.Uint16(ek[pos:])) pos += 2 - if keyLen+5 > len(ek) { + if keyLen+pos > len(ek) { return nil, nil, errSetKey } diff --git a/ledis/t_ttl.go b/ledis/t_ttl.go index 7758fc8..2785a1d 100644 --- a/ledis/t_ttl.go +++ b/ledis/t_ttl.go @@ -28,11 +28,12 @@ type ttlChecker struct { var errExpType = errors.New("invalid expire type") func (db *DB) expEncodeTimeKey(dataType byte, key []byte, when int64) []byte { - buf := make([]byte, len(key)+11) + buf := make([]byte, len(key)+10+len(db.indexVarBuf)) - buf[0] = db.index - buf[1] = ExpTimeType - pos := 2 + pos := copy(buf, db.indexVarBuf) + + buf[pos] = ExpTimeType + pos++ binary.BigEndian.PutUint64(buf[pos:], uint64(when)) pos += 8 @@ -46,12 +47,13 @@ func (db *DB) expEncodeTimeKey(dataType byte, key []byte, when int64) []byte { } func (db *DB) expEncodeMetaKey(dataType byte, key []byte) []byte { - buf := make([]byte, len(key)+3) + buf := make([]byte, len(key)+2+len(db.indexVarBuf)) - buf[0] = db.index - buf[1] = ExpMetaType - buf[2] = dataType - pos := 3 + pos := copy(buf, db.indexVarBuf) + buf[pos] = ExpMetaType + pos++ + buf[pos] = dataType + pos++ copy(buf[pos:], key) @@ -59,19 +61,29 @@ func (db *DB) expEncodeMetaKey(dataType byte, key []byte) []byte { } func (db *DB) expDecodeMetaKey(mk []byte) (byte, []byte, error) { - if len(mk) <= 3 || mk[0] != db.index || mk[1] != ExpMetaType { + pos, err := db.checkKeyIndex(mk) + if err != nil { + return 0, nil, err + } + + if pos+2 > len(mk) || mk[pos] != ExpMetaType { return 0, nil, errExpMetaKey } - return mk[2], mk[3:], nil + return mk[pos+1], mk[pos+2:], nil } func (db *DB) expDecodeTimeKey(tk []byte) (byte, []byte, int64, error) { - if len(tk) < 11 || tk[0] != db.index || tk[1] != ExpTimeType { + pos, err := db.checkKeyIndex(tk) + if err != nil { + return 0, nil, 0, err + } + + if pos+10 > len(tk) || tk[pos] != ExpTimeType { return 0, nil, 0, errExpTimeKey } - return tk[10], tk[11:], int64(binary.BigEndian.Uint64(tk[2:])), nil + return tk[pos+9], tk[pos+10:], int64(binary.BigEndian.Uint64(tk[pos+1:])), nil } func (db *DB) expire(t *batch, dataType byte, key []byte, duration int64) { diff --git a/ledis/t_ttl_test.go b/ledis/t_ttl_test.go index 371a40d..f261399 100644 --- a/ledis/t_ttl_test.go +++ b/ledis/t_ttl_test.go @@ -437,3 +437,31 @@ func TestExpCompose(t *testing.T) { return } + +func TestTTLCodec(t *testing.T) { + db := getTestDB() + + key := []byte("key") + ek := db.expEncodeTimeKey(KVType, key, 10) + + if tp, k, when, err := db.expDecodeTimeKey(ek); err != nil { + t.Fatal(err) + } else if tp != KVType { + t.Fatal(tp, KVType) + } else if string(k) != "key" { + t.Fatal(string(k)) + } else if when != 10 { + t.Fatal(when) + } + + ek = db.expEncodeMetaKey(KVType, key) + + if tp, k, err := db.expDecodeMetaKey(ek); err != nil { + t.Fatal(err) + } else if tp != KVType { + t.Fatal(tp, KVType) + } else if string(k) != "key" { + t.Fatal(string(k)) + } + +} diff --git a/ledis/t_zset.go b/ledis/t_zset.go index e33344a..6b65f4d 100644 --- a/ledis/t_zset.go +++ b/ledis/t_zset.go @@ -51,28 +51,31 @@ func checkZSetKMSize(key []byte, member []byte) error { } func (db *DB) zEncodeSizeKey(key []byte) []byte { - buf := make([]byte, len(key)+2) - buf[0] = db.index - buf[1] = ZSizeType - - copy(buf[2:], key) + buf := make([]byte, len(key)+1+len(db.indexVarBuf)) + pos := copy(buf, db.indexVarBuf) + buf[pos] = ZSizeType + pos++ + copy(buf[pos:], key) return buf } func (db *DB) zDecodeSizeKey(ek []byte) ([]byte, error) { - if len(ek) < 2 || ek[0] != db.index || ek[1] != ZSizeType { - return nil, errZSizeKey + pos, err := db.checkKeyIndex(ek) + if err != nil { + return nil, err } - return ek[2:], nil + if pos+1 > len(ek) || ek[pos] != ZSizeType { + return nil, errZSizeKey + } + pos++ + return ek[pos:], nil } func (db *DB) zEncodeSetKey(key []byte, member []byte) []byte { - buf := make([]byte, len(key)+len(member)+5) + buf := make([]byte, len(key)+len(member)+4+len(db.indexVarBuf)) - pos := 0 - buf[pos] = db.index - pos++ + pos := copy(buf, db.indexVarBuf) buf[pos] = ZSetType pos++ @@ -92,22 +95,35 @@ func (db *DB) zEncodeSetKey(key []byte, member []byte) []byte { } func (db *DB) zDecodeSetKey(ek []byte) ([]byte, []byte, error) { - if len(ek) < 5 || ek[0] != db.index || ek[1] != ZSetType { + pos, err := db.checkKeyIndex(ek) + if err != nil { + return nil, nil, err + } + + if pos+1 > len(ek) || ek[pos] != ZSetType { return nil, nil, errZSetKey } - keyLen := int(binary.BigEndian.Uint16(ek[2:])) - if keyLen+5 > len(ek) { + pos++ + + if pos+2 > len(ek) { return nil, nil, errZSetKey } - key := ek[4 : 4+keyLen] - - if ek[4+keyLen] != zsetStartMemSep { + keyLen := int(binary.BigEndian.Uint16(ek[pos:])) + if keyLen+pos > len(ek) { return nil, nil, errZSetKey } - member := ek[5+keyLen:] + pos += 2 + key := ek[pos : pos+keyLen] + + if ek[pos+keyLen] != zsetStartMemSep { + return nil, nil, errZSetKey + } + pos++ + + member := ek[pos+keyLen:] return key, member, nil } @@ -123,11 +139,9 @@ func (db *DB) zEncodeStopSetKey(key []byte) []byte { } func (db *DB) zEncodeScoreKey(key []byte, member []byte, score int64) []byte { - buf := make([]byte, len(key)+len(member)+14) + buf := make([]byte, len(key)+len(member)+13+len(db.indexVarBuf)) - pos := 0 - buf[pos] = db.index - pos++ + pos := copy(buf, db.indexVarBuf) buf[pos] = ZScoreType pos++ @@ -166,20 +180,38 @@ func (db *DB) zEncodeStopScoreKey(key []byte, score int64) []byte { } func (db *DB) zDecodeScoreKey(ek []byte) (key []byte, member []byte, score int64, err error) { - if len(ek) < 14 || ek[0] != db.index || ek[1] != ZScoreType { + pos := 0 + pos, err = db.checkKeyIndex(ek) + if err != nil { + return + } + + if pos+1 > len(ek) || ek[pos] != ZScoreType { + err = errZScoreKey + return + } + pos++ + + if pos+2 > len(ek) { + err = errZScoreKey + return + } + keyLen := int(binary.BigEndian.Uint16(ek[pos:])) + pos += 2 + + if keyLen+pos > len(ek) { err = errZScoreKey return } - keyLen := int(binary.BigEndian.Uint16(ek[2:])) - if keyLen+14 > len(ek) { + key = ek[pos : pos+keyLen] + pos += keyLen + + if pos+10 > len(ek) { err = errZScoreKey return } - key = ek[4 : 4+keyLen] - pos := 4 + keyLen - if (ek[pos] != zsetNScoreSep) && (ek[pos] != zsetPScoreSep) { err = errZScoreKey return