From 3927d2bd8fe3662ec7251fe240ddae8cf1bb24b2 Mon Sep 17 00:00:00 2001 From: wenyekui Date: Tue, 12 Aug 2014 18:06:02 +0800 Subject: [PATCH] add zunionstore & zintrestore interface --- ledis/t_zset.go | 158 +++++++++++++++++++++++++++++++++++++++++++ ledis/t_zset_test.go | 33 +++++++++ 2 files changed, 191 insertions(+) diff --git a/ledis/t_zset.go b/ledis/t_zset.go index 151f8eb..add519f 100644 --- a/ledis/t_zset.go +++ b/ledis/t_zset.go @@ -12,6 +12,10 @@ const ( MinScore int64 = -1<<63 + 1 MaxScore int64 = 1<<63 - 1 InvalidScore int64 = -1 << 63 + + AggregateSum byte = 0 + AggregateMin byte = 1 + AggregateMax byte = 2 ) type ScorePair struct { @@ -23,6 +27,9 @@ var errZSizeKey = errors.New("invalid zsize key") var errZSetKey = errors.New("invalid zset key") var errZScoreKey = errors.New("invalid zscore key") var errScoreOverflow = errors.New("zset score overflow") +var errInvalidAggregate = errors.New("invalid aggregate") +var errInvalidWeightNum = errors.New("invalid weight number") +var errInvalidSrcKeyNum = errors.New("invalid src key number") const ( zsetNScoreSep byte = '<' @@ -839,3 +846,154 @@ func (db *DB) ZPersist(key []byte) (int64, error) { err = t.Commit() return n, err } + +func getAggregateFunc(aggregate byte) func(int64, int64) int64 { + switch aggregate { + case AggregateSum: + return func(a int64, b int64) int64 { + return a + b + } + case AggregateMax: + return func(a int64, b int64) int64 { + if a > b { + return a + } + return b + } + case AggregateMin: + return func(a int64, b int64) int64 { + if a > b { + return b + } + return a + } + } + return nil +} + +func (db *DB) ZUnionStore(destKey []byte, srcKeys [][]byte, weights []int64, aggregate byte) (int64, error) { + + var destMap = map[string]int64{} + aggregateFunc := getAggregateFunc(aggregate) + if aggregateFunc == nil { + return 0, errInvalidAggregate + } + if len(srcKeys) < 1 { + return 0, errInvalidSrcKeyNum + } + if weights != nil { + if len(srcKeys) != len(weights) { + return 0, errInvalidWeightNum + } + } else { + weights = make([]int64, len(srcKeys)) + for i := 0; i < len(weights); i++ { + weights[i] = 1 + } + } + + for i, key := range srcKeys { + scorePairs, err := db.ZRange(key, 0, -1) + if err != nil { + return 0, err + } + for _, pair := range scorePairs { + if score, ok := destMap[String(pair.Member)]; !ok { + destMap[String(pair.Member)] = pair.Score + } else { + destMap[String(pair.Member)] = aggregateFunc(score, pair.Score*weights[i]) + } + } + } + + t := db.zsetTx + t.Lock() + defer t.Unlock() + + db.zDelete(t, destKey) + + var num int64 = 0 + for member, score := range destMap { + if err := checkZSetKMSize(destKey, []byte(member)); err != nil { + return 0, err + } + + if n, err := db.zSetItem(t, destKey, score, []byte(member)); err != nil { + return 0, err + } else if n == 0 { + //add new + num++ + } + } + + if _, err := db.zIncrSize(t, destKey, num); err != nil { + return 0, err + } + + //todo add binlog + err := t.Commit() + if err != nil { + return 0, err + } + return int64(len(destMap)), nil +} + +func (db *DB) ZInterStore(destKey []byte, srcKeys [][]byte, weights []int64, aggregate byte) (int64, error) { + + var destMap = map[string]int64{} + aggregateFunc := getAggregateFunc(aggregate) + if aggregateFunc == nil { + return 0, errInvalidAggregate + } + if len(srcKeys) < 1 { + return 0, errInvalidSrcKeyNum + } + if weights != nil { + if len(srcKeys) != len(weights) { + return 0, errInvalidWeightNum + } + } else { + weights = make([]int64, len(srcKeys)) + for i := 0; i < len(weights); i++ { + weights[i] = 1 + } + } + + var keptMembers [][]byte + for i, key := range srcKeys { + scorePairs, err := db.ZRange(key, 0, -1) + if err != nil { + return 0, err + } + for _, pair := range scorePairs { + if score, ok := destMap[String(pair.Member)]; !ok { + destMap[String(pair.Member)] = pair.Score + } else { + keptMembers = append(keptMembers, pair.Member) + destMap[String(pair.Member)] = aggregateFunc(score, pair.Score*weights[i]) + } + } + } + + t := db.zsetTx + t.Lock() + defer t.Unlock() + + db.zDelete(t, destKey) + + var num int64 = 0 + for _, member := range keptMembers { + score := destMap[String(member)] + if err := checkZSetKMSize(destKey, member); err != nil { + return 0, err + } + + if n, err := db.zSetItem(t, destKey, score, member); err != nil { + return 0, err + } else if n == 0 { + //add new + num++ + } + } + return int64(len(keptMembers)), nil +} diff --git a/ledis/t_zset_test.go b/ledis/t_zset_test.go index 74cf526..857f005 100644 --- a/ledis/t_zset_test.go +++ b/ledis/t_zset_test.go @@ -264,3 +264,36 @@ func TestZSetPersist(t *testing.T) { t.Fatal(n) } } + +func TestZUnionStore(t *testing.T) { + db := getTestDB() + key1 := []byte("key1") + key2 := []byte("key2") + + db.ZAdd(key1, ScorePair{1, []byte("one")}) + db.ZAdd(key1, ScorePair{1, []byte("two")}) + + db.ZAdd(key2, ScorePair{2, []byte("two")}) + db.ZAdd(key2, ScorePair{2, []byte("three")}) + + keys := [][]byte{key1, key2} + weights := []int64{1, 2} + + out := []byte("out") + n, err := db.ZUnionStore(out, keys, weights, AggregateSum) + if err != nil { + t.Fatal(err.Error()) + } + if n != 3 { + t.Fatal("invalid value ", n) + } + + v, err := db.ZScore(out, []byte("two")) + + if err != nil { + t.Fatal(err.Error()) + } + if v != 5 { + t.Fatal("invalid value ", v) + } +}