diff --git a/ledis/t_zset.go b/ledis/t_zset.go index f039487..fc86248 100644 --- a/ledis/t_zset.go +++ b/ledis/t_zset.go @@ -939,7 +939,6 @@ func (db *DB) ZUnionStore(destKey []byte, srcKeys [][]byte, weights []int64, agg func (db *DB) ZInterStore(destKey []byte, srcKeys [][]byte, weights []int64, aggregate byte) (int64, error) { - var destMap = map[string]int64{} aggregateFunc := getAggregateFunc(aggregate) if aggregateFunc == nil { return 0, errInvalidAggregate @@ -958,20 +957,27 @@ func (db *DB) ZInterStore(destKey []byte, srcKeys [][]byte, weights []int64, agg } } - var keptMembers [][]byte - for i, key := range srcKeys { + var destMap = map[string]int64{} + scorePairs, err := db.ZRange(srcKeys[0], 0, -1) + if err != nil { + return 0, err + } + for _, pair := range scorePairs { + destMap[String(pair.Member)] = pair.Score * weights[0] + } + + for i, key := range srcKeys[1:] { scorePairs, err := db.ZRange(key, 0, -1) if err != nil { return 0, err } + tmpMap := map[string]int64{} for _, pair := range scorePairs { - if score, ok := destMap[String(pair.Member)]; !ok { - destMap[String(pair.Member)] = pair.Score * weights[i] - } else { - keptMembers = append(keptMembers, pair.Member) - destMap[String(pair.Member)] = aggregateFunc(score, pair.Score*weights[i]) + if score, ok := destMap[String(pair.Member)]; ok { + tmpMap[String(pair.Member)] = aggregateFunc(score, pair.Score*weights[i+1]) } } + destMap = tmpMap } t := db.zsetTx @@ -981,13 +987,12 @@ func (db *DB) ZInterStore(destKey []byte, srcKeys [][]byte, weights []int64, agg db.zDelete(t, destKey) var num int64 = 0 - for _, member := range keptMembers { - score := destMap[String(member)] - if err := checkZSetKMSize(destKey, member); err != nil { + for member, score := range destMap { + if err := checkZSetKMSize(destKey, []byte(member)); err != nil { return 0, err } - if n, err := db.zSetItem(t, destKey, score, member); err != nil { + if n, err := db.zSetItem(t, destKey, score, []byte(member)); err != nil { return 0, err } else if n == 0 { //add new @@ -1002,7 +1007,7 @@ func (db *DB) ZInterStore(destKey []byte, srcKeys [][]byte, weights []int64, agg if err := t.Commit(); err != nil { return 0, err } - return int64(len(keptMembers)), nil + return int64(len(destMap)), nil } func (db *DB) ZScan(key []byte, count int, inclusive bool) ([][]byte, error) { diff --git a/server/cmd_zset_test.go b/server/cmd_zset_test.go index e6a6a70..8c74bdc 100644 --- a/server/cmd_zset_test.go +++ b/server/cmd_zset_test.go @@ -704,4 +704,36 @@ func TestZInterStore(t *testing.T) { t.Fatal("invalid value ", n) } } + + if _, err := c.Do("zadd", "k3", "3", "three"); err != nil { + t.Fatal(err.Error()) + } + + if n, err := ledis.Int64(c.Do("zinterstore", "out", "3", "k1", "k2", "k3", "aggregate", "sum")); err != nil { + t.Fatal(err.Error()) + } else { + if n != 0 { + t.Fatal("invalid value ", n) + } + } + + if _, err := c.Do("zadd", "k3", "3", "two"); err != nil { + t.Fatal(err.Error()) + } + + if n, err := ledis.Int64(c.Do("zinterstore", "out", "3", "k1", "k2", "k3", "aggregate", "sum", "weights", "3", "2", "2")); err != nil { + t.Fatal(err.Error()) + } else { + if n != 1 { + t.Fatal("invalid value ", n) + } + } + + if n, err := ledis.Int64(c.Do("zscore", "out", "two")); err != nil { + t.Fatal(err.Error()) + } else { + if n != 14 { + t.Fatal("invalid value ", n) + } + } }