Merge branch 'develop' into flush-feature

This commit is contained in:
wenyekui 2014-08-16 16:42:43 +08:00
commit 5b88f17470
16 changed files with 434 additions and 140 deletions

View File

@ -2,14 +2,12 @@ from __future__ import with_statement
import datetime
import time as mod_time
from ledis._compat import (b, izip, imap, iteritems,
basestring, long, nativestr, urlparse, bytes)
basestring, long, nativestr, bytes)
from ledis.connection import ConnectionPool, UnixDomainSocketConnection
from ledis.exceptions import (
ConnectionError,
DataError,
LedisError,
ResponseError,
ExecAbortError,
)
SYM_EMPTY = b('')
@ -75,15 +73,18 @@ class Ledis(object):
"""
RESPONSE_CALLBACKS = dict_merge(
string_keys_to_dict(
'EXISTS EXPIRE EXPIREAT HEXISTS HMSET SETNX '
'PERSIST HPERSIST LPERSIST ZPERSIST BEXPIRE '
'BEXPIREAT BPERSIST BDELETE',
'EXISTS HEXISTS SISMEMBER HMSET SETNX'
'PERSIST HPERSIST LPERSIST ZPERSIST SPERSIST BPERSIST'
'EXPIRE LEXPIRE HEXPIRE SEXPIRE ZEXPIRE BEXPIRE'
'EXPIREAT LBEXPIREAT HEXPIREAT SEXPIREAT ZEXPIREAT BEXPIREAT',
bool
),
string_keys_to_dict(
'DECRBY DEL HDEL HLEN INCRBY LLEN '
'ZADD ZCARD ZREM ZREMRANGEBYRANK ZREMRANGEBYSCORE'
'LMCLEAR HMCLEAR ZMCLEAR BCOUNT BGETBIT BSETBIT BOPT BMSETBIT',
'DECRBY DEL HDEL HLEN INCRBY LLEN ZADD ZCARD ZREM'
'ZREMRANGEBYRANK ZREMRANGEBYSCORE LMCLEAR HMCLEAR'
'ZMCLEAR BCOUNT BGETBIT BSETBIT BOPT BMSETBIT'
'SADD SCARD SDIFFSTORE SINTERSTORE SUNIONSTORE SREM'
'SCLEAR SMLEAR BDELETE',
int
),
string_keys_to_dict(
@ -94,6 +95,10 @@ class Ledis(object):
'MSET SELECT ',
lambda r: nativestr(r) == 'OK'
),
string_keys_to_dict(
'SDIFF SINTER SMEMBERS SUNION',
lambda r: r and set(r) or set()
),
string_keys_to_dict(
'ZRANGE ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE',
zset_score_pairs
@ -403,6 +408,103 @@ class Ledis(object):
return self.execute_command('LPERSIST', name)
#### SET COMMANDS ####
def sadd(self, name, *values):
"Add ``value(s)`` to set ``name``"
return self.execute_command('SADD', name, *values)
def scard(self, name):
"Return the number of elements in set ``name``"
return self.execute_command('SCARD', name)
def sdiff(self, keys, *args):
"Return the difference of sets specified by ``keys``"
args = list_or_args(keys, args)
return self.execute_command('SDIFF', *args)
def sdiffstore(self, dest, keys, *args):
"""
Store the difference of sets specified by ``keys`` into a new
set named ``dest``. Returns the number of keys in the new set.
"""
args = list_or_args(keys, args)
return self.execute_command('SDIFFSTORE', dest, *args)
def sinter(self, keys, *args):
"Return the intersection of sets specified by ``keys``"
args = list_or_args(keys, args)
return self.execute_command('SINTER', *args)
def sinterstore(self, dest, keys, *args):
"""
Store the intersection of sets specified by ``keys`` into a new
set named ``dest``. Returns the number of keys in the new set.
"""
args = list_or_args(keys, args)
return self.execute_command('SINTERSTORE', dest, *args)
def sismember(self, name, value):
"Return a boolean indicating if ``value`` is a member of set ``name``"
return self.execute_command('SISMEMBER', name, value)
def smembers(self, name):
"Return all members of the set ``name``"
return self.execute_command('SMEMBERS', name)
def srem(self, name, *values):
"Remove ``values`` from set ``name``"
return self.execute_command('SREM', name, *values)
def sunion(self, keys, *args):
"Return the union of sets specified by ``keys``"
args = list_or_args(keys, args)
return self.execute_command('SUNION', *args)
def sunionstore(self, dest, keys, *args):
"""
Store the union of sets specified by ``keys`` into a new
set named ``dest``. Returns the number of keys in the new set.
"""
args = list_or_args(keys, args)
return self.execute_command('SUNIONSTORE', dest, *args)
# SPECIAL COMMANDS SUPPORTED BY LEDISDB
def sclear(self, name):
"Delete key ``name`` from set"
return self.execute_command('SCLEAR', name)
def smclear(self, *names):
"Delete multiple keys ``names`` from set"
return self.execute_command('SMCLEAR', *names)
def sexpire(self, name, time):
"""
Set an expire flag on key name for time milliseconds.
time can be represented by an integer or a Python timedelta object.
"""
if isinstance(time, datetime.timedelta):
time = time.seconds + time.days * 24 * 3600
return self.execute_command('SEXPIRE', name, time)
def sexpireat(self, name, when):
"""
Set an expire flag on key name. when can be represented as an integer
representing unix time in milliseconds (unix time * 1000) or a
Python datetime object.
"""
if isinstance(when, datetime.datetime):
when = int(mod_time.mktime(when.timetuple()))
return self.execute_command('SEXPIREAT', name, when)
def sttl(self, name):
"Returns the number of seconds until the key name will expire"
return self.execute_command('STTL', name)
def spersist(self, name):
"Removes an expiration on name"
return self.execute_command('SPERSIST', name)
#### SORTED SET COMMANDS ####
def zadd(self, name, *args, **kwargs):
"""
@ -693,7 +795,6 @@ class Ledis(object):
return self.execute_command('HPERSIST', name)
### BIT COMMANDS
def bget(self, name):
""

