From c97b09fd0cfe16142bdb3eed3f305c70adc88eea Mon Sep 17 00:00:00 2001 From: holys Date: Thu, 14 Aug 2014 19:35:41 +0800 Subject: [PATCH] add set operations: difference, union, intersection --- ledis/ledis_db.go | 3 +- ledis/t_set.go | 252 ++++++++++++++++++++++++++++++++++++++------ ledis/t_set_test.go | 200 +++++++++++++++++++++++++++++++++++ 3 files changed, 420 insertions(+), 35 deletions(-) diff --git a/ledis/ledis_db.go b/ledis/ledis_db.go index 7d58054..1bd9a54 100644 --- a/ledis/ledis_db.go +++ b/ledis/ledis_db.go @@ -10,7 +10,8 @@ func (db *DB) FlushAll() (drop int64, err error) { db.lFlush, db.hFlush, db.zFlush, - db.bFlush} + db.bFlush, + db.sFlush} for _, flush := range all { if n, e := flush(); e != nil { diff --git a/ledis/t_set.go b/ledis/t_set.go index 78a750c..fdd5f6c 100644 --- a/ledis/t_set.go +++ b/ledis/t_set.go @@ -13,6 +13,9 @@ var errSSizeKey = errors.New("invalid ssize key") const ( setStartSep byte = ':' setStopSep byte = setStartSep + 1 + UnionType byte = 51 + DiffType byte = 52 + InterType byte = 53 ) func checkSetKMSize(key []byte, member []byte) error { @@ -101,9 +104,25 @@ func (db *DB) sEncodeStopKey(key []byte) []byte { return k } -// func (db *DB) sFlush() { +func (db *DB) sFlush() (drop int64, err error) { + minKey := make([]byte, 2) + minKey[0] = db.index + minKey[1] = SetType -// } + maxKey := make([]byte, 2) + maxKey[0] = db.index + maxKey[1] = SSizeType + 1 + + t := db.setTx + t.Lock() + defer t.Unlock() + + drop, err = db.flushRegion(t, minKey, maxKey) + err = db.expFlush(t, SetType) + + err = t.Commit() + return +} func (db *DB) sDelete(t *tx, key []byte) int64 { sk := db.sEncodeSizeKey(key) @@ -223,24 +242,124 @@ func (db *DB) SCard(key []byte) (int64, error) { return Int64(db.db.Get(sk)) } -// TODO -// func (db *DB) sDiffGeneric() +func (db *DB) sDiffGeneric(keys ...[]byte) ([][]byte, error) { + destMap := make(map[string]bool) -// func (db *DB) SDiff() { + members, err := db.SMembers(keys[0]) + if err != nil { + return nil, err + } -// } + for _, m := range members { + destMap[String(m)] = true + } -// func (db *DB) SDiffStore() { + for _, k := range keys[1:] { + members, err := db.SMembers(k) + if err != nil { + return nil, err + } -// } + for _, m := range members { + if _, ok := destMap[String(m)]; !ok { + continue + } else if ok { + delete(destMap, String(m)) + } + } + // O - A = O, O is zero set. + if len(destMap) == 0 { + return nil, nil + } + } -// func (db *DB) SInter() { + slice := make([][]byte, len(destMap)) + idx := 0 + for k, v := range destMap { + if !v { + continue + } + slice[idx] = []byte(k) + idx++ + } -// } + return slice, nil +} -// func (db *DB) SInterStore() { +func (db *DB) SDiff(keys ...[]byte) ([][]byte, error) { + v, err := db.sDiffGeneric(keys...) + return v, err +} -// } +func (db *DB) SDiffStore(dstKey []byte, keys ...[]byte) (int64, error) { + n, err := db.sStoreGeneric(dstKey, DiffType, keys...) + return n, err +} + +func (db *DB) sInterGeneric(keys ...[]byte) ([][]byte, error) { + destMap := make(map[string]bool) + + members, err := db.SMembers(keys[0]) + if err != nil { + return nil, err + } + + for _, m := range members { + destMap[String(m)] = true + } + + for _, key := range keys[1:] { + if err := checkKeySize(key); err != nil { + return nil, err + } + + members, err := db.SMembers(key) + if err != nil { + return nil, err + } else if len(members) == 0 { + return nil, err + } + + tempMap := make(map[string]bool) + for _, member := range members { + if err := checkKeySize(member); err != nil { + return nil, err + } + if _, ok := destMap[String(member)]; ok { + tempMap[String(member)] = true //mark this item as selected + } + } + destMap = tempMap //reduce the size of the result set + if len(destMap) == 0 { + return nil, nil + } + } + + slice := make([][]byte, len(destMap)) + idx := 0 + for k, v := range destMap { + if !v { + continue + } + + slice[idx] = []byte(k) + idx++ + } + + return slice, nil + +} + +func (db *DB) SInter(keys ...[]byte) ([][]byte, error) { + v, err := db.sInterGeneric(keys...) + return v, err + +} + +func (db *DB) SInterStore(dstKey []byte, keys ...[]byte) (int64, error) { + n, err := db.sStoreGeneric(dstKey, InterType, keys...) + return n, err +} func (db *DB) SIsMember(key []byte, member []byte) (int64, error) { ek := db.sEncodeSetKey(key, member) @@ -317,35 +436,100 @@ func (db *DB) SRem(key []byte, args ...[]byte) (int64, error) { } -// TODO -// func (db *DB) sUnionGeneric(keys ...[]byte) ([][]byte, error) { +func (db *DB) sUnionGeneric(keys ...[]byte) ([][]byte, error) { + dstMap := make(map[string]bool) -// for _, key := range keys { -// if err := checkKeySize(key); err != nil { -// return nil, err -// } + for _, key := range keys { + if err := checkKeySize(key); err != nil { + return nil, err + } -// } -// } + members, err := db.SMembers(key) + if err != nil { + return nil, err + } -// func (db *DB) SUnion(keys ...[]byte) ([][]byte, error) { + for _, member := range members { + dstMap[String(member)] = true + } + } -// if v, err := db.sUnionGeneric(keys); err != nil { -// return nil, err -// } else if v == nil { -// return nil, nil -// } else { -// return v, nil -// } + slice := make([][]byte, len(dstMap)) + idx := 0 + for k, v := range dstMap { + if !v { + continue + } + slice[idx] = []byte(k) + idx++ + } -// } + return slice, nil +} -// func (db *DB) SUnionStore(dstkey []byte, keys []byte) (int64, error) { -// if err := checkKeySize(dstkey); err != nil { -// return 0, err -// } +func (db *DB) SUnion(keys ...[]byte) ([][]byte, error) { + v, err := db.sUnionGeneric(keys...) + return v, err +} -// } +func (db *DB) SUnionStore(dstKey []byte, keys ...[]byte) (int64, error) { + n, err := db.sStoreGeneric(dstKey, UnionType, keys...) + return n, err +} + +func (db *DB) sStoreGeneric(dstKey []byte, optType byte, keys ...[]byte) (int64, error) { + if err := checkKeySize(dstKey); err != nil { + return 0, err + } + + t := db.setTx + t.Lock() + defer t.Unlock() + + db.sDelete(t, dstKey) + + var err error + var ek []byte + var num int64 = 0 + var v [][]byte + + switch optType { + case UnionType: + v, err = db.sUnionGeneric(keys...) + case DiffType: + v, err = db.sDiffGeneric(keys...) + case InterType: + v, err = db.sInterGeneric(keys...) + } + + if err != nil { + return 0, err + } + + for _, m := range v { + if err := checkSetKMSize(dstKey, m); err != nil { + return 0, err + } + + ek = db.sEncodeSetKey(dstKey, m) + + if v, err := db.db.Get(ek); err != nil { + return 0, err + } else if v == nil { + num++ + } + + t.Put(ek, nil) + + } + + if _, err = db.sIncrSize(dstKey, num); err != nil { + return 0, err + } + + err = t.Commit() + return num, err +} func (db *DB) SClear(key []byte) (int64, error) { if err := checkKeySize(key); err != nil { diff --git a/ledis/t_set_test.go b/ledis/t_set_test.go index 2284a68..d308bd5 100644 --- a/ledis/t_set_test.go +++ b/ledis/t_set_test.go @@ -126,3 +126,203 @@ func TestDBSet(t *testing.T) { } } + +func TestSetOperation(t *testing.T) { + db := getTestDB() + // testUnion(db, t) + testInter(db, t) + // testDiff(db, t) + +} + +func testUnion(db *DB, t *testing.T) { + + key := []byte("testdb_set_union_1") + key1 := []byte("testdb_set_union_2") + key2 := []byte("testdb_set_union_2") + // member1 := []byte("testdb_set_m1") + // member2 := []byte("testdb_set_m2") + + m1 := []byte("m1") + m2 := []byte("m2") + m3 := []byte("m3") + db.SAdd(key, m1, m2) + db.SAdd(key1, m1, m3) + db.SAdd(key2, m2, m3) + if _, err := db.sUnionGeneric(key, key2); err != nil { + t.Fatal(err) + } + + if _, err := db.SUnion(key, key2); err != nil { + t.Fatal(err) + } + + dstkey := []byte("union_dsk") + if num, err := db.SUnionStore(dstkey, key1, key2); err != nil { + t.Fatal(err) + } else if num != 3 { + t.Fatal(num) + } + if _, err := db.SMembers(dstkey); err != nil { + t.Fatal(err) + } + + if n, err := db.SCard(dstkey); err != nil { + t.Fatal(err) + } else if n != 3 { + t.Fatal(n) + } + + v1, _ := db.SUnion(key1, key2) + v2, _ := db.SUnion(key2, key1) + if len(v1) != len(v2) { + t.Fatal(v1, v2) + } + + v1, _ = db.SUnion(key, key1, key2) + v2, _ = db.SUnion(key, key2, key1) + if len(v1) != len(v2) { + t.Fatal(v1, v2) + } + + if v, err := db.SUnion(key, key); err != nil { + t.Fatal(err) + } else if len(v) != 2 { + t.Fatal(v) + } + + empKey := []byte("0") + if v, err := db.SUnion(key, empKey); err != nil { + t.Fatal(err) + } else if len(v) != 2 { + t.Fatal(v) + } +} + +func testInter(db *DB, t *testing.T) { + key1 := []byte("testdb_set_inter_1") + key2 := []byte("testdb_set_inter_2") + key3 := []byte("testdb_set_inter_3") + + m1 := []byte("m1") + m2 := []byte("m2") + m3 := []byte("m3") + m4 := []byte("m4") + + db.SAdd(key1, m1, m2) + db.SAdd(key2, m2, m3, m4) + db.SAdd(key3, m2, m4) + + if v, err := db.sInterGeneric(key1, key2); err != nil { + t.Fatal(err) + } else if len(v) != 1 { + t.Fatal(v) + } + + if v, err := db.SInter(key1, key2); err != nil { + t.Fatal(err) + } else if len(v) != 1 { + t.Fatal(v) + } + + dstKey := []byte("inter_dsk") + if n, err := db.SInterStore(dstKey, key1, key2); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Fatal(n) + } + + v1, _ := db.SInter(key1, key2) + v2, _ := db.SInter(key2, key1) + if len(v1) != len(v2) { + t.Fatal(v1, v2) + } + + v1, _ = db.SInter(key1, key2, key3) + v2, _ = db.SInter(key2, key3, key1) + if len(v1) != len(v2) { + t.Fatal(v1, v2) + } + + if v, err := db.SInter(key1, key1); err != nil { + t.Fatal(err) + } else if len(v) != 2 { + t.Fatal(v) + } + + empKey := []byte("0") + if v, err := db.SInter(key1, empKey); err != nil { + t.Fatal(err) + } else if len(v) != 0 { + t.Fatal(v) + } + + if v, err := db.SInter(empKey, key2); err != nil { + t.Fatal(err) + } else if len(v) != 0 { + t.Fatal(v) + } +} + +func testDiff(db *DB, t *testing.T) { + key0 := []byte("testdb_set_diff_0") + key1 := []byte("testdb_set_diff_1") + key2 := []byte("testdb_set_diff_2") + key3 := []byte("testdb_set_diff_3") + + m1 := []byte("m1") + m2 := []byte("m2") + m3 := []byte("m3") + m4 := []byte("m4") + + db.SAdd(key1, m1, m2) + db.SAdd(key2, m2, m3, m4) + db.SAdd(key3, m3) + + if _, err := db.sDiffGeneric(key1, key2); err != nil { + t.Fatal(err) + } + + if v, err := db.SDiff(key1, key2); err != nil { + t.Fatal(err) + } else if len(v) != 1 { + t.Fatal(v) + } + + dstKey := []byte("diff_dsk") + if n, err := db.SDiffStore(dstKey, key1, key2); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Fatal(n) + } + + if v, err := db.SDiff(key2, key1); err != nil { + t.Fatal(err) + } else if len(v) != 2 { + t.Fatal(v) + } + + if v, err := db.SDiff(key1, key2, key3); err != nil { + t.Fatal(err) + } else if len(v) != 1 { + t.Fatal(v) //return 1 + } + + if v, err := db.SDiff(key2, key2); err != nil { + t.Fatal(err) + } else if len(v) != 0 { + t.Fatal(v) + } + + if v, err := db.SDiff(key0, key1); err != nil { + t.Fatal(err) + } else if len(v) != 0 { + t.Fatal(v) + } + + if v, err := db.SDiff(key1, key0); err != nil { + t.Fatal(err) + } else if len(v) != 2 { + t.Fatal(v) + } +}