forked from mirror/ledisdb
Merge branch 'zset-enhance-feature' into develop
This commit is contained in:
commit
826961e0fc
165
ledis/t_zset.go
165
ledis/t_zset.go
|
@ -12,6 +12,10 @@ const (
|
|||
MinScore int64 = -1<<63 + 1
|
||||
MaxScore int64 = 1<<63 - 1
|
||||
InvalidScore int64 = -1 << 63
|
||||
|
||||
AggregateSum byte = 0
|
||||
AggregateMin byte = 1
|
||||
AggregateMax byte = 2
|
||||
)
|
||||
|
||||
type ScorePair struct {
|
||||
|
@ -23,6 +27,9 @@ var errZSizeKey = errors.New("invalid zsize key")
|
|||
var errZSetKey = errors.New("invalid zset key")
|
||||
var errZScoreKey = errors.New("invalid zscore key")
|
||||
var errScoreOverflow = errors.New("zset score overflow")
|
||||
var errInvalidAggregate = errors.New("invalid aggregate")
|
||||
var errInvalidWeightNum = errors.New("invalid weight number")
|
||||
var errInvalidSrcKeyNum = errors.New("invalid src key number")
|
||||
|
||||
const (
|
||||
zsetNScoreSep byte = '<'
|
||||
|
@ -839,3 +846,161 @@ func (db *DB) ZPersist(key []byte) (int64, error) {
|
|||
err = t.Commit()
|
||||
return n, err
|
||||
}
|
||||
|
||||
func getAggregateFunc(aggregate byte) func(int64, int64) int64 {
|
||||
switch aggregate {
|
||||
case AggregateSum:
|
||||
return func(a int64, b int64) int64 {
|
||||
return a + b
|
||||
}
|
||||
case AggregateMax:
|
||||
return func(a int64, b int64) int64 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
case AggregateMin:
|
||||
return func(a int64, b int64) int64 {
|
||||
if a > b {
|
||||
return b
|
||||
}
|
||||
return a
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) ZUnionStore(destKey []byte, srcKeys [][]byte, weights []int64, aggregate byte) (int64, error) {
|
||||
|
||||
var destMap = map[string]int64{}
|
||||
aggregateFunc := getAggregateFunc(aggregate)
|
||||
if aggregateFunc == nil {
|
||||
return 0, errInvalidAggregate
|
||||
}
|
||||
if len(srcKeys) < 1 {
|
||||
return 0, errInvalidSrcKeyNum
|
||||
}
|
||||
if weights != nil {
|
||||
if len(srcKeys) != len(weights) {
|
||||
return 0, errInvalidWeightNum
|
||||
}
|
||||
} else {
|
||||
weights = make([]int64, len(srcKeys))
|
||||
for i := 0; i < len(weights); i++ {
|
||||
weights[i] = 1
|
||||
}
|
||||
}
|
||||
|
||||
for i, key := range srcKeys {
|
||||
scorePairs, err := db.ZRange(key, 0, -1)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for _, pair := range scorePairs {
|
||||
if score, ok := destMap[String(pair.Member)]; !ok {
|
||||
destMap[String(pair.Member)] = pair.Score
|
||||
} else {
|
||||
destMap[String(pair.Member)] = aggregateFunc(score, pair.Score*weights[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t := db.zsetTx
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
|
||||
db.zDelete(t, destKey)
|
||||
|
||||
var num int64 = 0
|
||||
for member, score := range destMap {
|
||||
if err := checkZSetKMSize(destKey, []byte(member)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if n, err := db.zSetItem(t, destKey, score, []byte(member)); err != nil {
|
||||
return 0, err
|
||||
} else if n == 0 {
|
||||
//add new
|
||||
num++
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := db.zIncrSize(t, destKey, num); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
//todo add binlog
|
||||
if err := t.Commit(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(len(destMap)), nil
|
||||
}
|
||||
|
||||
func (db *DB) ZInterStore(destKey []byte, srcKeys [][]byte, weights []int64, aggregate byte) (int64, error) {
|
||||
|
||||
var destMap = map[string]int64{}
|
||||
aggregateFunc := getAggregateFunc(aggregate)
|
||||
if aggregateFunc == nil {
|
||||
return 0, errInvalidAggregate
|
||||
}
|
||||
if len(srcKeys) < 1 {
|
||||
return 0, errInvalidSrcKeyNum
|
||||
}
|
||||
if weights != nil {
|
||||
if len(srcKeys) != len(weights) {
|
||||
return 0, errInvalidWeightNum
|
||||
}
|
||||
} else {
|
||||
weights = make([]int64, len(srcKeys))
|
||||
for i := 0; i < len(weights); i++ {
|
||||
weights[i] = 1
|
||||
}
|
||||
}
|
||||
|
||||
var keptMembers [][]byte
|
||||
for i, key := range srcKeys {
|
||||
scorePairs, err := db.ZRange(key, 0, -1)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for _, pair := range scorePairs {
|
||||
if score, ok := destMap[String(pair.Member)]; !ok {
|
||||
destMap[String(pair.Member)] = pair.Score
|
||||
} else {
|
||||
keptMembers = append(keptMembers, pair.Member)
|
||||
destMap[String(pair.Member)] = aggregateFunc(score, pair.Score*weights[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t := db.zsetTx
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
|
||||
db.zDelete(t, destKey)
|
||||
|
||||
var num int64 = 0
|
||||
for _, member := range keptMembers {
|
||||
score := destMap[String(member)]
|
||||
if err := checkZSetKMSize(destKey, member); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if n, err := db.zSetItem(t, destKey, score, member); err != nil {
|
||||
return 0, err
|
||||
} else if n == 0 {
|
||||
//add new
|
||||
num++
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := db.zIncrSize(t, destKey, num); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
//todo add binlog
|
||||
if err := t.Commit(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(len(keptMembers)), nil
|
||||
}
|
||||
|
|
|
@ -264,3 +264,122 @@ func TestZSetPersist(t *testing.T) {
|
|||
t.Fatal(n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestZUnionStore(t *testing.T) {
|
||||
db := getTestDB()
|
||||
key1 := []byte("key1")
|
||||
key2 := []byte("key2")
|
||||
|
||||
db.ZAdd(key1, ScorePair{1, []byte("one")})
|
||||
db.ZAdd(key1, ScorePair{1, []byte("two")})
|
||||
|
||||
db.ZAdd(key2, ScorePair{2, []byte("two")})
|
||||
db.ZAdd(key2, ScorePair{2, []byte("three")})
|
||||
|
||||
keys := [][]byte{key1, key2}
|
||||
weights := []int64{1, 2}
|
||||
|
||||
out := []byte("out")
|
||||
n, err := db.ZUnionStore(out, keys, weights, AggregateSum)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
if n != 3 {
|
||||
t.Fatal("invalid value ", n)
|
||||
}
|
||||
|
||||
v, err := db.ZScore(out, []byte("two"))
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
if v != 5 {
|
||||
t.Fatal("invalid value ", v)
|
||||
}
|
||||
|
||||
out = []byte("out")
|
||||
n, err = db.ZUnionStore(out, keys, weights, AggregateMax)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
if n != 3 {
|
||||
t.Fatal("invalid value ", n)
|
||||
}
|
||||
|
||||
v, err = db.ZScore(out, []byte("two"))
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
if v != 4 {
|
||||
t.Fatal("invalid value ", v)
|
||||
}
|
||||
|
||||
n, err = db.ZCount(out, 0, 0XFFFE)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
if n != 3 {
|
||||
t.Fatal("invalid value ", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestZInterStore(t *testing.T) {
|
||||
db := getTestDB()
|
||||
|
||||
key1 := []byte("key1")
|
||||
key2 := []byte("key2")
|
||||
|
||||
db.ZAdd(key1, ScorePair{1, []byte("one")})
|
||||
db.ZAdd(key1, ScorePair{1, []byte("two")})
|
||||
|
||||
db.ZAdd(key2, ScorePair{2, []byte("two")})
|
||||
db.ZAdd(key2, ScorePair{2, []byte("three")})
|
||||
|
||||
keys := [][]byte{key1, key2}
|
||||
weights := []int64{1, 2}
|
||||
out := []byte("out")
|
||||
|
||||
n, err := db.ZInterStore(out, keys, weights, AggregateSum)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
if n != 1 {
|
||||
t.Fatal("invalid value ", n)
|
||||
}
|
||||
v, err := db.ZScore(out, []byte("two"))
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
if v != 5 {
|
||||
t.Fatal("invalid value ", v)
|
||||
}
|
||||
|
||||
out = []byte("out")
|
||||
n, err = db.ZInterStore(out, keys, weights, AggregateMin)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
if n != 1 {
|
||||
t.Fatal("invalid value ", n)
|
||||
}
|
||||
|
||||
v, err = db.ZScore(out, []byte("two"))
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
if v != 1 {
|
||||
t.Fatal("invalid value ", v)
|
||||
}
|
||||
|
||||
n, err = db.ZCount(out, 0, 0XFFFF)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
if n != 1 {
|
||||
t.Fatal("invalid value ", n)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -520,6 +520,120 @@ func zpersistCommand(req *requestContext) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func zparseZsetoptStore(args [][]byte) (destKey []byte, srcKeys [][]byte, weights []int64, aggregate byte, err error) {
|
||||
destKey = args[0]
|
||||
nKeys, err := strconv.Atoi(ledis.String(args[1]))
|
||||
if err != nil {
|
||||
err = ErrValue
|
||||
return
|
||||
}
|
||||
args = args[2:]
|
||||
if len(args) < nKeys {
|
||||
err = ErrSyntax
|
||||
return
|
||||
}
|
||||
|
||||
srcKeys = args[:nKeys]
|
||||
|
||||
args = args[nKeys:]
|
||||
|
||||
var weightsFlag = false
|
||||
var aggregateFlag = false
|
||||
|
||||
for len(args) > 0 {
|
||||
if strings.ToLower(ledis.String(args[0])) == "weights" {
|
||||
if weightsFlag {
|
||||
err = ErrSyntax
|
||||
return
|
||||
}
|
||||
|
||||
args = args[1:]
|
||||
if len(args) < nKeys {
|
||||
err = ErrSyntax
|
||||
return
|
||||
}
|
||||
|
||||
weights = make([]int64, nKeys)
|
||||
for i, arg := range args[:nKeys] {
|
||||
if weights[i], err = ledis.StrInt64(arg, nil); err != nil {
|
||||
err = ErrValue
|
||||
return
|
||||
}
|
||||
}
|
||||
args = args[nKeys:]
|
||||
|
||||
weightsFlag = true
|
||||
|
||||
} else if strings.ToLower(ledis.String(args[0])) == "aggregate" {
|
||||
if aggregateFlag {
|
||||
err = ErrSyntax
|
||||
return
|
||||
}
|
||||
if len(args) < 2 {
|
||||
err = ErrSyntax
|
||||
return
|
||||
}
|
||||
|
||||
if strings.ToLower(ledis.String(args[1])) == "sum" {
|
||||
aggregate = ledis.AggregateSum
|
||||
} else if strings.ToLower(ledis.String(args[1])) == "min" {
|
||||
aggregate = ledis.AggregateMin
|
||||
} else if strings.ToLower(ledis.String(args[1])) == "max" {
|
||||
aggregate = ledis.AggregateMax
|
||||
} else {
|
||||
err = ErrSyntax
|
||||
return
|
||||
}
|
||||
args = args[2:]
|
||||
aggregateFlag = true
|
||||
} else {
|
||||
err = ErrSyntax
|
||||
return
|
||||
}
|
||||
}
|
||||
if !aggregateFlag {
|
||||
aggregate = ledis.AggregateSum
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func zunionstoreCommand(req *requestContext) error {
|
||||
args := req.args
|
||||
if len(args) < 2 {
|
||||
return ErrCmdParams
|
||||
}
|
||||
|
||||
destKey, srcKeys, weights, aggregate, err := zparseZsetoptStore(args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n, err := req.db.ZUnionStore(destKey, srcKeys, weights, aggregate); err != nil {
|
||||
return err
|
||||
} else {
|
||||
req.resp.writeInteger(n)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func zinterstoreCommand(req *requestContext) error {
|
||||
args := req.args
|
||||
if len(args) < 2 {
|
||||
return ErrCmdParams
|
||||
}
|
||||
|
||||
destKey, srcKeys, weights, aggregate, err := zparseZsetoptStore(args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n, err := req.db.ZInterStore(destKey, srcKeys, weights, aggregate); err != nil {
|
||||
return err
|
||||
} else {
|
||||
req.resp.writeInteger(n)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
register("zadd", zaddCommand)
|
||||
register("zcard", zcardCommand)
|
||||
|
@ -536,6 +650,9 @@ func init() {
|
|||
register("zrevrangebyscore", zrevrangebyscoreCommand)
|
||||
register("zscore", zscoreCommand)
|
||||
|
||||
register("zunionstore", zunionstoreCommand)
|
||||
register("zinterstore", zinterstoreCommand)
|
||||
|
||||
//ledisdb special command
|
||||
|
||||
register("zclear", zclearCommand)
|
||||
|
|
|
@ -599,3 +599,109 @@ func TestZsetErrorParams(t *testing.T) {
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
func TestZUnionStore(t *testing.T) {
|
||||
c := getTestConn()
|
||||
defer c.Close()
|
||||
|
||||
if _, err := c.Do("zadd", "k1", "1", "one"); err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if _, err := c.Do("zadd", "k1", "2", "two"); err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if _, err := c.Do("zadd", "k2", "1", "two"); err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if _, err := c.Do("zadd", "k2", "2", "three"); err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if n, err := ledis.Int64(c.Do("zunionstore", "out", "2", "k1", "k2", "weights", "1", "2")); err != nil {
|
||||
t.Fatal(err.Error())
|
||||
} else {
|
||||
if n != 3 {
|
||||
t.Fatal("invalid value ", n)
|
||||
}
|
||||
}
|
||||
|
||||
if n, err := ledis.Int64(c.Do("zunionstore", "out", "2", "k1", "k2", "weights", "1", "2", "aggregate", "min")); err != nil {
|
||||
t.Fatal(err.Error())
|
||||
} else {
|
||||
if n != 3 {
|
||||
t.Fatal("invalid value ", n)
|
||||
}
|
||||
}
|
||||
|
||||
if n, err := ledis.Int64(c.Do("zunionstore", "out", "2", "k1", "k2", "aggregate", "max")); err != nil {
|
||||
t.Fatal(err.Error())
|
||||
} else {
|
||||
if n != 3 {
|
||||
t.Fatal("invalid value ", n)
|
||||
}
|
||||
}
|
||||
|
||||
if n, err := ledis.Int64(c.Do("zscore", "out", "two")); err != nil {
|
||||
t.Fatal(err.Error())
|
||||
} else {
|
||||
if n != 2 {
|
||||
t.Fatal("invalid value ", n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestZInterStore(t *testing.T) {
|
||||
c := getTestConn()
|
||||
defer c.Close()
|
||||
|
||||
if _, err := c.Do("zadd", "k1", "1", "one"); err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if _, err := c.Do("zadd", "k1", "2", "two"); err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if _, err := c.Do("zadd", "k2", "1", "two"); err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if _, err := c.Do("zadd", "k2", "2", "three"); err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if n, err := ledis.Int64(c.Do("zinterstore", "out", "2", "k1", "k2", "weights", "1", "2")); err != nil {
|
||||
t.Fatal(err.Error())
|
||||
} else {
|
||||
if n != 1 {
|
||||
t.Fatal("invalid value ", n)
|
||||
}
|
||||
}
|
||||
|
||||
if n, err := ledis.Int64(c.Do("zinterstore", "out", "2", "k1", "k2", "aggregate", "min", "weights", "1", "2")); err != nil {
|
||||
t.Fatal(err.Error())
|
||||
} else {
|
||||
if n != 1 {
|
||||
t.Fatal("invalid value ", n)
|
||||
}
|
||||
}
|
||||
|
||||
if n, err := ledis.Int64(c.Do("zinterstore", "out", "2", "k1", "k2", "aggregate", "sum")); err != nil {
|
||||
t.Fatal(err.Error())
|
||||
} else {
|
||||
if n != 1 {
|
||||
t.Fatal("invalid value ", n)
|
||||
}
|
||||
}
|
||||
|
||||
if n, err := ledis.Int64(c.Do("zscore", "out", "two")); err != nil {
|
||||
t.Fatal(err.Error())
|
||||
} else {
|
||||
if n != 3 {
|
||||
t.Fatal("invalid value ", n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue