From cb908a859c794597a7364f15301693e62bc2338a Mon Sep 17 00:00:00 2001 From: wenyekui Date: Wed, 13 Aug 2014 11:59:43 +0800 Subject: [PATCH] add zunionstorecommand & zinterstorecommand in server pkg --- ledis/t_zset_test.go | 1 - server/cmd_zset.go | 117 ++++++++++++++++++++++++++++++++++++++++ server/cmd_zset_test.go | 45 ++++++++++++++++ 3 files changed, 162 insertions(+), 1 deletion(-) diff --git a/ledis/t_zset_test.go b/ledis/t_zset_test.go index fc4f41f..a772360 100644 --- a/ledis/t_zset_test.go +++ b/ledis/t_zset_test.go @@ -315,7 +315,6 @@ func TestZUnionStore(t *testing.T) { t.Fatal("invalid value ", v) } - pairs, _ := db.ZRange(out, 0, -1) n, err = db.ZCount(out, 0, 0XFFFE) if err != nil { diff --git a/server/cmd_zset.go b/server/cmd_zset.go index e540b32..f8117fc 100644 --- a/server/cmd_zset.go +++ b/server/cmd_zset.go @@ -520,6 +520,120 @@ func zpersistCommand(req *requestContext) error { return nil } +func zparseZsetoptStore(args [][]byte) (destKey []byte, srcKeys [][]byte, weights []int64, aggregate byte, err error) { + destKey = args[0] + nKeys, err := strconv.Atoi(ledis.String(args[1])) + if err != nil { + err = ErrValue + return + } + args = args[2:] + if len(args) < nKeys { + err = ErrSyntax + return + } + + srcKeys = args[:nKeys] + + args = args[nKeys:] + + var weightsFlag = false + var aggregateFlag = false + + for len(args) > 0 { + if strings.ToLower(ledis.String(args[0])) == "weights" { + if weightsFlag { + err = ErrSyntax + return + } + + args = args[1:] + if len(args) < nKeys { + err = ErrSyntax + return + } + + weights = make([]int64, nKeys) + for i, arg := range args[:nKeys] { + if weights[i], err = ledis.StrInt64(arg, nil); err != nil { + err = ErrValue + return + } + } + args = args[nKeys:] + + weightsFlag = true + + } else if strings.ToLower(ledis.String(args[0])) == "aggregate" { + if aggregateFlag { + err = ErrSyntax + return + } + if len(args) < 2 { + err = ErrSyntax + return + } + + if strings.ToLower(ledis.String(args[1])) == "sum" { + aggregate = ledis.AggregateSum + } else if strings.ToLower(ledis.String(args[1])) == "min" { + aggregate = ledis.AggregateMin + } else if strings.ToLower(ledis.String(args[1])) == "max" { + aggregate = ledis.AggregateMax + } else { + err = ErrSyntax + return + } + args = args[2:] + aggregateFlag = true + } else { + err = ErrSyntax + return + } + } + if !aggregateFlag { + aggregate = ledis.AggregateSum + } + return +} + +func zunionstoreCommand(req *requestContext) error { + args := req.args + if len(args) < 2 { + return ErrCmdParams + } + + destKey, srcKeys, weights, aggregate, err := zparseZsetoptStore(args) + if err != nil { + return err + } + if n, err := req.db.ZUnionStore(destKey, srcKeys, weights, aggregate); err != nil { + return err + } else { + req.resp.writeInteger(n) + } + + return nil +} + +func zinterstoreCommand(req *requestContext) error { + args := req.args + if len(args) < 2 { + return ErrCmdParams + } + + destKey, srcKeys, weights, aggregate, err := zparseZsetoptStore(args) + if err != nil { + return err + } + if n, err := req.db.ZInterStore(destKey, srcKeys, weights, aggregate); err != nil { + return err + } else { + req.resp.writeInteger(n) + } + return nil +} + func init() { register("zadd", zaddCommand) register("zcard", zcardCommand) @@ -536,6 +650,9 @@ func init() { register("zrevrangebyscore", zrevrangebyscoreCommand) register("zscore", zscoreCommand) + register("zunionstore", zunionstoreCommand) + register("zinterstore", zinterstoreCommand) + //ledisdb special command register("zclear", zclearCommand) diff --git a/server/cmd_zset_test.go b/server/cmd_zset_test.go index d9b1272..06cd538 100644 --- a/server/cmd_zset_test.go +++ b/server/cmd_zset_test.go @@ -599,3 +599,48 @@ func TestZsetErrorParams(t *testing.T) { } } + +func TestZUnionStore(t *testing.T) { + c := getTestConn() + defer c.Close() + + if _, err := c.Do("zadd", "k1", "1", "one"); err != nil { + t.Fatal(err.Error()) + } + + if _, err := c.Do("zadd", "k1", "2", "two"); err != nil { + t.Fatal(err.Error()) + } + + if _, err := c.Do("zadd", "k2", "1", "two"); err != nil { + t.Fatal(err.Error()) + } + + if _, err := c.Do("zadd", "k2", "2", "three"); err != nil { + t.Fatal(err.Error()) + } + + if n, err := ledis.Int64(c.Do("zunionstore", "out", "2", "k1", "k2", "weights", "1", "2")); err != nil { + t.Fatal(err.Error()) + } else { + if n != 3 { + t.Fatal("invalid value ", n) + } + } + + if n, err := ledis.Int64(c.Do("zunionstore", "out", "2", "k1", "k2", "weights", "1", "2", "aggregate", "min")); err != nil { + t.Fatal(err.Error()) + } else { + if n != 3 { + t.Fatal("invalid value ", n) + } + } + + if n, err := ledis.Int64(c.Do("zunionstore", "out", "2", "k1", "k2", "aggregate", "max")); err != nil { + t.Fatal(err.Error()) + } else { + if n != 3 { + t.Fatal("invalid value ", n) + } + } +}