diff --git a/ledis/t_zset.go b/ledis/t_zset.go index 151f8eb..5d6bfb3 100644 --- a/ledis/t_zset.go +++ b/ledis/t_zset.go @@ -12,6 +12,10 @@ const ( MinScore int64 = -1<<63 + 1 MaxScore int64 = 1<<63 - 1 InvalidScore int64 = -1 << 63 + + AggregateSum byte = 0 + AggregateMin byte = 1 + AggregateMax byte = 2 ) type ScorePair struct { @@ -23,6 +27,9 @@ var errZSizeKey = errors.New("invalid zsize key") var errZSetKey = errors.New("invalid zset key") var errZScoreKey = errors.New("invalid zscore key") var errScoreOverflow = errors.New("zset score overflow") +var errInvalidAggregate = errors.New("invalid aggregate") +var errInvalidWeightNum = errors.New("invalid weight number") +var errInvalidSrcKeyNum = errors.New("invalid src key number") const ( zsetNScoreSep byte = '<' @@ -839,3 +846,161 @@ func (db *DB) ZPersist(key []byte) (int64, error) { err = t.Commit() return n, err } + +func getAggregateFunc(aggregate byte) func(int64, int64) int64 { + switch aggregate { + case AggregateSum: + return func(a int64, b int64) int64 { + return a + b + } + case AggregateMax: + return func(a int64, b int64) int64 { + if a > b { + return a + } + return b + } + case AggregateMin: + return func(a int64, b int64) int64 { + if a > b { + return b + } + return a + } + } + return nil +} + +func (db *DB) ZUnionStore(destKey []byte, srcKeys [][]byte, weights []int64, aggregate byte) (int64, error) { + + var destMap = map[string]int64{} + aggregateFunc := getAggregateFunc(aggregate) + if aggregateFunc == nil { + return 0, errInvalidAggregate + } + if len(srcKeys) < 1 { + return 0, errInvalidSrcKeyNum + } + if weights != nil { + if len(srcKeys) != len(weights) { + return 0, errInvalidWeightNum + } + } else { + weights = make([]int64, len(srcKeys)) + for i := 0; i < len(weights); i++ { + weights[i] = 1 + } + } + + for i, key := range srcKeys { + scorePairs, err := db.ZRange(key, 0, -1) + if err != nil { + return 0, err + } + for _, pair := range scorePairs { + if score, ok := destMap[String(pair.Member)]; !ok { + destMap[String(pair.Member)] = pair.Score + } else { + destMap[String(pair.Member)] = aggregateFunc(score, pair.Score*weights[i]) + } + } + } + + t := db.zsetTx + t.Lock() + defer t.Unlock() + + db.zDelete(t, destKey) + + var num int64 = 0 + for member, score := range destMap { + if err := checkZSetKMSize(destKey, []byte(member)); err != nil { + return 0, err + } + + if n, err := db.zSetItem(t, destKey, score, []byte(member)); err != nil { + return 0, err + } else if n == 0 { + //add new + num++ + } + } + + if _, err := db.zIncrSize(t, destKey, num); err != nil { + return 0, err + } + + //todo add binlog + if err := t.Commit(); err != nil { + return 0, err + } + return int64(len(destMap)), nil +} + +func (db *DB) ZInterStore(destKey []byte, srcKeys [][]byte, weights []int64, aggregate byte) (int64, error) { + + var destMap = map[string]int64{} + aggregateFunc := getAggregateFunc(aggregate) + if aggregateFunc == nil { + return 0, errInvalidAggregate + } + if len(srcKeys) < 1 { + return 0, errInvalidSrcKeyNum + } + if weights != nil { + if len(srcKeys) != len(weights) { + return 0, errInvalidWeightNum + } + } else { + weights = make([]int64, len(srcKeys)) + for i := 0; i < len(weights); i++ { + weights[i] = 1 + } + } + + var keptMembers [][]byte + for i, key := range srcKeys { + scorePairs, err := db.ZRange(key, 0, -1) + if err != nil { + return 0, err + } + for _, pair := range scorePairs { + if score, ok := destMap[String(pair.Member)]; !ok { + destMap[String(pair.Member)] = pair.Score + } else { + keptMembers = append(keptMembers, pair.Member) + destMap[String(pair.Member)] = aggregateFunc(score, pair.Score*weights[i]) + } + } + } + + t := db.zsetTx + t.Lock() + defer t.Unlock() + + db.zDelete(t, destKey) + + var num int64 = 0 + for _, member := range keptMembers { + score := destMap[String(member)] + if err := checkZSetKMSize(destKey, member); err != nil { + return 0, err + } + + if n, err := db.zSetItem(t, destKey, score, member); err != nil { + return 0, err + } else if n == 0 { + //add new + num++ + } + } + + if _, err := db.zIncrSize(t, destKey, num); err != nil { + return 0, err + } + //todo add binlog + if err := t.Commit(); err != nil { + return 0, err + } + return int64(len(keptMembers)), nil +} diff --git a/ledis/t_zset_test.go b/ledis/t_zset_test.go index 74cf526..a772360 100644 --- a/ledis/t_zset_test.go +++ b/ledis/t_zset_test.go @@ -264,3 +264,122 @@ func TestZSetPersist(t *testing.T) { t.Fatal(n) } } + +func TestZUnionStore(t *testing.T) { + db := getTestDB() + key1 := []byte("key1") + key2 := []byte("key2") + + db.ZAdd(key1, ScorePair{1, []byte("one")}) + db.ZAdd(key1, ScorePair{1, []byte("two")}) + + db.ZAdd(key2, ScorePair{2, []byte("two")}) + db.ZAdd(key2, ScorePair{2, []byte("three")}) + + keys := [][]byte{key1, key2} + weights := []int64{1, 2} + + out := []byte("out") + n, err := db.ZUnionStore(out, keys, weights, AggregateSum) + if err != nil { + t.Fatal(err.Error()) + } + if n != 3 { + t.Fatal("invalid value ", n) + } + + v, err := db.ZScore(out, []byte("two")) + + if err != nil { + t.Fatal(err.Error()) + } + if v != 5 { + t.Fatal("invalid value ", v) + } + + out = []byte("out") + n, err = db.ZUnionStore(out, keys, weights, AggregateMax) + if err != nil { + t.Fatal(err.Error()) + } + if n != 3 { + t.Fatal("invalid value ", n) + } + + v, err = db.ZScore(out, []byte("two")) + + if err != nil { + t.Fatal(err.Error()) + } + if v != 4 { + t.Fatal("invalid value ", v) + } + + n, err = db.ZCount(out, 0, 0XFFFE) + + if err != nil { + t.Fatal(err.Error()) + } + if n != 3 { + t.Fatal("invalid value ", v) + } +} + +func TestZInterStore(t *testing.T) { + db := getTestDB() + + key1 := []byte("key1") + key2 := []byte("key2") + + db.ZAdd(key1, ScorePair{1, []byte("one")}) + db.ZAdd(key1, ScorePair{1, []byte("two")}) + + db.ZAdd(key2, ScorePair{2, []byte("two")}) + db.ZAdd(key2, ScorePair{2, []byte("three")}) + + keys := [][]byte{key1, key2} + weights := []int64{1, 2} + out := []byte("out") + + n, err := db.ZInterStore(out, keys, weights, AggregateSum) + if err != nil { + t.Fatal(err.Error()) + } + if n != 1 { + t.Fatal("invalid value ", n) + } + v, err := db.ZScore(out, []byte("two")) + if err != nil { + t.Fatal(err.Error()) + } + if v != 5 { + t.Fatal("invalid value ", v) + } + + out = []byte("out") + n, err = db.ZInterStore(out, keys, weights, AggregateMin) + if err != nil { + t.Fatal(err.Error()) + } + if n != 1 { + t.Fatal("invalid value ", n) + } + + v, err = db.ZScore(out, []byte("two")) + + if err != nil { + t.Fatal(err.Error()) + } + if v != 1 { + t.Fatal("invalid value ", v) + } + + n, err = db.ZCount(out, 0, 0XFFFF) + if err != nil { + t.Fatal(err.Error()) + } + if n != 1 { + t.Fatal("invalid value ", n) + } + +} diff --git a/server/cmd_zset.go b/server/cmd_zset.go index e540b32..f8117fc 100644 --- a/server/cmd_zset.go +++ b/server/cmd_zset.go @@ -520,6 +520,120 @@ func zpersistCommand(req *requestContext) error { return nil } +func zparseZsetoptStore(args [][]byte) (destKey []byte, srcKeys [][]byte, weights []int64, aggregate byte, err error) { + destKey = args[0] + nKeys, err := strconv.Atoi(ledis.String(args[1])) + if err != nil { + err = ErrValue + return + } + args = args[2:] + if len(args) < nKeys { + err = ErrSyntax + return + } + + srcKeys = args[:nKeys] + + args = args[nKeys:] + + var weightsFlag = false + var aggregateFlag = false + + for len(args) > 0 { + if strings.ToLower(ledis.String(args[0])) == "weights" { + if weightsFlag { + err = ErrSyntax + return + } + + args = args[1:] + if len(args) < nKeys { + err = ErrSyntax + return + } + + weights = make([]int64, nKeys) + for i, arg := range args[:nKeys] { + if weights[i], err = ledis.StrInt64(arg, nil); err != nil { + err = ErrValue + return + } + } + args = args[nKeys:] + + weightsFlag = true + + } else if strings.ToLower(ledis.String(args[0])) == "aggregate" { + if aggregateFlag { + err = ErrSyntax + return + } + if len(args) < 2 { + err = ErrSyntax + return + } + + if strings.ToLower(ledis.String(args[1])) == "sum" { + aggregate = ledis.AggregateSum + } else if strings.ToLower(ledis.String(args[1])) == "min" { + aggregate = ledis.AggregateMin + } else if strings.ToLower(ledis.String(args[1])) == "max" { + aggregate = ledis.AggregateMax + } else { + err = ErrSyntax + return + } + args = args[2:] + aggregateFlag = true + } else { + err = ErrSyntax + return + } + } + if !aggregateFlag { + aggregate = ledis.AggregateSum + } + return +} + +func zunionstoreCommand(req *requestContext) error { + args := req.args + if len(args) < 2 { + return ErrCmdParams + } + + destKey, srcKeys, weights, aggregate, err := zparseZsetoptStore(args) + if err != nil { + return err + } + if n, err := req.db.ZUnionStore(destKey, srcKeys, weights, aggregate); err != nil { + return err + } else { + req.resp.writeInteger(n) + } + + return nil +} + +func zinterstoreCommand(req *requestContext) error { + args := req.args + if len(args) < 2 { + return ErrCmdParams + } + + destKey, srcKeys, weights, aggregate, err := zparseZsetoptStore(args) + if err != nil { + return err + } + if n, err := req.db.ZInterStore(destKey, srcKeys, weights, aggregate); err != nil { + return err + } else { + req.resp.writeInteger(n) + } + return nil +} + func init() { register("zadd", zaddCommand) register("zcard", zcardCommand) @@ -536,6 +650,9 @@ func init() { register("zrevrangebyscore", zrevrangebyscoreCommand) register("zscore", zscoreCommand) + register("zunionstore", zunionstoreCommand) + register("zinterstore", zinterstoreCommand) + //ledisdb special command register("zclear", zclearCommand) diff --git a/server/cmd_zset_test.go b/server/cmd_zset_test.go index d9b1272..e6a6a70 100644 --- a/server/cmd_zset_test.go +++ b/server/cmd_zset_test.go @@ -599,3 +599,109 @@ func TestZsetErrorParams(t *testing.T) { } } + +func TestZUnionStore(t *testing.T) { + c := getTestConn() + defer c.Close() + + if _, err := c.Do("zadd", "k1", "1", "one"); err != nil { + t.Fatal(err.Error()) + } + + if _, err := c.Do("zadd", "k1", "2", "two"); err != nil { + t.Fatal(err.Error()) + } + + if _, err := c.Do("zadd", "k2", "1", "two"); err != nil { + t.Fatal(err.Error()) + } + + if _, err := c.Do("zadd", "k2", "2", "three"); err != nil { + t.Fatal(err.Error()) + } + + if n, err := ledis.Int64(c.Do("zunionstore", "out", "2", "k1", "k2", "weights", "1", "2")); err != nil { + t.Fatal(err.Error()) + } else { + if n != 3 { + t.Fatal("invalid value ", n) + } + } + + if n, err := ledis.Int64(c.Do("zunionstore", "out", "2", "k1", "k2", "weights", "1", "2", "aggregate", "min")); err != nil { + t.Fatal(err.Error()) + } else { + if n != 3 { + t.Fatal("invalid value ", n) + } + } + + if n, err := ledis.Int64(c.Do("zunionstore", "out", "2", "k1", "k2", "aggregate", "max")); err != nil { + t.Fatal(err.Error()) + } else { + if n != 3 { + t.Fatal("invalid value ", n) + } + } + + if n, err := ledis.Int64(c.Do("zscore", "out", "two")); err != nil { + t.Fatal(err.Error()) + } else { + if n != 2 { + t.Fatal("invalid value ", n) + } + } +} + +func TestZInterStore(t *testing.T) { + c := getTestConn() + defer c.Close() + + if _, err := c.Do("zadd", "k1", "1", "one"); err != nil { + t.Fatal(err.Error()) + } + + if _, err := c.Do("zadd", "k1", "2", "two"); err != nil { + t.Fatal(err.Error()) + } + + if _, err := c.Do("zadd", "k2", "1", "two"); err != nil { + t.Fatal(err.Error()) + } + + if _, err := c.Do("zadd", "k2", "2", "three"); err != nil { + t.Fatal(err.Error()) + } + + if n, err := ledis.Int64(c.Do("zinterstore", "out", "2", "k1", "k2", "weights", "1", "2")); err != nil { + t.Fatal(err.Error()) + } else { + if n != 1 { + t.Fatal("invalid value ", n) + } + } + + if n, err := ledis.Int64(c.Do("zinterstore", "out", "2", "k1", "k2", "aggregate", "min", "weights", "1", "2")); err != nil { + t.Fatal(err.Error()) + } else { + if n != 1 { + t.Fatal("invalid value ", n) + } + } + + if n, err := ledis.Int64(c.Do("zinterstore", "out", "2", "k1", "k2", "aggregate", "sum")); err != nil { + t.Fatal(err.Error()) + } else { + if n != 1 { + t.Fatal("invalid value ", n) + } + } + + if n, err := ledis.Int64(c.Do("zscore", "out", "two")); err != nil { + t.Fatal(err.Error()) + } else { + if n != 3 { + t.Fatal("invalid value ", n) + } + } +}