Merge branch 'zset-enhance-feature' into develop

This commit is contained in:
wenyekui 2014-08-13 14:24:25 +08:00
commit 826961e0fc
4 changed files with 507 additions and 0 deletions

View File

@ -12,6 +12,10 @@ const (
MinScore int64 = -1<<63 + 1
MaxScore int64 = 1<<63 - 1
InvalidScore int64 = -1 << 63
AggregateSum byte = 0
AggregateMin byte = 1
AggregateMax byte = 2
)
type ScorePair struct {
@ -23,6 +27,9 @@ var errZSizeKey = errors.New("invalid zsize key")
var errZSetKey = errors.New("invalid zset key")
var errZScoreKey = errors.New("invalid zscore key")
var errScoreOverflow = errors.New("zset score overflow")
var errInvalidAggregate = errors.New("invalid aggregate")
var errInvalidWeightNum = errors.New("invalid weight number")
var errInvalidSrcKeyNum = errors.New("invalid src key number")
const (
zsetNScoreSep byte = '<'
@ -839,3 +846,161 @@ func (db *DB) ZPersist(key []byte) (int64, error) {
err = t.Commit()
return n, err
}
func getAggregateFunc(aggregate byte) func(int64, int64) int64 {
switch aggregate {
case AggregateSum:
return func(a int64, b int64) int64 {
return a + b
}
case AggregateMax:
return func(a int64, b int64) int64 {
if a > b {
return a
}
return b
}
case AggregateMin:
return func(a int64, b int64) int64 {
if a > b {
return b
}
return a
}
}
return nil
}
func (db *DB) ZUnionStore(destKey []byte, srcKeys [][]byte, weights []int64, aggregate byte) (int64, error) {
var destMap = map[string]int64{}
aggregateFunc := getAggregateFunc(aggregate)
if aggregateFunc == nil {
return 0, errInvalidAggregate
}
if len(srcKeys) < 1 {
return 0, errInvalidSrcKeyNum
}
if weights != nil {
if len(srcKeys) != len(weights) {
return 0, errInvalidWeightNum
}
} else {
weights = make([]int64, len(srcKeys))
for i := 0; i < len(weights); i++ {
weights[i] = 1
}
}
for i, key := range srcKeys {
scorePairs, err := db.ZRange(key, 0, -1)
if err != nil {
return 0, err
}
for _, pair := range scorePairs {
if score, ok := destMap[String(pair.Member)]; !ok {
destMap[String(pair.Member)] = pair.Score
} else {
destMap[String(pair.Member)] = aggregateFunc(score, pair.Score*weights[i])
}
}
}
t := db.zsetTx
t.Lock()
defer t.Unlock()
db.zDelete(t, destKey)
var num int64 = 0
for member, score := range destMap {
if err := checkZSetKMSize(destKey, []byte(member)); err != nil {
return 0, err
}
if n, err := db.zSetItem(t, destKey, score, []byte(member)); err != nil {
return 0, err
} else if n == 0 {
//add new
num++
}
}
if _, err := db.zIncrSize(t, destKey, num); err != nil {
return 0, err
}
//todo add binlog
if err := t.Commit(); err != nil {
return 0, err
}
return int64(len(destMap)), nil
}
func (db *DB) ZInterStore(destKey []byte, srcKeys [][]byte, weights []int64, aggregate byte) (int64, error) {
var destMap = map[string]int64{}
aggregateFunc := getAggregateFunc(aggregate)
if aggregateFunc == nil {
return 0, errInvalidAggregate
}
if len(srcKeys) < 1 {
return 0, errInvalidSrcKeyNum
}
if weights != nil {
if len(srcKeys) != len(weights) {
return 0, errInvalidWeightNum
}
} else {
weights = make([]int64, len(srcKeys))
for i := 0; i < len(weights); i++ {
weights[i] = 1
}
}
var keptMembers [][]byte
for i, key := range srcKeys {
scorePairs, err := db.ZRange(key, 0, -1)
if err != nil {
return 0, err
}
for _, pair := range scorePairs {
if score, ok := destMap[String(pair.Member)]; !ok {
destMap[String(pair.Member)] = pair.Score
} else {
keptMembers = append(keptMembers, pair.Member)
destMap[String(pair.Member)] = aggregateFunc(score, pair.Score*weights[i])
}
}
}
t := db.zsetTx
t.Lock()
defer t.Unlock()
db.zDelete(t, destKey)
var num int64 = 0
for _, member := range keptMembers {
score := destMap[String(member)]
if err := checkZSetKMSize(destKey, member); err != nil {
return 0, err
}
if n, err := db.zSetItem(t, destKey, score, member); err != nil {
return 0, err
} else if n == 0 {
//add new
num++
}
}
if _, err := db.zIncrSize(t, destKey, num); err != nil {
return 0, err
}
//todo add binlog
if err := t.Commit(); err != nil {
return 0, err
}
return int64(len(keptMembers)), nil
}

View File

@ -264,3 +264,122 @@ func TestZSetPersist(t *testing.T) {
t.Fatal(n)
}
}
func TestZUnionStore(t *testing.T) {
db := getTestDB()
key1 := []byte("key1")
key2 := []byte("key2")
db.ZAdd(key1, ScorePair{1, []byte("one")})
db.ZAdd(key1, ScorePair{1, []byte("two")})
db.ZAdd(key2, ScorePair{2, []byte("two")})
db.ZAdd(key2, ScorePair{2, []byte("three")})
keys := [][]byte{key1, key2}
weights := []int64{1, 2}
out := []byte("out")
n, err := db.ZUnionStore(out, keys, weights, AggregateSum)
if err != nil {
t.Fatal(err.Error())
}
if n != 3 {
t.Fatal("invalid value ", n)
}
v, err := db.ZScore(out, []byte("two"))
if err != nil {
t.Fatal(err.Error())
}
if v != 5 {
t.Fatal("invalid value ", v)
}
out = []byte("out")
n, err = db.ZUnionStore(out, keys, weights, AggregateMax)
if err != nil {
t.Fatal(err.Error())
}
if n != 3 {
t.Fatal("invalid value ", n)
}
v, err = db.ZScore(out, []byte("two"))
if err != nil {
t.Fatal(err.Error())
}
if v != 4 {
t.Fatal("invalid value ", v)
}
n, err = db.ZCount(out, 0, 0XFFFE)
if err != nil {
t.Fatal(err.Error())
}
if n != 3 {
t.Fatal("invalid value ", v)
}
}
func TestZInterStore(t *testing.T) {
db := getTestDB()
key1 := []byte("key1")
key2 := []byte("key2")
db.ZAdd(key1, ScorePair{1, []byte("one")})
db.ZAdd(key1, ScorePair{1, []byte("two")})
db.ZAdd(key2, ScorePair{2, []byte("two")})
db.ZAdd(key2, ScorePair{2, []byte("three")})
keys := [][]byte{key1, key2}
weights := []int64{1, 2}
out := []byte("out")
n, err := db.ZInterStore(out, keys, weights, AggregateSum)
if err != nil {
t.Fatal(err.Error())
}
if n != 1 {
t.Fatal("invalid value ", n)
}
v, err := db.ZScore(out, []byte("two"))
if err != nil {
t.Fatal(err.Error())
}
if v != 5 {
t.Fatal("invalid value ", v)
}
out = []byte("out")
n, err = db.ZInterStore(out, keys, weights, AggregateMin)
if err != nil {
t.Fatal(err.Error())
}
if n != 1 {
t.Fatal("invalid value ", n)
}
v, err = db.ZScore(out, []byte("two"))
if err != nil {
t.Fatal(err.Error())
}
if v != 1 {
t.Fatal("invalid value ", v)
}
n, err = db.ZCount(out, 0, 0XFFFF)
if err != nil {
t.Fatal(err.Error())
}
if n != 1 {
t.Fatal("invalid value ", n)
}
}

View File

@ -520,6 +520,120 @@ func zpersistCommand(req *requestContext) error {
return nil
}
func zparseZsetoptStore(args [][]byte) (destKey []byte, srcKeys [][]byte, weights []int64, aggregate byte, err error) {
destKey = args[0]
nKeys, err := strconv.Atoi(ledis.String(args[1]))
if err != nil {
err = ErrValue
return
}
args = args[2:]
if len(args) < nKeys {
err = ErrSyntax
return
}
srcKeys = args[:nKeys]
args = args[nKeys:]
var weightsFlag = false
var aggregateFlag = false
for len(args) > 0 {
if strings.ToLower(ledis.String(args[0])) == "weights" {
if weightsFlag {
err = ErrSyntax
return
}
args = args[1:]
if len(args) < nKeys {
err = ErrSyntax
return
}
weights = make([]int64, nKeys)
for i, arg := range args[:nKeys] {
if weights[i], err = ledis.StrInt64(arg, nil); err != nil {
err = ErrValue
return
}
}
args = args[nKeys:]
weightsFlag = true
} else if strings.ToLower(ledis.String(args[0])) == "aggregate" {
if aggregateFlag {
err = ErrSyntax
return
}
if len(args) < 2 {
err = ErrSyntax
return
}
if strings.ToLower(ledis.String(args[1])) == "sum" {
aggregate = ledis.AggregateSum
} else if strings.ToLower(ledis.String(args[1])) == "min" {
aggregate = ledis.AggregateMin
} else if strings.ToLower(ledis.String(args[1])) == "max" {
aggregate = ledis.AggregateMax
} else {
err = ErrSyntax
return
}
args = args[2:]
aggregateFlag = true
} else {
err = ErrSyntax
return
}
}
if !aggregateFlag {
aggregate = ledis.AggregateSum
}
return
}
func zunionstoreCommand(req *requestContext) error {
args := req.args
if len(args) < 2 {
return ErrCmdParams
}
destKey, srcKeys, weights, aggregate, err := zparseZsetoptStore(args)
if err != nil {
return err
}
if n, err := req.db.ZUnionStore(destKey, srcKeys, weights, aggregate); err != nil {
return err
} else {
req.resp.writeInteger(n)
}
return nil
}
func zinterstoreCommand(req *requestContext) error {
args := req.args
if len(args) < 2 {
return ErrCmdParams
}
destKey, srcKeys, weights, aggregate, err := zparseZsetoptStore(args)
if err != nil {
return err
}
if n, err := req.db.ZInterStore(destKey, srcKeys, weights, aggregate); err != nil {
return err
} else {
req.resp.writeInteger(n)
}
return nil
}
func init() {
register("zadd", zaddCommand)
register("zcard", zcardCommand)
@ -536,6 +650,9 @@ func init() {
register("zrevrangebyscore", zrevrangebyscoreCommand)
register("zscore", zscoreCommand)
register("zunionstore", zunionstoreCommand)
register("zinterstore", zinterstoreCommand)
//ledisdb special command
register("zclear", zclearCommand)

View File

@ -599,3 +599,109 @@ func TestZsetErrorParams(t *testing.T) {
}
}
func TestZUnionStore(t *testing.T) {
c := getTestConn()
defer c.Close()
if _, err := c.Do("zadd", "k1", "1", "one"); err != nil {
t.Fatal(err.Error())
}
if _, err := c.Do("zadd", "k1", "2", "two"); err != nil {
t.Fatal(err.Error())
}
if _, err := c.Do("zadd", "k2", "1", "two"); err != nil {
t.Fatal(err.Error())
}
if _, err := c.Do("zadd", "k2", "2", "three"); err != nil {
t.Fatal(err.Error())
}
if n, err := ledis.Int64(c.Do("zunionstore", "out", "2", "k1", "k2", "weights", "1", "2")); err != nil {
t.Fatal(err.Error())
} else {
if n != 3 {
t.Fatal("invalid value ", n)
}
}
if n, err := ledis.Int64(c.Do("zunionstore", "out", "2", "k1", "k2", "weights", "1", "2", "aggregate", "min")); err != nil {
t.Fatal(err.Error())
} else {
if n != 3 {
t.Fatal("invalid value ", n)
}
}
if n, err := ledis.Int64(c.Do("zunionstore", "out", "2", "k1", "k2", "aggregate", "max")); err != nil {
t.Fatal(err.Error())
} else {
if n != 3 {
t.Fatal("invalid value ", n)
}
}
if n, err := ledis.Int64(c.Do("zscore", "out", "two")); err != nil {
t.Fatal(err.Error())
} else {
if n != 2 {
t.Fatal("invalid value ", n)
}
}
}
func TestZInterStore(t *testing.T) {
c := getTestConn()
defer c.Close()
if _, err := c.Do("zadd", "k1", "1", "one"); err != nil {
t.Fatal(err.Error())
}
if _, err := c.Do("zadd", "k1", "2", "two"); err != nil {
t.Fatal(err.Error())
}
if _, err := c.Do("zadd", "k2", "1", "two"); err != nil {
t.Fatal(err.Error())
}
if _, err := c.Do("zadd", "k2", "2", "three"); err != nil {
t.Fatal(err.Error())
}
if n, err := ledis.Int64(c.Do("zinterstore", "out", "2", "k1", "k2", "weights", "1", "2")); err != nil {
t.Fatal(err.Error())
} else {
if n != 1 {
t.Fatal("invalid value ", n)
}
}
if n, err := ledis.Int64(c.Do("zinterstore", "out", "2", "k1", "k2", "aggregate", "min", "weights", "1", "2")); err != nil {
t.Fatal(err.Error())
} else {
if n != 1 {
t.Fatal("invalid value ", n)
}
}
if n, err := ledis.Int64(c.Do("zinterstore", "out", "2", "k1", "k2", "aggregate", "sum")); err != nil {
t.Fatal(err.Error())
} else {
if n != 1 {
t.Fatal("invalid value ", n)
}
}
if n, err := ledis.Int64(c.Do("zscore", "out", "two")); err != nil {
t.Fatal(err.Error())
} else {
if n != 3 {
t.Fatal("invalid value ", n)
}
}
}