View File

@ -3,21 +3,15 @@
import unittest
import sys
import datetime, time
sys.path.append('..')
import ledis
from ledis._compat import b
from ledis import ResponseError
from util import expire_at, expire_at_seconds
l = ledis.Ledis(port=6380)
def current_time():
return datetime.datetime.now()
class TestCmdBit(unittest.TestCase):
def setUp(self):
pass
@ -94,21 +88,17 @@ class TestCmdBit(unittest.TestCase):
assert l.bttl('a') == -1
def test_bexpireat_datetime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
l.bsetbit('a', 1, True)
assert l.bexpireat('a', expire_at)
assert l.bexpireat('a', expire_at())
assert 0 < l.bttl('a') <= 61
def test_bexpireat_unixtime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
l.bsetbit('a', 1, True)
expire_at_seconds = int(time.mktime(expire_at.timetuple()))
assert l.bexpireat('a', expire_at_seconds)
assert l.bexpireat('a', expire_at_seconds())
assert 0 < l.bttl('a') <= 61
def test_bexpireat_no_key(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
assert not l.bexpireat('a', expire_at)
assert not l.bexpireat('a', expire_at())
def test_bttl_and_bpersist(self):
l.bsetbit('a', 1, True)

View File

@ -3,19 +3,16 @@
import unittest
import sys
import datetime, time
sys.path.append('..')
import ledis
from ledis._compat import b, iteritems, itervalues
from ledis import ResponseError
from ledis._compat import itervalues
from util import expire_at, expire_at_seconds
l = ledis.Ledis(port=6380)
def current_time():
return datetime.datetime.now()
class TestCmdHash(unittest.TestCase):
def setUp(self):
@ -110,21 +107,17 @@ class TestCmdHash(unittest.TestCase):
assert l.httl('myhash') <= 100
def test_hexpireat_datetime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
l.hset('a', 'f', 'foo')
assert l.hexpireat('a', expire_at)
assert l.hexpireat('a', expire_at())
assert 0 < l.httl('a') <= 61
def test_hexpireat_unixtime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
l.hset('a', 'f', 'foo')
expire_at_seconds = int(time.mktime(expire_at.timetuple()))
assert l.hexpireat('a', expire_at_seconds)
assert l.hexpireat('a', expire_at_seconds())
assert 0 < l.httl('a') <= 61
def test_hexpireat_no_key(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
assert not l.hexpireat('a', expire_at)
assert not l.hexpireat('a', expire_at())
def test_hexpireat(self):
assert l.hexpireat('myhash', 1577808000) == 0

View File

@ -3,19 +3,16 @@
import unittest
import sys
import datetime, time
sys.path.append('..')
import ledis
from ledis._compat import b, iteritems
from util import expire_at, expire_at_seconds
l = ledis.Ledis(port=6380)
def current_time():
return datetime.datetime.now()
class TestCmdKv(unittest.TestCase):
def setUp(self):
pass
@ -32,10 +29,6 @@ class TestCmdKv(unittest.TestCase):
assert l.decr('a', amount=5) == -7
assert l['a'] == b('-7')
#FIXME: how to test exception?
# l.set('b', '234293482390480948029348230948')
# self.assertRaises(ResponseError, l.delete('b'))
def test_decrby(self):
assert l.delete('a') == 1
assert l.decrby('a') == -1
@ -134,21 +127,17 @@ class TestCmdKv(unittest.TestCase):
assert not (l.expire('a', 100))
def test_expireat_datetime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
l.set('a', '1')
assert l.expireat('a', expire_at)
assert l.expireat('a', expire_at())
assert 0 < l.ttl('a') <= 61
def test_expireat_unixtime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
l.set('a', '1')
expire_at_seconds = int(time.mktime(expire_at.timetuple()))
assert l.expireat('a', expire_at_seconds)
assert l.expireat('a', expire_at_seconds())
assert 0 < l.ttl('a') <= 61
def test_expireat_no_key(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
assert not l.expireat('a', expire_at)
assert not l.expireat('a', expire_at())
def test_expireat(self):
l.set('a', 'hello')

View File

@ -2,20 +2,17 @@
# Test Cases for list commands
import unittest
import datetime, time
import sys
sys.path.append('..')
import ledis
from ledis._compat import b
from util import expire_at, expire_at_seconds
l = ledis.Ledis(port=6380)
def current_time():
return datetime.datetime.now()
class TestCmdList(unittest.TestCase):
def setUp(self):
pass
@ -84,21 +81,17 @@ class TestCmdList(unittest.TestCase):
assert l.lttl('mylist') == -1
def test_lexpireat_datetime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
l.rpush('mylist', '1')
assert l.lexpireat('mylist', expire_at)
assert l.lexpireat('mylist', expire_at())
assert 0 < l.lttl('mylist') <= 61
def test_lexpireat_unixtime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
l.rpush('mylist', '1')
expire_at_seconds = int(time.mktime(expire_at.timetuple()))
assert l.lexpireat('mylist', expire_at_seconds)
assert l.lexpireat('mylist', expire_at_seconds())
assert l.lttl('mylist') <= 61
def test_lexpireat_no_key(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
assert not l.lexpireat('mylist', expire_at)
assert not l.lexpireat('mylist', expire_at())
def test_lttl_and_lpersist(self):
l.rpush('mylist', '1')

View File

@ -0,0 +1,156 @@
# coding: utf-8
# Test set commands
import unittest
import sys
sys.path.append('..')
import pytest
import ledis
from ledis._compat import b
from ledis import ResponseError
from util import expire_at, expire_at_seconds
l = ledis.Ledis(port=6380)
class TestCmdSet(unittest.TestCase):
def setUp(self):
pass
def tearDown(self):
l.smclear('a', 'b', 'c')
def test_sadd(self):
members = set([b('1'), b('2'), b('3')])
l.sadd('a', *members)
assert l.smembers('a') == members
def test_scard(self):
l.sadd('a', '1', '2', '3')
assert l.scard('a') == 3
def test_sdiff(self):
l.sadd('a', '1', '2', '3')
assert l.sdiff('a', 'b') == set([b('1'), b('2'), b('3')])
l.sadd('b', '2', '3')
assert l.sdiff('a', 'b') == set([b('1')])
def test_sdiffstore(self):
l.sadd('a', '1', '2', '3')
assert l.sdiffstore('c', 'a', 'b') == 3
assert l.smembers('c') == set([b('1'), b('2'), b('3')])
l.sadd('b', '2', '3')
print l.smembers('c')
print "before"
assert l.sdiffstore('c', 'a', 'b') == 1
print l.smembers('c')
assert l.smembers('c') == set([b('1')])
def test_sinter(self):
l.sadd('a', '1', '2', '3')
assert l.sinter('a', 'b') == set()
l.sadd('b', '2', '3')
assert l.sinter('a', 'b') == set([b('2'), b('3')])
def test_sinterstore(self):
l.sadd('a', '1', '2', '3')
assert l.sinterstore('c', 'a', 'b') == 0
assert l.smembers('c') == set()
l.sadd('b', '2', '3')
assert l.sinterstore('c', 'a', 'b') == 2
assert l.smembers('c') == set([b('2'), b('3')])
def test_sismember(self):
l.sadd('a', '1', '2', '3')
assert l.sismember('a', '1')
assert l.sismember('a', '2')
assert l.sismember('a', '3')
assert not l.sismember('a', '4')
def test_smembers(self):
l.sadd('a', '1', '2', '3')
assert l.smembers('a') == set([b('1'), b('2'), b('3')])
def test_srem(self):
l.sadd('a', '1', '2', '3', '4')
assert l.srem('a', '5') == 0
assert l.srem('a', '2', '4') == 2
assert l.smembers('a') == set([b('1'), b('3')])
def test_sunion(self):
l.sadd('a', '1', '2')
l.sadd('b', '2', '3')
assert l.sunion('a', 'b') == set([b('1'), b('2'), b('3')])
def test_sunionstore(self):
l.sadd('a', '1', '2')
l.sadd('b', '2', '3')
assert l.sunionstore('c', 'a', 'b') == 3
assert l.smembers('c') == set([b('1'), b('2'), b('3')])
def test_sclear(self):
members = set([b('1'), b('2'), b('3')])
l.sadd('a', *members)
assert l.sclear('a') == 3
assert l.sclear('a') == 0
def test_smclear(self):
members = set([b('1'), b('2'), b('3')])
l.sadd('a', *members)
l.sadd('b', *members)
assert l.smclear('a', 'b') == 2
def test_sexpire(self):
members = set([b('1'), b('2'), b('3')])
assert l.sexpire('a', 100) == 0
l.sadd('a', *members)
assert l.sexpire('a', 100) == 1
assert l.sttl('a') <= 100
def test_sexpireat_datetime(self):
members = set([b('1'), b('2'), b('3')])
l.sadd('a', *members)
assert l.sexpireat('a', expire_at())
assert 0 < l.sttl('a') <= 61
def test_sexpireat_unixtime(self):
members = set([b('1'), b('2'), b('3')])
l.sadd('a', *members)
assert l.sexpireat('a', expire_at_seconds())
assert 0 < l.sttl('a') <= 61
def test_sexpireat_no_key(self):
assert not l.sexpireat('a', expire_at())
def test_sexpireat(self):
assert l.sexpireat('a', 1577808000) == 0
members = set([b('1'), b('2'), b('3')])
l.sadd('a', *members)
assert l.sexpireat('a', 1577808000) == 1
def test_sttl(self):
members = set([b('1'), b('2'), b('3')])
l.sadd('a', *members)
assert l.sexpire('a', 100)
assert l.sttl('a') <= 100
def test_spersist(self):
members = set([b('1'), b('2'), b('3')])
l.sadd('a', *members)
l.sexpire('a', 100)
assert l.sttl('a') <= 100
assert l.spersist('a')
assert l.sttl('a') == -1
def test_invalid_params(self):
with pytest.raises(ResponseError) as excinfo:
l.sadd("a")
assert excinfo.value.message == "invalid command param"
def test_invalid_value(self):
members = set([b('1'), b('2'), b('3')])
l.sadd('a', *members)
self.assertRaises(ResponseError, lambda: l.sexpire('a', 'a'))

View File

@ -3,17 +3,14 @@
import unittest
import sys
import datetime, time
sys.path.append('..')
import ledis
from ledis._compat import b, iteritems
from ledis import ResponseError
from ledis._compat import b
from util import expire_at, expire_at_seconds
l = ledis.Ledis(port=6380)
def current_time():
return datetime.datetime.now()
class TestCmdZset(unittest.TestCase):
def setUp(self):
@ -145,21 +142,17 @@ class TestCmdZset(unittest.TestCase):
assert l.zttl('a') == -1
def test_zexpireat_datetime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
l.zadd('a', a1=1)
assert l.zexpireat('a', expire_at)
assert l.zexpireat('a', expire_at())
assert 0 < l.zttl('a') <= 61
def test_zexpireat_unixtime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
l.zadd('a', a1=1)
expire_at_seconds = int(time.mktime(expire_at.timetuple()))
assert l.zexpireat('a', expire_at_seconds)
assert l.zexpireat('a', expire_at_seconds())
assert 0 < l.zttl('a') <= 61
def test_zexpireat_no_key(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
assert not l.zexpireat('a', expire_at)
assert not l.zexpireat('a', expire_at())
def test_zttl_and_zpersist(self):
l.zadd('a', a1=1)

View File

@ -0,0 +1,20 @@
#coding: utf-8
import datetime
import time
def current_time():
return datetime.datetime.now()
def expire_at(minute=1):
expire_at = current_time() + datetime.timedelta(minutes=minute)
return expire_at
def expire_at_seconds(minute=1):
return int(time.mktime(expire_at(minute=minute).timetuple()))
if __name__ == "__main__":
print expire_at()
print expire_at_seconds()

View File

@ -36,16 +36,38 @@ client.bget("bit key 3", function(err, result){
});
//test zunionstore & zinterstore
client.zadd("zset1", 1, "one")
client.zadd("zset1", 2, "two")
client.zadd("zset1", 1, "one");
client.zadd("zset1", 2, "two");
client.zadd("zset2", 1, "one")
client.zadd("zset2", 2, "two")
client.zadd("zset2", 3, "three")
client.zadd("zset2", 1, "one");
client.zadd("zset2", 2, "two");
client.zadd("zset2", 3, "three");
client.zunionstore("out", 2, "zset1", "zset2", "weights", 2, 3, ledis.print);
client.zrange("out", 0, -1, "withscores", ledis.print);
client.zinterstore("out", 2, "zset1", "zset2", "weights", 2, 3, ledis.print);
client.zrange("out", 0, -1, "withscores", ledis.print);
//example of set commands
client.sadd("a", 1, 2, 3);
client.sadd("b", 3, 4, 5);
client.sismember("a", 1, ledis.print);
client.smembers("a", ledis.print);
client.sdiff("a", "b", ledis.print);
client.sdiffstore("c", "a", "b", ledis.print);
client.sinter("a", "b", ledis.print);
client.sinterstore("c", "a", "b", ledis.print);
client.sunion("a", "b", ledis.print);
client.sunionstore("c", "a", "b", ledis.print);
client.srem("a", 1, ledis.print);
client.sclear("c", ledis.print);
client.smclear("a", "b", ledis.print);
client.sexpire("a", 100, ledis.print);
client.sexpireat("a", 1577808000, ledis.print);
client.sttl("a", ledis.print);
client.spersist("a", ledis.print);
client.zunionstore("out", 2, "zset1", "zset2", "weights", 2, 3, ledis.print)
client.zrange("out", 0, -1, "withscores", ledis.print)
client.zinterstore("out", 2, "zset1", "zset2", "weights", 2, 3, ledis.print)
client.zrange("out", 0, -1, "withscores", ledis.print)
client.quit()

View File

@ -2,6 +2,10 @@
module.exports = [
"quit",
"ping",
"echo",
"select",
"bget",
"bdelete",
"bsetbit",
@ -93,4 +97,26 @@ module.exports = [
"zexpireat",
"zttl",
"zpersist",
"sadd",
"scard",
"sdiff",
"sdiffstore",
"sinter",
"sinterstore",
"sismember",
"smembers",
"srem",
"sunion",
"sunionstore",
"sclear",
"smclear",
"sexpire",
"sexpireat",
"sttl",
"spersist"
];

View File

@ -113,6 +113,27 @@ local commands = {
"bttl",
"bpersist",
--[[set]]
"sadd",
"scard",
"sdiff",
"sdiffstore",
"sinter",
"sinterstore",
"sismember",
"smembers",
"srem",
"sunion",
"sunionstore",
"sclear",
"smclear",
"sexpire",
"sexpireat",
"sttl",
"spersist",
--[[server]]
"ping",
"echo",

View File

@ -1,8 +1,6 @@
package ledis
import (
"fmt"
"github.com/siddontang/ledisdb/store"
"testing"
)
@ -198,26 +196,6 @@ func TestDBBScan(t *testing.T) {
t.Fatal(err.Error())
}
ek1 := db.bEncodeMetaKey(k1)
fmt.Printf("%x\n", ek1)
ek2 := db.bEncodeMetaKey(k2)
fmt.Printf("%x\n", ek2)
ek3 := db.bEncodeMetaKey(k3)
fmt.Printf("%x\n", ek3)
start := db.bEncodeMetaKey(nil)
fmt.Printf("start: %x\n", start)
end := db.bEncodeMetaKey(nil)
end[len(end)-1] = BitMetaType + 1
fmt.Printf("end: %x\n", end)
it := db.db.RangeLimitIterator(start, end, store.RangeClose, 0, 4)
for ; it.Valid(); it.Next() {
fmt.Printf("%x\n", it.RawKey())
}
it.Close()
if v, err := db.BScan(nil, 1, true); err != nil {
t.Fatal(err)
} else if len(v) != 1 {

View File

@ -490,7 +490,6 @@ func (db *DB) sStoreGeneric(dstKey []byte, optType byte, keys ...[]byte) (int64,
var err error
var ek []byte
var num int64 = 0
var v [][]byte
switch optType {
@ -513,22 +512,21 @@ func (db *DB) sStoreGeneric(dstKey []byte, optType byte, keys ...[]byte) (int64,
ek = db.sEncodeSetKey(dstKey, m)
if v, err := db.db.Get(ek); err != nil {
if _, err := db.db.Get(ek); err != nil {
return 0, err
} else if v == nil {
num++
}
t.Put(ek, nil)
}
if _, err = db.sIncrSize(dstKey, num); err != nil {
var num = int64(len(v))
sk := db.sEncodeSizeKey(dstKey)
t.Put(sk, PutInt64(num))
if err = t.Commit(); err != nil {
return 0, err
}
err = t.Commit()
return num, err
return num, nil
}
func (db *DB) SClear(key []byte) (int64, error) {

View File

@ -147,7 +147,7 @@ func testUnion(db *DB, t *testing.T) {
m2 := []byte("m2")
m3 := []byte("m3")
db.SAdd(key, m1, m2)
db.SAdd(key1, m1, m3)
db.SAdd(key1, m1, m2, m3)
db.SAdd(key2, m2, m3)
if _, err := db.sUnionGeneric(key, key2); err != nil {
t.Fatal(err)
@ -158,11 +158,13 @@ func testUnion(db *DB, t *testing.T) {
}
dstkey := []byte("union_dsk")
db.SAdd(dstkey, []byte("x"))
if num, err := db.SUnionStore(dstkey, key1, key2); err != nil {
t.Fatal(err)
} else if num != 3 {
t.Fatal(num)
}
if _, err := db.SMembers(dstkey); err != nil {
t.Fatal(err)
}

View File

@ -848,29 +848,25 @@ func (db *DB) ZUnionStore(destKey []byte, srcKeys [][]byte, weights []int64, agg
db.zDelete(t, destKey)
var num int64 = 0
for member, score := range destMap {
if err := checkZSetKMSize(destKey, []byte(member)); err != nil {
return 0, err
}
if n, err := db.zSetItem(t, destKey, score, []byte(member)); err != nil {
if _, err := db.zSetItem(t, destKey, score, []byte(member)); err != nil {
return 0, err
} else if n == 0 {
//add new
num++
}
}
if _, err := db.zIncrSize(t, destKey, num); err != nil {
return 0, err
}
var num = int64(len(destMap))
sk := db.zEncodeSizeKey(destKey)
t.Put(sk, PutInt64(num))
//todo add binlog
if err := t.Commit(); err != nil {
return 0, err
}
return int64(len(destMap)), nil
return num, nil
}
func (db *DB) ZInterStore(destKey []byte, srcKeys [][]byte, weights []int64, aggregate byte) (int64, error) {
@ -922,28 +918,23 @@ func (db *DB) ZInterStore(destKey []byte, srcKeys [][]byte, weights []int64, agg
db.zDelete(t, destKey)
var num int64 = 0
for member, score := range destMap {
if err := checkZSetKMSize(destKey, []byte(member)); err != nil {
return 0, err
}
if n, err := db.zSetItem(t, destKey, score, []byte(member)); err != nil {
if _, err := db.zSetItem(t, destKey, score, []byte(member)); err != nil {
return 0, err
} else if n == 0 {
//add new
num++
}
}
if _, err := db.zIncrSize(t, destKey, num); err != nil {
return 0, err
}
var num int64 = int64(len(destMap))
sk := db.zEncodeSizeKey(destKey)
t.Put(sk, PutInt64(num))
//todo add binlog
if err := t.Commit(); err != nil {
return 0, err
}
return int64(len(destMap)), nil
return num, nil
}
func (db *DB) ZScan(key []byte, count int, inclusive bool) ([][]byte, error) {

View File

@ -253,6 +253,9 @@ func TestZUnionStore(t *testing.T) {
weights := []int64{1, 2}
out := []byte("out")
db.ZAdd(out, ScorePair{3, []byte("out")})
n, err := db.ZUnionStore(out, keys, weights, AggregateSum)
if err != nil {
t.Fatal(err.Error())
@ -296,6 +299,15 @@ func TestZUnionStore(t *testing.T) {
if n != 3 {
t.Fatal("invalid value ", v)
}
n, err = db.ZCard(out)
if err != nil {
t.Fatal(err.Error())
}
if n != 3 {
t.Fatal("invalid value ", n)
}
}
func TestZInterStore(t *testing.T) {
@ -314,6 +326,8 @@ func TestZInterStore(t *testing.T) {
weights := []int64{2, 3}
out := []byte("out")
db.ZAdd(out, ScorePair{3, []byte("out")})
n, err := db.ZInterStore(out, keys, weights, AggregateSum)
if err != nil {
t.Fatal(err.Error())
@ -329,7 +343,6 @@ func TestZInterStore(t *testing.T) {
t.Fatal("invalid value ", v)
}
out = []byte("out")
n, err = db.ZInterStore(out, keys, weights, AggregateMin)
if err != nil {
t.Fatal(err.Error())
@ -355,4 +368,12 @@ func TestZInterStore(t *testing.T) {
t.Fatal("invalid value ", n)
}
n, err = db.ZCard(out)
if err != nil {
t.Fatal(err.Error())
}
if n != 1 {
t.Fatal("invalid value ", n)
}
}