add set operations: difference, union, intersection

This commit is contained in:
holys 2014-08-14 19:35:41 +08:00
parent 572a8f2c9a
commit c97b09fd0c
3 changed files with 420 additions and 35 deletions

View File

@ -10,7 +10,8 @@ func (db *DB) FlushAll() (drop int64, err error) {
db.lFlush, db.lFlush,
db.hFlush, db.hFlush,
db.zFlush, db.zFlush,
db.bFlush} db.bFlush,
db.sFlush}
for _, flush := range all { for _, flush := range all {
if n, e := flush(); e != nil { if n, e := flush(); e != nil {

View File

@ -13,6 +13,9 @@ var errSSizeKey = errors.New("invalid ssize key")
const ( const (
setStartSep byte = ':' setStartSep byte = ':'
setStopSep byte = setStartSep + 1 setStopSep byte = setStartSep + 1
UnionType byte = 51
DiffType byte = 52
InterType byte = 53
) )
func checkSetKMSize(key []byte, member []byte) error { func checkSetKMSize(key []byte, member []byte) error {
@ -101,9 +104,25 @@ func (db *DB) sEncodeStopKey(key []byte) []byte {
return k return k
} }
// func (db *DB) sFlush() { func (db *DB) sFlush() (drop int64, err error) {
minKey := make([]byte, 2)
minKey[0] = db.index
minKey[1] = SetType
// } maxKey := make([]byte, 2)
maxKey[0] = db.index
maxKey[1] = SSizeType + 1
t := db.setTx
t.Lock()
defer t.Unlock()
drop, err = db.flushRegion(t, minKey, maxKey)
err = db.expFlush(t, SetType)
err = t.Commit()
return
}
func (db *DB) sDelete(t *tx, key []byte) int64 { func (db *DB) sDelete(t *tx, key []byte) int64 {
sk := db.sEncodeSizeKey(key) sk := db.sEncodeSizeKey(key)
@ -223,24 +242,124 @@ func (db *DB) SCard(key []byte) (int64, error) {
return Int64(db.db.Get(sk)) return Int64(db.db.Get(sk))
} }
// TODO func (db *DB) sDiffGeneric(keys ...[]byte) ([][]byte, error) {
// func (db *DB) sDiffGeneric() destMap := make(map[string]bool)
// func (db *DB) SDiff() { members, err := db.SMembers(keys[0])
if err != nil {
return nil, err
}
// } for _, m := range members {
destMap[String(m)] = true
}
// func (db *DB) SDiffStore() { for _, k := range keys[1:] {
members, err := db.SMembers(k)
if err != nil {
return nil, err
}
// } for _, m := range members {
if _, ok := destMap[String(m)]; !ok {
continue
} else if ok {
delete(destMap, String(m))
}
}
// O - A = O, O is zero set.
if len(destMap) == 0 {
return nil, nil
}
}
// func (db *DB) SInter() { slice := make([][]byte, len(destMap))
idx := 0
for k, v := range destMap {
if !v {
continue
}
slice[idx] = []byte(k)
idx++
}
// } return slice, nil
}
// func (db *DB) SInterStore() { func (db *DB) SDiff(keys ...[]byte) ([][]byte, error) {
v, err := db.sDiffGeneric(keys...)
return v, err
}
// } func (db *DB) SDiffStore(dstKey []byte, keys ...[]byte) (int64, error) {
n, err := db.sStoreGeneric(dstKey, DiffType, keys...)
return n, err
}
func (db *DB) sInterGeneric(keys ...[]byte) ([][]byte, error) {
destMap := make(map[string]bool)
members, err := db.SMembers(keys[0])
if err != nil {
return nil, err
}
for _, m := range members {
destMap[String(m)] = true
}
for _, key := range keys[1:] {
if err := checkKeySize(key); err != nil {
return nil, err
}
members, err := db.SMembers(key)
if err != nil {
return nil, err
} else if len(members) == 0 {
return nil, err
}
tempMap := make(map[string]bool)
for _, member := range members {
if err := checkKeySize(member); err != nil {
return nil, err
}
if _, ok := destMap[String(member)]; ok {
tempMap[String(member)] = true //mark this item as selected
}
}
destMap = tempMap //reduce the size of the result set
if len(destMap) == 0 {
return nil, nil
}
}
slice := make([][]byte, len(destMap))
idx := 0
for k, v := range destMap {
if !v {
continue
}
slice[idx] = []byte(k)
idx++
}
return slice, nil
}
func (db *DB) SInter(keys ...[]byte) ([][]byte, error) {
v, err := db.sInterGeneric(keys...)
return v, err
}
func (db *DB) SInterStore(dstKey []byte, keys ...[]byte) (int64, error) {
n, err := db.sStoreGeneric(dstKey, InterType, keys...)
return n, err
}
func (db *DB) SIsMember(key []byte, member []byte) (int64, error) { func (db *DB) SIsMember(key []byte, member []byte) (int64, error) {
ek := db.sEncodeSetKey(key, member) ek := db.sEncodeSetKey(key, member)
@ -317,35 +436,100 @@ func (db *DB) SRem(key []byte, args ...[]byte) (int64, error) {
} }
// TODO func (db *DB) sUnionGeneric(keys ...[]byte) ([][]byte, error) {
// func (db *DB) sUnionGeneric(keys ...[]byte) ([][]byte, error) { dstMap := make(map[string]bool)
// for _, key := range keys { for _, key := range keys {
// if err := checkKeySize(key); err != nil { if err := checkKeySize(key); err != nil {
// return nil, err return nil, err
// } }
// } members, err := db.SMembers(key)
// } if err != nil {
return nil, err
}
// func (db *DB) SUnion(keys ...[]byte) ([][]byte, error) { for _, member := range members {
dstMap[String(member)] = true
}
}
// if v, err := db.sUnionGeneric(keys); err != nil { slice := make([][]byte, len(dstMap))
// return nil, err idx := 0
// } else if v == nil { for k, v := range dstMap {
// return nil, nil if !v {
// } else { continue
// return v, nil }
// } slice[idx] = []byte(k)
idx++
}
// } return slice, nil
}
// func (db *DB) SUnionStore(dstkey []byte, keys []byte) (int64, error) { func (db *DB) SUnion(keys ...[]byte) ([][]byte, error) {
// if err := checkKeySize(dstkey); err != nil { v, err := db.sUnionGeneric(keys...)
// return 0, err return v, err
// } }
// } func (db *DB) SUnionStore(dstKey []byte, keys ...[]byte) (int64, error) {
n, err := db.sStoreGeneric(dstKey, UnionType, keys...)
return n, err
}
func (db *DB) sStoreGeneric(dstKey []byte, optType byte, keys ...[]byte) (int64, error) {
if err := checkKeySize(dstKey); err != nil {
return 0, err
}
t := db.setTx
t.Lock()
defer t.Unlock()
db.sDelete(t, dstKey)
var err error
var ek []byte
var num int64 = 0
var v [][]byte
switch optType {
case UnionType:
v, err = db.sUnionGeneric(keys...)
case DiffType:
v, err = db.sDiffGeneric(keys...)
case InterType:
v, err = db.sInterGeneric(keys...)
}
if err != nil {
return 0, err
}
for _, m := range v {
if err := checkSetKMSize(dstKey, m); err != nil {
return 0, err
}
ek = db.sEncodeSetKey(dstKey, m)
if v, err := db.db.Get(ek); err != nil {
return 0, err
} else if v == nil {
num++
}
t.Put(ek, nil)
}
if _, err = db.sIncrSize(dstKey, num); err != nil {
return 0, err
}
err = t.Commit()
return num, err
}
func (db *DB) SClear(key []byte) (int64, error) { func (db *DB) SClear(key []byte) (int64, error) {
if err := checkKeySize(key); err != nil { if err := checkKeySize(key); err != nil {

View File

@ -126,3 +126,203 @@ func TestDBSet(t *testing.T) {
} }
} }
func TestSetOperation(t *testing.T) {
db := getTestDB()
// testUnion(db, t)
testInter(db, t)
// testDiff(db, t)
}
func testUnion(db *DB, t *testing.T) {
key := []byte("testdb_set_union_1")
key1 := []byte("testdb_set_union_2")
key2 := []byte("testdb_set_union_2")
// member1 := []byte("testdb_set_m1")
// member2 := []byte("testdb_set_m2")
m1 := []byte("m1")
m2 := []byte("m2")
m3 := []byte("m3")
db.SAdd(key, m1, m2)
db.SAdd(key1, m1, m3)
db.SAdd(key2, m2, m3)
if _, err := db.sUnionGeneric(key, key2); err != nil {
t.Fatal(err)
}
if _, err := db.SUnion(key, key2); err != nil {
t.Fatal(err)
}
dstkey := []byte("union_dsk")
if num, err := db.SUnionStore(dstkey, key1, key2); err != nil {
t.Fatal(err)
} else if num != 3 {
t.Fatal(num)
}
if _, err := db.SMembers(dstkey); err != nil {
t.Fatal(err)
}
if n, err := db.SCard(dstkey); err != nil {
t.Fatal(err)
} else if n != 3 {
t.Fatal(n)
}
v1, _ := db.SUnion(key1, key2)
v2, _ := db.SUnion(key2, key1)
if len(v1) != len(v2) {
t.Fatal(v1, v2)
}
v1, _ = db.SUnion(key, key1, key2)
v2, _ = db.SUnion(key, key2, key1)
if len(v1) != len(v2) {
t.Fatal(v1, v2)
}
if v, err := db.SUnion(key, key); err != nil {
t.Fatal(err)
} else if len(v) != 2 {
t.Fatal(v)
}
empKey := []byte("0")
if v, err := db.SUnion(key, empKey); err != nil {
t.Fatal(err)
} else if len(v) != 2 {
t.Fatal(v)
}
}
func testInter(db *DB, t *testing.T) {
key1 := []byte("testdb_set_inter_1")
key2 := []byte("testdb_set_inter_2")
key3 := []byte("testdb_set_inter_3")
m1 := []byte("m1")
m2 := []byte("m2")
m3 := []byte("m3")
m4 := []byte("m4")
db.SAdd(key1, m1, m2)
db.SAdd(key2, m2, m3, m4)
db.SAdd(key3, m2, m4)
if v, err := db.sInterGeneric(key1, key2); err != nil {
t.Fatal(err)
} else if len(v) != 1 {
t.Fatal(v)
}
if v, err := db.SInter(key1, key2); err != nil {
t.Fatal(err)
} else if len(v) != 1 {
t.Fatal(v)
}
dstKey := []byte("inter_dsk")
if n, err := db.SInterStore(dstKey, key1, key2); err != nil {
t.Fatal(err)
} else if n != 1 {
t.Fatal(n)
}
v1, _ := db.SInter(key1, key2)
v2, _ := db.SInter(key2, key1)
if len(v1) != len(v2) {
t.Fatal(v1, v2)
}
v1, _ = db.SInter(key1, key2, key3)
v2, _ = db.SInter(key2, key3, key1)
if len(v1) != len(v2) {
t.Fatal(v1, v2)
}
if v, err := db.SInter(key1, key1); err != nil {
t.Fatal(err)
} else if len(v) != 2 {
t.Fatal(v)
}
empKey := []byte("0")
if v, err := db.SInter(key1, empKey); err != nil {
t.Fatal(err)
} else if len(v) != 0 {
t.Fatal(v)
}
if v, err := db.SInter(empKey, key2); err != nil {
t.Fatal(err)
} else if len(v) != 0 {
t.Fatal(v)
}
}
func testDiff(db *DB, t *testing.T) {
key0 := []byte("testdb_set_diff_0")
key1 := []byte("testdb_set_diff_1")
key2 := []byte("testdb_set_diff_2")
key3 := []byte("testdb_set_diff_3")
m1 := []byte("m1")
m2 := []byte("m2")
m3 := []byte("m3")
m4 := []byte("m4")
db.SAdd(key1, m1, m2)
db.SAdd(key2, m2, m3, m4)
db.SAdd(key3, m3)
if _, err := db.sDiffGeneric(key1, key2); err != nil {
t.Fatal(err)
}
if v, err := db.SDiff(key1, key2); err != nil {
t.Fatal(err)
} else if len(v) != 1 {
t.Fatal(v)
}
dstKey := []byte("diff_dsk")
if n, err := db.SDiffStore(dstKey, key1, key2); err != nil {
t.Fatal(err)
} else if n != 1 {
t.Fatal(n)
}
if v, err := db.SDiff(key2, key1); err != nil {
t.Fatal(err)
} else if len(v) != 2 {
t.Fatal(v)
}
if v, err := db.SDiff(key1, key2, key3); err != nil {
t.Fatal(err)
} else if len(v) != 1 {
t.Fatal(v) //return 1
}
if v, err := db.SDiff(key2, key2); err != nil {
t.Fatal(err)
} else if len(v) != 0 {
t.Fatal(v)
}
if v, err := db.SDiff(key0, key1); err != nil {
t.Fatal(err)
} else if len(v) != 0 {
t.Fatal(v)
}
if v, err := db.SDiff(key1, key0); err != nil {
t.Fatal(err)
} else if len(v) != 2 {
t.Fatal(v)
}
}