diff --git a/ledis/const.go b/ledis/const.go index 9e4d0cf..ba0cf05 100644 --- a/ledis/const.go +++ b/ledis/const.go @@ -37,4 +37,5 @@ var ( ErrKeySize = errors.New("invalid key size") ErrHashFieldSize = errors.New("invalid hash field size") ErrZSetMemberSize = errors.New("invalid zset member size") + ErrScoreMiss = errors.New("zset score miss") ) diff --git a/ledis/t_zset.go b/ledis/t_zset.go index 222b4ae..8be54b7 100644 --- a/ledis/t_zset.go +++ b/ledis/t_zset.go @@ -8,8 +8,9 @@ import ( ) const ( - MinScore int64 = -1<<63 + 1 - MaxScore int64 = 1<<63 - 1 + MinScore int64 = -1<<63 + 1 + MaxScore int64 = 1<<63 - 1 + InvalidScore int64 = -1 << 63 ) type ScorePair struct { @@ -311,19 +312,25 @@ func (db *DB) ZCard(key []byte) (int64, error) { return Int64(db.db.Get(sk)) } -func (db *DB) ZScore(key []byte, member []byte) ([]byte, error) { +func (db *DB) ZScore(key []byte, member []byte) (int64, error) { if err := checkZSetKMSize(key, member); err != nil { - return nil, err + return InvalidScore, err } + var score int64 = InvalidScore + k := db.zEncodeSetKey(key, member) - - score, err := Int64(db.db.Get(k)) - if err != nil { - return nil, err + if v, err := db.db.Get(k); err != nil { + return InvalidScore, err + } else if v == nil { + return InvalidScore, ErrScoreMiss + } else { + if score, err = Int64(v, nil); err != nil { + return InvalidScore, err + } } - return StrPutInt64(score), nil + return score, nil } func (db *DB) ZRem(key []byte, members ...[]byte) (int64, error) { @@ -356,9 +363,9 @@ func (db *DB) ZRem(key []byte, members ...[]byte) (int64, error) { return num, err } -func (db *DB) ZIncrBy(key []byte, delta int64, member []byte) ([]byte, error) { +func (db *DB) ZIncrBy(key []byte, delta int64, member []byte) (int64, error) { if err := checkZSetKMSize(key, member); err != nil { - return nil, err + return InvalidScore, err } t := db.zsetTx @@ -371,18 +378,17 @@ func (db *DB) ZIncrBy(key []byte, delta int64, member []byte) ([]byte, error) { v, err := db.db.Get(ek) if err != nil { - return nil, err + return InvalidScore, err } else if v != nil { if s, err := Int64(v, err); err != nil { - return nil, err + return InvalidScore, err } else { sk := db.zEncodeScoreKey(key, member, s) t.Delete(sk) score = s + delta - if score >= MaxScore || score <= MinScore { - return nil, errScoreOverflow + return InvalidScore, errScoreOverflow } } } else { @@ -395,7 +401,7 @@ func (db *DB) ZIncrBy(key []byte, delta int64, member []byte) ([]byte, error) { t.Put(sk, []byte{}) err = t.Commit() - return StrPutInt64(score), err + return score, err } func (db *DB) ZCount(key []byte, min int64, max int64) (int64, error) { @@ -563,10 +569,10 @@ func (db *DB) zRange(key []byte, min int64, max int64, withScores bool, offset i continue } - v = append(v, m) - if withScores { - v = append(v, StrPutInt64(s)) + v = append(v, m, s) + } else { + v = append(v, m) } } diff --git a/ledis/t_zset_test.go b/ledis/t_zset_test.go index fa2bb45..202da70 100644 --- a/ledis/t_zset_test.go +++ b/ledis/t_zset_test.go @@ -71,6 +71,16 @@ func TestDBZSet(t *testing.T) { t.Fatal(n) } + if s, err := db.ZScore(key, bin("d")); err != nil { + t.Fatal(err) + } else if s != 3 { + t.Fatal(s) + } + + if s, err := db.ZScore(key, bin("zzz")); err != ErrScoreMiss || s != InvalidScore { + t.Fatal(fmt.Sprintf("s=[%d] err=[%s]", s, err)) + } + // {c':2, 'd':3} if n, err := db.ZRem(key, bin("a"), bin("b")); err != nil { t.Fatal(err) @@ -179,8 +189,10 @@ func TestZSetOrder(t *testing.T) { } // {'a':0, 'b':1, 'c':2, 'd':999, 'e':6, 'f':5} - if _, err := db.ZIncrBy(key, 2, bin("e")); err != nil { + if s, err := db.ZIncrBy(key, 2, bin("e")); err != nil { t.Fatal(err) + } else if s != 6 { + t.Fatal(s) } if pos, _ := db.ZRank(key, bin("e")); int(pos) != 4 { @@ -190,6 +202,19 @@ func TestZSetOrder(t *testing.T) { if pos, _ := db.ZRevRank(key, bin("e")); int(pos) != 1 { t.Fatal(pos) } + + if datas, _ := db.ZRange(key, 0, endPos, true); len(datas) != 12 { + t.Fatal(len(datas)) + } else { + scores := []int64{0, 1, 2, 5, 6, 999} + for i := 1; i < len(datas); i += 2 { + if s, ok := datas[i].(int64); !ok || s != scores[(i-1)/2] { + t.Fatal(fmt.Sprintf("[%d]=%d", i, datas[i])) + } + } + } + + return } func TestDBZScan(t *testing.T) { diff --git a/server/cmd_zset.go b/server/cmd_zset.go index c7f26b5..c72fa71 100644 --- a/server/cmd_zset.go +++ b/server/cmd_zset.go @@ -66,10 +66,14 @@ func zscoreCommand(c *client) error { return ErrCmdParams } - if v, err := c.db.ZScore(args[0], args[1]); err != nil { - return err + if s, err := c.db.ZScore(args[0], args[1]); err != nil { + if err == ledis.ErrScoreMiss { + c.writeBulk(nil) + } else { + return err + } } else { - c.writeBulk(v) + c.writeBulk(ledis.StrPutInt64(s)) } return nil @@ -106,7 +110,7 @@ func zincrbyCommand(c *client) error { if v, err := c.db.ZIncrBy(key, delta, args[2]); err != nil { return err } else { - c.writeBulk(v) + c.writeBulk(ledis.StrPutInt64(v)) } return nil @@ -308,10 +312,16 @@ func zrangeGeneric(c *client, reverse bool) error { withScores = true } - if v, err := c.db.ZRangeGeneric(key, start, stop, withScores, reverse); err != nil { + if datas, err := c.db.ZRangeGeneric(key, start, stop, withScores, reverse); err != nil { return err } else { - c.writeArray(v) + if withScores { + for i := len(datas) - 1; i > 0; i -= 2 { + v, _ := datas[i].(int64) + datas[i] = ledis.StrPutInt64(v) + } + } + c.writeArray(datas) } return nil } @@ -373,10 +383,16 @@ func zrangebyscoreGeneric(c *client, reverse bool) error { return nil } - if v, err := c.db.ZRangeByScoreGeneric(key, min, max, withScores, offset, count, reverse); err != nil { + if datas, err := c.db.ZRangeByScoreGeneric(key, min, max, withScores, offset, count, reverse); err != nil { return err } else { - c.writeArray(v) + if withScores { + for i := len(datas) - 1; i > 0; i -= 2 { + v, _ := datas[i].(int64) + datas[i] = ledis.StrPutInt64(v) + } + } + c.writeArray(datas) } return nil