forked from mirror/ledisdb
add zunionstore & zintrestore interface
This commit is contained in:
parent
de98e7887e
commit
3927d2bd8f
158
ledis/t_zset.go
158
ledis/t_zset.go
|
@ -12,6 +12,10 @@ const (
|
||||||
MinScore int64 = -1<<63 + 1
|
MinScore int64 = -1<<63 + 1
|
||||||
MaxScore int64 = 1<<63 - 1
|
MaxScore int64 = 1<<63 - 1
|
||||||
InvalidScore int64 = -1 << 63
|
InvalidScore int64 = -1 << 63
|
||||||
|
|
||||||
|
AggregateSum byte = 0
|
||||||
|
AggregateMin byte = 1
|
||||||
|
AggregateMax byte = 2
|
||||||
)
|
)
|
||||||
|
|
||||||
type ScorePair struct {
|
type ScorePair struct {
|
||||||
|
@ -23,6 +27,9 @@ var errZSizeKey = errors.New("invalid zsize key")
|
||||||
var errZSetKey = errors.New("invalid zset key")
|
var errZSetKey = errors.New("invalid zset key")
|
||||||
var errZScoreKey = errors.New("invalid zscore key")
|
var errZScoreKey = errors.New("invalid zscore key")
|
||||||
var errScoreOverflow = errors.New("zset score overflow")
|
var errScoreOverflow = errors.New("zset score overflow")
|
||||||
|
var errInvalidAggregate = errors.New("invalid aggregate")
|
||||||
|
var errInvalidWeightNum = errors.New("invalid weight number")
|
||||||
|
var errInvalidSrcKeyNum = errors.New("invalid src key number")
|
||||||
|
|
||||||
const (
|
const (
|
||||||
zsetNScoreSep byte = '<'
|
zsetNScoreSep byte = '<'
|
||||||
|
@ -839,3 +846,154 @@ func (db *DB) ZPersist(key []byte) (int64, error) {
|
||||||
err = t.Commit()
|
err = t.Commit()
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getAggregateFunc(aggregate byte) func(int64, int64) int64 {
|
||||||
|
switch aggregate {
|
||||||
|
case AggregateSum:
|
||||||
|
return func(a int64, b int64) int64 {
|
||||||
|
return a + b
|
||||||
|
}
|
||||||
|
case AggregateMax:
|
||||||
|
return func(a int64, b int64) int64 {
|
||||||
|
if a > b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
case AggregateMin:
|
||||||
|
return func(a int64, b int64) int64 {
|
||||||
|
if a > b {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) ZUnionStore(destKey []byte, srcKeys [][]byte, weights []int64, aggregate byte) (int64, error) {
|
||||||
|
|
||||||
|
var destMap = map[string]int64{}
|
||||||
|
aggregateFunc := getAggregateFunc(aggregate)
|
||||||
|
if aggregateFunc == nil {
|
||||||
|
return 0, errInvalidAggregate
|
||||||
|
}
|
||||||
|
if len(srcKeys) < 1 {
|
||||||
|
return 0, errInvalidSrcKeyNum
|
||||||
|
}
|
||||||
|
if weights != nil {
|
||||||
|
if len(srcKeys) != len(weights) {
|
||||||
|
return 0, errInvalidWeightNum
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
weights = make([]int64, len(srcKeys))
|
||||||
|
for i := 0; i < len(weights); i++ {
|
||||||
|
weights[i] = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, key := range srcKeys {
|
||||||
|
scorePairs, err := db.ZRange(key, 0, -1)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
for _, pair := range scorePairs {
|
||||||
|
if score, ok := destMap[String(pair.Member)]; !ok {
|
||||||
|
destMap[String(pair.Member)] = pair.Score
|
||||||
|
} else {
|
||||||
|
destMap[String(pair.Member)] = aggregateFunc(score, pair.Score*weights[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t := db.zsetTx
|
||||||
|
t.Lock()
|
||||||
|
defer t.Unlock()
|
||||||
|
|
||||||
|
db.zDelete(t, destKey)
|
||||||
|
|
||||||
|
var num int64 = 0
|
||||||
|
for member, score := range destMap {
|
||||||
|
if err := checkZSetKMSize(destKey, []byte(member)); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if n, err := db.zSetItem(t, destKey, score, []byte(member)); err != nil {
|
||||||
|
return 0, err
|
||||||
|
} else if n == 0 {
|
||||||
|
//add new
|
||||||
|
num++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.zIncrSize(t, destKey, num); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
//todo add binlog
|
||||||
|
err := t.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return int64(len(destMap)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) ZInterStore(destKey []byte, srcKeys [][]byte, weights []int64, aggregate byte) (int64, error) {
|
||||||
|
|
||||||
|
var destMap = map[string]int64{}
|
||||||
|
aggregateFunc := getAggregateFunc(aggregate)
|
||||||
|
if aggregateFunc == nil {
|
||||||
|
return 0, errInvalidAggregate
|
||||||
|
}
|
||||||
|
if len(srcKeys) < 1 {
|
||||||
|
return 0, errInvalidSrcKeyNum
|
||||||
|
}
|
||||||
|
if weights != nil {
|
||||||
|
if len(srcKeys) != len(weights) {
|
||||||
|
return 0, errInvalidWeightNum
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
weights = make([]int64, len(srcKeys))
|
||||||
|
for i := 0; i < len(weights); i++ {
|
||||||
|
weights[i] = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var keptMembers [][]byte
|
||||||
|
for i, key := range srcKeys {
|
||||||
|
scorePairs, err := db.ZRange(key, 0, -1)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
for _, pair := range scorePairs {
|
||||||
|
if score, ok := destMap[String(pair.Member)]; !ok {
|
||||||
|
destMap[String(pair.Member)] = pair.Score
|
||||||
|
} else {
|
||||||
|
keptMembers = append(keptMembers, pair.Member)
|
||||||
|
destMap[String(pair.Member)] = aggregateFunc(score, pair.Score*weights[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t := db.zsetTx
|
||||||
|
t.Lock()
|
||||||
|
defer t.Unlock()
|
||||||
|
|
||||||
|
db.zDelete(t, destKey)
|
||||||
|
|
||||||
|
var num int64 = 0
|
||||||
|
for _, member := range keptMembers {
|
||||||
|
score := destMap[String(member)]
|
||||||
|
if err := checkZSetKMSize(destKey, member); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if n, err := db.zSetItem(t, destKey, score, member); err != nil {
|
||||||
|
return 0, err
|
||||||
|
} else if n == 0 {
|
||||||
|
//add new
|
||||||
|
num++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return int64(len(keptMembers)), nil
|
||||||
|
}
|
||||||
|
|
|
@ -264,3 +264,36 @@ func TestZSetPersist(t *testing.T) {
|
||||||
t.Fatal(n)
|
t.Fatal(n)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestZUnionStore(t *testing.T) {
|
||||||
|
db := getTestDB()
|
||||||
|
key1 := []byte("key1")
|
||||||
|
key2 := []byte("key2")
|
||||||
|
|
||||||
|
db.ZAdd(key1, ScorePair{1, []byte("one")})
|
||||||
|
db.ZAdd(key1, ScorePair{1, []byte("two")})
|
||||||
|
|
||||||
|
db.ZAdd(key2, ScorePair{2, []byte("two")})
|
||||||
|
db.ZAdd(key2, ScorePair{2, []byte("three")})
|
||||||
|
|
||||||
|
keys := [][]byte{key1, key2}
|
||||||
|
weights := []int64{1, 2}
|
||||||
|
|
||||||
|
out := []byte("out")
|
||||||
|
n, err := db.ZUnionStore(out, keys, weights, AggregateSum)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err.Error())
|
||||||
|
}
|
||||||
|
if n != 3 {
|
||||||
|
t.Fatal("invalid value ", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := db.ZScore(out, []byte("two"))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err.Error())
|
||||||
|
}
|
||||||
|
if v != 5 {
|
||||||
|
t.Fatal("invalid value ", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue