add zunionstore & zintrestore interface

This commit is contained in:
wenyekui 2014-08-12 18:06:02 +08:00
parent de98e7887e
commit 3927d2bd8f
2 changed files with 191 additions and 0 deletions

View File

@ -12,6 +12,10 @@ const (
MinScore int64 = -1<<63 + 1 MinScore int64 = -1<<63 + 1
MaxScore int64 = 1<<63 - 1 MaxScore int64 = 1<<63 - 1
InvalidScore int64 = -1 << 63 InvalidScore int64 = -1 << 63
AggregateSum byte = 0
AggregateMin byte = 1
AggregateMax byte = 2
) )
type ScorePair struct { type ScorePair struct {
@ -23,6 +27,9 @@ var errZSizeKey = errors.New("invalid zsize key")
var errZSetKey = errors.New("invalid zset key") var errZSetKey = errors.New("invalid zset key")
var errZScoreKey = errors.New("invalid zscore key") var errZScoreKey = errors.New("invalid zscore key")
var errScoreOverflow = errors.New("zset score overflow") var errScoreOverflow = errors.New("zset score overflow")
var errInvalidAggregate = errors.New("invalid aggregate")
var errInvalidWeightNum = errors.New("invalid weight number")
var errInvalidSrcKeyNum = errors.New("invalid src key number")
const ( const (
zsetNScoreSep byte = '<' zsetNScoreSep byte = '<'
@ -839,3 +846,154 @@ func (db *DB) ZPersist(key []byte) (int64, error) {
err = t.Commit() err = t.Commit()
return n, err return n, err
} }
func getAggregateFunc(aggregate byte) func(int64, int64) int64 {
switch aggregate {
case AggregateSum:
return func(a int64, b int64) int64 {
return a + b
}
case AggregateMax:
return func(a int64, b int64) int64 {
if a > b {
return a
}
return b
}
case AggregateMin:
return func(a int64, b int64) int64 {
if a > b {
return b
}
return a
}
}
return nil
}
func (db *DB) ZUnionStore(destKey []byte, srcKeys [][]byte, weights []int64, aggregate byte) (int64, error) {
var destMap = map[string]int64{}
aggregateFunc := getAggregateFunc(aggregate)
if aggregateFunc == nil {
return 0, errInvalidAggregate
}
if len(srcKeys) < 1 {
return 0, errInvalidSrcKeyNum
}
if weights != nil {
if len(srcKeys) != len(weights) {
return 0, errInvalidWeightNum
}
} else {
weights = make([]int64, len(srcKeys))
for i := 0; i < len(weights); i++ {
weights[i] = 1
}
}
for i, key := range srcKeys {
scorePairs, err := db.ZRange(key, 0, -1)
if err != nil {
return 0, err
}
for _, pair := range scorePairs {
if score, ok := destMap[String(pair.Member)]; !ok {
destMap[String(pair.Member)] = pair.Score
} else {
destMap[String(pair.Member)] = aggregateFunc(score, pair.Score*weights[i])
}
}
}
t := db.zsetTx
t.Lock()
defer t.Unlock()
db.zDelete(t, destKey)
var num int64 = 0
for member, score := range destMap {
if err := checkZSetKMSize(destKey, []byte(member)); err != nil {
return 0, err
}
if n, err := db.zSetItem(t, destKey, score, []byte(member)); err != nil {
return 0, err
} else if n == 0 {
//add new
num++
}
}
if _, err := db.zIncrSize(t, destKey, num); err != nil {
return 0, err
}
//todo add binlog
err := t.Commit()
if err != nil {
return 0, err
}
return int64(len(destMap)), nil
}
func (db *DB) ZInterStore(destKey []byte, srcKeys [][]byte, weights []int64, aggregate byte) (int64, error) {
var destMap = map[string]int64{}
aggregateFunc := getAggregateFunc(aggregate)
if aggregateFunc == nil {
return 0, errInvalidAggregate
}
if len(srcKeys) < 1 {
return 0, errInvalidSrcKeyNum
}
if weights != nil {
if len(srcKeys) != len(weights) {
return 0, errInvalidWeightNum
}
} else {
weights = make([]int64, len(srcKeys))
for i := 0; i < len(weights); i++ {
weights[i] = 1
}
}
var keptMembers [][]byte
for i, key := range srcKeys {
scorePairs, err := db.ZRange(key, 0, -1)
if err != nil {
return 0, err
}
for _, pair := range scorePairs {
if score, ok := destMap[String(pair.Member)]; !ok {
destMap[String(pair.Member)] = pair.Score
} else {
keptMembers = append(keptMembers, pair.Member)
destMap[String(pair.Member)] = aggregateFunc(score, pair.Score*weights[i])
}
}
}
t := db.zsetTx
t.Lock()
defer t.Unlock()
db.zDelete(t, destKey)
var num int64 = 0
for _, member := range keptMembers {
score := destMap[String(member)]
if err := checkZSetKMSize(destKey, member); err != nil {
return 0, err
}
if n, err := db.zSetItem(t, destKey, score, member); err != nil {
return 0, err
} else if n == 0 {
//add new
num++
}
}
return int64(len(keptMembers)), nil
}

View File

@ -264,3 +264,36 @@ func TestZSetPersist(t *testing.T) {
t.Fatal(n) t.Fatal(n)
} }
} }
func TestZUnionStore(t *testing.T) {
db := getTestDB()
key1 := []byte("key1")
key2 := []byte("key2")
db.ZAdd(key1, ScorePair{1, []byte("one")})
db.ZAdd(key1, ScorePair{1, []byte("two")})
db.ZAdd(key2, ScorePair{2, []byte("two")})
db.ZAdd(key2, ScorePair{2, []byte("three")})
keys := [][]byte{key1, key2}
weights := []int64{1, 2}
out := []byte("out")
n, err := db.ZUnionStore(out, keys, weights, AggregateSum)
if err != nil {
t.Fatal(err.Error())
}
if n != 3 {
t.Fatal("invalid value ", n)
}
v, err := db.ZScore(out, []byte("two"))
if err != nil {
t.Fatal(err.Error())
}
if v != 5 {
t.Fatal("invalid value ", v)
}
}