mirror of https://github.com/ledisdb/ledisdb.git
add set operations: difference, union, intersection
This commit is contained in:
parent
572a8f2c9a
commit
c97b09fd0c
|
@ -10,7 +10,8 @@ func (db *DB) FlushAll() (drop int64, err error) {
|
|||
db.lFlush,
|
||||
db.hFlush,
|
||||
db.zFlush,
|
||||
db.bFlush}
|
||||
db.bFlush,
|
||||
db.sFlush}
|
||||
|
||||
for _, flush := range all {
|
||||
if n, e := flush(); e != nil {
|
||||
|
|
252
ledis/t_set.go
252
ledis/t_set.go
|
@ -13,6 +13,9 @@ var errSSizeKey = errors.New("invalid ssize key")
|
|||
const (
|
||||
setStartSep byte = ':'
|
||||
setStopSep byte = setStartSep + 1
|
||||
UnionType byte = 51
|
||||
DiffType byte = 52
|
||||
InterType byte = 53
|
||||
)
|
||||
|
||||
func checkSetKMSize(key []byte, member []byte) error {
|
||||
|
@ -101,9 +104,25 @@ func (db *DB) sEncodeStopKey(key []byte) []byte {
|
|||
return k
|
||||
}
|
||||
|
||||
// func (db *DB) sFlush() {
|
||||
func (db *DB) sFlush() (drop int64, err error) {
|
||||
minKey := make([]byte, 2)
|
||||
minKey[0] = db.index
|
||||
minKey[1] = SetType
|
||||
|
||||
// }
|
||||
maxKey := make([]byte, 2)
|
||||
maxKey[0] = db.index
|
||||
maxKey[1] = SSizeType + 1
|
||||
|
||||
t := db.setTx
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
|
||||
drop, err = db.flushRegion(t, minKey, maxKey)
|
||||
err = db.expFlush(t, SetType)
|
||||
|
||||
err = t.Commit()
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) sDelete(t *tx, key []byte) int64 {
|
||||
sk := db.sEncodeSizeKey(key)
|
||||
|
@ -223,24 +242,124 @@ func (db *DB) SCard(key []byte) (int64, error) {
|
|||
return Int64(db.db.Get(sk))
|
||||
}
|
||||
|
||||
// TODO
|
||||
// func (db *DB) sDiffGeneric()
|
||||
func (db *DB) sDiffGeneric(keys ...[]byte) ([][]byte, error) {
|
||||
destMap := make(map[string]bool)
|
||||
|
||||
// func (db *DB) SDiff() {
|
||||
members, err := db.SMembers(keys[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// }
|
||||
for _, m := range members {
|
||||
destMap[String(m)] = true
|
||||
}
|
||||
|
||||
// func (db *DB) SDiffStore() {
|
||||
for _, k := range keys[1:] {
|
||||
members, err := db.SMembers(k)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// }
|
||||
for _, m := range members {
|
||||
if _, ok := destMap[String(m)]; !ok {
|
||||
continue
|
||||
} else if ok {
|
||||
delete(destMap, String(m))
|
||||
}
|
||||
}
|
||||
// O - A = O, O is zero set.
|
||||
if len(destMap) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// func (db *DB) SInter() {
|
||||
slice := make([][]byte, len(destMap))
|
||||
idx := 0
|
||||
for k, v := range destMap {
|
||||
if !v {
|
||||
continue
|
||||
}
|
||||
slice[idx] = []byte(k)
|
||||
idx++
|
||||
}
|
||||
|
||||
// }
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
// func (db *DB) SInterStore() {
|
||||
func (db *DB) SDiff(keys ...[]byte) ([][]byte, error) {
|
||||
v, err := db.sDiffGeneric(keys...)
|
||||
return v, err
|
||||
}
|
||||
|
||||
// }
|
||||
func (db *DB) SDiffStore(dstKey []byte, keys ...[]byte) (int64, error) {
|
||||
n, err := db.sStoreGeneric(dstKey, DiffType, keys...)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (db *DB) sInterGeneric(keys ...[]byte) ([][]byte, error) {
|
||||
destMap := make(map[string]bool)
|
||||
|
||||
members, err := db.SMembers(keys[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, m := range members {
|
||||
destMap[String(m)] = true
|
||||
}
|
||||
|
||||
for _, key := range keys[1:] {
|
||||
if err := checkKeySize(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
members, err := db.SMembers(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(members) == 0 {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tempMap := make(map[string]bool)
|
||||
for _, member := range members {
|
||||
if err := checkKeySize(member); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok := destMap[String(member)]; ok {
|
||||
tempMap[String(member)] = true //mark this item as selected
|
||||
}
|
||||
}
|
||||
destMap = tempMap //reduce the size of the result set
|
||||
if len(destMap) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
slice := make([][]byte, len(destMap))
|
||||
idx := 0
|
||||
for k, v := range destMap {
|
||||
if !v {
|
||||
continue
|
||||
}
|
||||
|
||||
slice[idx] = []byte(k)
|
||||
idx++
|
||||
}
|
||||
|
||||
return slice, nil
|
||||
|
||||
}
|
||||
|
||||
func (db *DB) SInter(keys ...[]byte) ([][]byte, error) {
|
||||
v, err := db.sInterGeneric(keys...)
|
||||
return v, err
|
||||
|
||||
}
|
||||
|
||||
func (db *DB) SInterStore(dstKey []byte, keys ...[]byte) (int64, error) {
|
||||
n, err := db.sStoreGeneric(dstKey, InterType, keys...)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (db *DB) SIsMember(key []byte, member []byte) (int64, error) {
|
||||
ek := db.sEncodeSetKey(key, member)
|
||||
|
@ -317,35 +436,100 @@ func (db *DB) SRem(key []byte, args ...[]byte) (int64, error) {
|
|||
|
||||
}
|
||||
|
||||
// TODO
|
||||
// func (db *DB) sUnionGeneric(keys ...[]byte) ([][]byte, error) {
|
||||
func (db *DB) sUnionGeneric(keys ...[]byte) ([][]byte, error) {
|
||||
dstMap := make(map[string]bool)
|
||||
|
||||
// for _, key := range keys {
|
||||
// if err := checkKeySize(key); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
for _, key := range keys {
|
||||
if err := checkKeySize(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// }
|
||||
// }
|
||||
members, err := db.SMembers(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// func (db *DB) SUnion(keys ...[]byte) ([][]byte, error) {
|
||||
for _, member := range members {
|
||||
dstMap[String(member)] = true
|
||||
}
|
||||
}
|
||||
|
||||
// if v, err := db.sUnionGeneric(keys); err != nil {
|
||||
// return nil, err
|
||||
// } else if v == nil {
|
||||
// return nil, nil
|
||||
// } else {
|
||||
// return v, nil
|
||||
// }
|
||||
slice := make([][]byte, len(dstMap))
|
||||
idx := 0
|
||||
for k, v := range dstMap {
|
||||
if !v {
|
||||
continue
|
||||
}
|
||||
slice[idx] = []byte(k)
|
||||
idx++
|
||||
}
|
||||
|
||||
// }
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
// func (db *DB) SUnionStore(dstkey []byte, keys []byte) (int64, error) {
|
||||
// if err := checkKeySize(dstkey); err != nil {
|
||||
// return 0, err
|
||||
// }
|
||||
func (db *DB) SUnion(keys ...[]byte) ([][]byte, error) {
|
||||
v, err := db.sUnionGeneric(keys...)
|
||||
return v, err
|
||||
}
|
||||
|
||||
// }
|
||||
func (db *DB) SUnionStore(dstKey []byte, keys ...[]byte) (int64, error) {
|
||||
n, err := db.sStoreGeneric(dstKey, UnionType, keys...)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (db *DB) sStoreGeneric(dstKey []byte, optType byte, keys ...[]byte) (int64, error) {
|
||||
if err := checkKeySize(dstKey); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
t := db.setTx
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
|
||||
db.sDelete(t, dstKey)
|
||||
|
||||
var err error
|
||||
var ek []byte
|
||||
var num int64 = 0
|
||||
var v [][]byte
|
||||
|
||||
switch optType {
|
||||
case UnionType:
|
||||
v, err = db.sUnionGeneric(keys...)
|
||||
case DiffType:
|
||||
v, err = db.sDiffGeneric(keys...)
|
||||
case InterType:
|
||||
v, err = db.sInterGeneric(keys...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for _, m := range v {
|
||||
if err := checkSetKMSize(dstKey, m); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
ek = db.sEncodeSetKey(dstKey, m)
|
||||
|
||||
if v, err := db.db.Get(ek); err != nil {
|
||||
return 0, err
|
||||
} else if v == nil {
|
||||
num++
|
||||
}
|
||||
|
||||
t.Put(ek, nil)
|
||||
|
||||
}
|
||||
|
||||
if _, err = db.sIncrSize(dstKey, num); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = t.Commit()
|
||||
return num, err
|
||||
}
|
||||
|
||||
func (db *DB) SClear(key []byte) (int64, error) {
|
||||
if err := checkKeySize(key); err != nil {
|
||||
|
|
|
@ -126,3 +126,203 @@ func TestDBSet(t *testing.T) {
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
func TestSetOperation(t *testing.T) {
|
||||
db := getTestDB()
|
||||
// testUnion(db, t)
|
||||
testInter(db, t)
|
||||
// testDiff(db, t)
|
||||
|
||||
}
|
||||
|
||||
func testUnion(db *DB, t *testing.T) {
|
||||
|
||||
key := []byte("testdb_set_union_1")
|
||||
key1 := []byte("testdb_set_union_2")
|
||||
key2 := []byte("testdb_set_union_2")
|
||||
// member1 := []byte("testdb_set_m1")
|
||||
// member2 := []byte("testdb_set_m2")
|
||||
|
||||
m1 := []byte("m1")
|
||||
m2 := []byte("m2")
|
||||
m3 := []byte("m3")
|
||||
db.SAdd(key, m1, m2)
|
||||
db.SAdd(key1, m1, m3)
|
||||
db.SAdd(key2, m2, m3)
|
||||
if _, err := db.sUnionGeneric(key, key2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := db.SUnion(key, key2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dstkey := []byte("union_dsk")
|
||||
if num, err := db.SUnionStore(dstkey, key1, key2); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if num != 3 {
|
||||
t.Fatal(num)
|
||||
}
|
||||
if _, err := db.SMembers(dstkey); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if n, err := db.SCard(dstkey); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if n != 3 {
|
||||
t.Fatal(n)
|
||||
}
|
||||
|
||||
v1, _ := db.SUnion(key1, key2)
|
||||
v2, _ := db.SUnion(key2, key1)
|
||||
if len(v1) != len(v2) {
|
||||
t.Fatal(v1, v2)
|
||||
}
|
||||
|
||||
v1, _ = db.SUnion(key, key1, key2)
|
||||
v2, _ = db.SUnion(key, key2, key1)
|
||||
if len(v1) != len(v2) {
|
||||
t.Fatal(v1, v2)
|
||||
}
|
||||
|
||||
if v, err := db.SUnion(key, key); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(v) != 2 {
|
||||
t.Fatal(v)
|
||||
}
|
||||
|
||||
empKey := []byte("0")
|
||||
if v, err := db.SUnion(key, empKey); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(v) != 2 {
|
||||
t.Fatal(v)
|
||||
}
|
||||
}
|
||||
|
||||
func testInter(db *DB, t *testing.T) {
|
||||
key1 := []byte("testdb_set_inter_1")
|
||||
key2 := []byte("testdb_set_inter_2")
|
||||
key3 := []byte("testdb_set_inter_3")
|
||||
|
||||
m1 := []byte("m1")
|
||||
m2 := []byte("m2")
|
||||
m3 := []byte("m3")
|
||||
m4 := []byte("m4")
|
||||
|
||||
db.SAdd(key1, m1, m2)
|
||||
db.SAdd(key2, m2, m3, m4)
|
||||
db.SAdd(key3, m2, m4)
|
||||
|
||||
if v, err := db.sInterGeneric(key1, key2); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(v) != 1 {
|
||||
t.Fatal(v)
|
||||
}
|
||||
|
||||
if v, err := db.SInter(key1, key2); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(v) != 1 {
|
||||
t.Fatal(v)
|
||||
}
|
||||
|
||||
dstKey := []byte("inter_dsk")
|
||||
if n, err := db.SInterStore(dstKey, key1, key2); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if n != 1 {
|
||||
t.Fatal(n)
|
||||
}
|
||||
|
||||
v1, _ := db.SInter(key1, key2)
|
||||
v2, _ := db.SInter(key2, key1)
|
||||
if len(v1) != len(v2) {
|
||||
t.Fatal(v1, v2)
|
||||
}
|
||||
|
||||
v1, _ = db.SInter(key1, key2, key3)
|
||||
v2, _ = db.SInter(key2, key3, key1)
|
||||
if len(v1) != len(v2) {
|
||||
t.Fatal(v1, v2)
|
||||
}
|
||||
|
||||
if v, err := db.SInter(key1, key1); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(v) != 2 {
|
||||
t.Fatal(v)
|
||||
}
|
||||
|
||||
empKey := []byte("0")
|
||||
if v, err := db.SInter(key1, empKey); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(v) != 0 {
|
||||
t.Fatal(v)
|
||||
}
|
||||
|
||||
if v, err := db.SInter(empKey, key2); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(v) != 0 {
|
||||
t.Fatal(v)
|
||||
}
|
||||
}
|
||||
|
||||
func testDiff(db *DB, t *testing.T) {
|
||||
key0 := []byte("testdb_set_diff_0")
|
||||
key1 := []byte("testdb_set_diff_1")
|
||||
key2 := []byte("testdb_set_diff_2")
|
||||
key3 := []byte("testdb_set_diff_3")
|
||||
|
||||
m1 := []byte("m1")
|
||||
m2 := []byte("m2")
|
||||
m3 := []byte("m3")
|
||||
m4 := []byte("m4")
|
||||
|
||||
db.SAdd(key1, m1, m2)
|
||||
db.SAdd(key2, m2, m3, m4)
|
||||
db.SAdd(key3, m3)
|
||||
|
||||
if _, err := db.sDiffGeneric(key1, key2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v, err := db.SDiff(key1, key2); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(v) != 1 {
|
||||
t.Fatal(v)
|
||||
}
|
||||
|
||||
dstKey := []byte("diff_dsk")
|
||||
if n, err := db.SDiffStore(dstKey, key1, key2); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if n != 1 {
|
||||
t.Fatal(n)
|
||||
}
|
||||
|
||||
if v, err := db.SDiff(key2, key1); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(v) != 2 {
|
||||
t.Fatal(v)
|
||||
}
|
||||
|
||||
if v, err := db.SDiff(key1, key2, key3); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(v) != 1 {
|
||||
t.Fatal(v) //return 1
|
||||
}
|
||||
|
||||
if v, err := db.SDiff(key2, key2); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(v) != 0 {
|
||||
t.Fatal(v)
|
||||
}
|
||||
|
||||
if v, err := db.SDiff(key0, key1); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(v) != 0 {
|
||||
t.Fatal(v)
|
||||
}
|
||||
|
||||
if v, err := db.SDiff(key1, key0); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(v) != 2 {
|
||||
t.Fatal(v)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue