Merge branch 'client-feature' into develop

This commit is contained in:
holys 2014-08-16 16:41:36 +08:00
commit 917a6416de
13 changed files with 402 additions and 98 deletions

View File

@ -2,14 +2,12 @@ from __future__ import with_statement
import datetime import datetime
import time as mod_time import time as mod_time
from ledis._compat import (b, izip, imap, iteritems, from ledis._compat import (b, izip, imap, iteritems,
basestring, long, nativestr, urlparse, bytes) basestring, long, nativestr, bytes)
from ledis.connection import ConnectionPool, UnixDomainSocketConnection from ledis.connection import ConnectionPool, UnixDomainSocketConnection
from ledis.exceptions import ( from ledis.exceptions import (
ConnectionError, ConnectionError,
DataError, DataError,
LedisError, LedisError,
ResponseError,
ExecAbortError,
) )
SYM_EMPTY = b('') SYM_EMPTY = b('')
@ -75,15 +73,18 @@ class Ledis(object):
""" """
RESPONSE_CALLBACKS = dict_merge( RESPONSE_CALLBACKS = dict_merge(
string_keys_to_dict( string_keys_to_dict(
'EXISTS EXPIRE EXPIREAT HEXISTS HMSET SETNX ' 'EXISTS HEXISTS SISMEMBER HMSET SETNX'
'PERSIST HPERSIST LPERSIST ZPERSIST BEXPIRE ' 'PERSIST HPERSIST LPERSIST ZPERSIST SPERSIST BPERSIST'
'BEXPIREAT BPERSIST BDELETE', 'EXPIRE LEXPIRE HEXPIRE SEXPIRE ZEXPIRE BEXPIRE'
'EXPIREAT LBEXPIREAT HEXPIREAT SEXPIREAT ZEXPIREAT BEXPIREAT',
bool bool
), ),
string_keys_to_dict( string_keys_to_dict(
'DECRBY DEL HDEL HLEN INCRBY LLEN ' 'DECRBY DEL HDEL HLEN INCRBY LLEN ZADD ZCARD ZREM'
'ZADD ZCARD ZREM ZREMRANGEBYRANK ZREMRANGEBYSCORE' 'ZREMRANGEBYRANK ZREMRANGEBYSCORE LMCLEAR HMCLEAR'
'LMCLEAR HMCLEAR ZMCLEAR BCOUNT BGETBIT BSETBIT BOPT BMSETBIT', 'ZMCLEAR BCOUNT BGETBIT BSETBIT BOPT BMSETBIT'
'SADD SCARD SDIFFSTORE SINTERSTORE SUNIONSTORE SREM'
'SCLEAR SMLEAR BDELETE',
int int
), ),
string_keys_to_dict( string_keys_to_dict(
@ -94,6 +95,10 @@ class Ledis(object):
'MSET SELECT ', 'MSET SELECT ',
lambda r: nativestr(r) == 'OK' lambda r: nativestr(r) == 'OK'
), ),
string_keys_to_dict(
'SDIFF SINTER SMEMBERS SUNION',
lambda r: r and set(r) or set()
),
string_keys_to_dict( string_keys_to_dict(
'ZRANGE ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE', 'ZRANGE ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE',
zset_score_pairs zset_score_pairs
@ -403,6 +408,103 @@ class Ledis(object):
return self.execute_command('LPERSIST', name) return self.execute_command('LPERSIST', name)
#### SET COMMANDS ####
def sadd(self, name, *values):
"Add ``value(s)`` to set ``name``"
return self.execute_command('SADD', name, *values)
def scard(self, name):
"Return the number of elements in set ``name``"
return self.execute_command('SCARD', name)
def sdiff(self, keys, *args):
"Return the difference of sets specified by ``keys``"
args = list_or_args(keys, args)
return self.execute_command('SDIFF', *args)
def sdiffstore(self, dest, keys, *args):
"""
Store the difference of sets specified by ``keys`` into a new
set named ``dest``. Returns the number of keys in the new set.
"""
args = list_or_args(keys, args)
return self.execute_command('SDIFFSTORE', dest, *args)
def sinter(self, keys, *args):
"Return the intersection of sets specified by ``keys``"
args = list_or_args(keys, args)
return self.execute_command('SINTER', *args)
def sinterstore(self, dest, keys, *args):
"""
Store the intersection of sets specified by ``keys`` into a new
set named ``dest``. Returns the number of keys in the new set.
"""
args = list_or_args(keys, args)
return self.execute_command('SINTERSTORE', dest, *args)
def sismember(self, name, value):
"Return a boolean indicating if ``value`` is a member of set ``name``"
return self.execute_command('SISMEMBER', name, value)
def smembers(self, name):
"Return all members of the set ``name``"
return self.execute_command('SMEMBERS', name)
def srem(self, name, *values):
"Remove ``values`` from set ``name``"
return self.execute_command('SREM', name, *values)
def sunion(self, keys, *args):
"Return the union of sets specified by ``keys``"
args = list_or_args(keys, args)
return self.execute_command('SUNION', *args)
def sunionstore(self, dest, keys, *args):
"""
Store the union of sets specified by ``keys`` into a new
set named ``dest``. Returns the number of keys in the new set.
"""
args = list_or_args(keys, args)
return self.execute_command('SUNIONSTORE', dest, *args)
# SPECIAL COMMANDS SUPPORTED BY LEDISDB
def sclear(self, name):
"Delete key ``name`` from set"
return self.execute_command('SCLEAR', name)
def smclear(self, *names):
"Delete multiple keys ``names`` from set"
return self.execute_command('SMCLEAR', *names)
def sexpire(self, name, time):
"""
Set an expire flag on key name for time milliseconds.
time can be represented by an integer or a Python timedelta object.
"""
if isinstance(time, datetime.timedelta):
time = time.seconds + time.days * 24 * 3600
return self.execute_command('SEXPIRE', name, time)
def sexpireat(self, name, when):
"""
Set an expire flag on key name. when can be represented as an integer
representing unix time in milliseconds (unix time * 1000) or a
Python datetime object.
"""
if isinstance(when, datetime.datetime):
when = int(mod_time.mktime(when.timetuple()))
return self.execute_command('SEXPIREAT', name, when)
def sttl(self, name):
"Returns the number of seconds until the key name will expire"
return self.execute_command('STTL', name)
def spersist(self, name):
"Removes an expiration on name"
return self.execute_command('SPERSIST', name)
#### SORTED SET COMMANDS #### #### SORTED SET COMMANDS ####
def zadd(self, name, *args, **kwargs): def zadd(self, name, *args, **kwargs):
""" """
@ -693,7 +795,6 @@ class Ledis(object):
return self.execute_command('HPERSIST', name) return self.execute_command('HPERSIST', name)
### BIT COMMANDS ### BIT COMMANDS
def bget(self, name): def bget(self, name):
"" ""

View File

@ -3,21 +3,15 @@
import unittest import unittest
import sys import sys
import datetime, time
sys.path.append('..') sys.path.append('..')
import ledis import ledis
from ledis._compat import b from ledis._compat import b
from ledis import ResponseError from util import expire_at, expire_at_seconds
l = ledis.Ledis(port=6380) l = ledis.Ledis(port=6380)
def current_time():
return datetime.datetime.now()
class TestCmdBit(unittest.TestCase): class TestCmdBit(unittest.TestCase):
def setUp(self): def setUp(self):
pass pass
@ -94,21 +88,17 @@ class TestCmdBit(unittest.TestCase):
assert l.bttl('a') == -1 assert l.bttl('a') == -1
def test_bexpireat_datetime(self): def test_bexpireat_datetime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
l.bsetbit('a', 1, True) l.bsetbit('a', 1, True)
assert l.bexpireat('a', expire_at) assert l.bexpireat('a', expire_at())
assert 0 < l.bttl('a') <= 61 assert 0 < l.bttl('a') <= 61
def test_bexpireat_unixtime(self): def test_bexpireat_unixtime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
l.bsetbit('a', 1, True) l.bsetbit('a', 1, True)
expire_at_seconds = int(time.mktime(expire_at.timetuple())) assert l.bexpireat('a', expire_at_seconds())
assert l.bexpireat('a', expire_at_seconds)
assert 0 < l.bttl('a') <= 61 assert 0 < l.bttl('a') <= 61
def test_bexpireat_no_key(self): def test_bexpireat_no_key(self):
expire_at = current_time() + datetime.timedelta(minutes=1) assert not l.bexpireat('a', expire_at())
assert not l.bexpireat('a', expire_at)
def test_bttl_and_bpersist(self): def test_bttl_and_bpersist(self):
l.bsetbit('a', 1, True) l.bsetbit('a', 1, True)

View File

@ -3,19 +3,16 @@
import unittest import unittest
import sys import sys
import datetime, time
sys.path.append('..') sys.path.append('..')
import ledis import ledis
from ledis._compat import b, iteritems, itervalues from ledis._compat import itervalues
from ledis import ResponseError from util import expire_at, expire_at_seconds
l = ledis.Ledis(port=6380) l = ledis.Ledis(port=6380)
def current_time():
return datetime.datetime.now()
class TestCmdHash(unittest.TestCase): class TestCmdHash(unittest.TestCase):
def setUp(self): def setUp(self):
@ -110,21 +107,17 @@ class TestCmdHash(unittest.TestCase):
assert l.httl('myhash') <= 100 assert l.httl('myhash') <= 100
def test_hexpireat_datetime(self): def test_hexpireat_datetime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
l.hset('a', 'f', 'foo') l.hset('a', 'f', 'foo')
assert l.hexpireat('a', expire_at) assert l.hexpireat('a', expire_at())
assert 0 < l.httl('a') <= 61 assert 0 < l.httl('a') <= 61
def test_hexpireat_unixtime(self): def test_hexpireat_unixtime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
l.hset('a', 'f', 'foo') l.hset('a', 'f', 'foo')
expire_at_seconds = int(time.mktime(expire_at.timetuple())) assert l.hexpireat('a', expire_at_seconds())
assert l.hexpireat('a', expire_at_seconds)
assert 0 < l.httl('a') <= 61 assert 0 < l.httl('a') <= 61
def test_hexpireat_no_key(self): def test_hexpireat_no_key(self):
expire_at = current_time() + datetime.timedelta(minutes=1) assert not l.hexpireat('a', expire_at())
assert not l.hexpireat('a', expire_at)
def test_hexpireat(self): def test_hexpireat(self):
assert l.hexpireat('myhash', 1577808000) == 0 assert l.hexpireat('myhash', 1577808000) == 0

View File

@ -3,19 +3,16 @@
import unittest import unittest
import sys import sys
import datetime, time
sys.path.append('..') sys.path.append('..')
import ledis import ledis
from ledis._compat import b, iteritems from ledis._compat import b, iteritems
from util import expire_at, expire_at_seconds
l = ledis.Ledis(port=6380) l = ledis.Ledis(port=6380)
def current_time():
return datetime.datetime.now()
class TestCmdKv(unittest.TestCase): class TestCmdKv(unittest.TestCase):
def setUp(self): def setUp(self):
pass pass
@ -32,10 +29,6 @@ class TestCmdKv(unittest.TestCase):
assert l.decr('a', amount=5) == -7 assert l.decr('a', amount=5) == -7
assert l['a'] == b('-7') assert l['a'] == b('-7')
#FIXME: how to test exception?
# l.set('b', '234293482390480948029348230948')
# self.assertRaises(ResponseError, l.delete('b'))
def test_decrby(self): def test_decrby(self):
assert l.delete('a') == 1 assert l.delete('a') == 1
assert l.decrby('a') == -1 assert l.decrby('a') == -1
@ -134,21 +127,17 @@ class TestCmdKv(unittest.TestCase):
assert not (l.expire('a', 100)) assert not (l.expire('a', 100))
def test_expireat_datetime(self): def test_expireat_datetime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
l.set('a', '1') l.set('a', '1')
assert l.expireat('a', expire_at) assert l.expireat('a', expire_at())
assert 0 < l.ttl('a') <= 61 assert 0 < l.ttl('a') <= 61
def test_expireat_unixtime(self): def test_expireat_unixtime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
l.set('a', '1') l.set('a', '1')
expire_at_seconds = int(time.mktime(expire_at.timetuple())) assert l.expireat('a', expire_at_seconds())
assert l.expireat('a', expire_at_seconds)
assert 0 < l.ttl('a') <= 61 assert 0 < l.ttl('a') <= 61
def test_expireat_no_key(self): def test_expireat_no_key(self):
expire_at = current_time() + datetime.timedelta(minutes=1) assert not l.expireat('a', expire_at())
assert not l.expireat('a', expire_at)
def test_expireat(self): def test_expireat(self):
l.set('a', 'hello') l.set('a', 'hello')

View File

@ -2,20 +2,17 @@
# Test Cases for list commands # Test Cases for list commands
import unittest import unittest
import datetime, time
import sys import sys
sys.path.append('..') sys.path.append('..')
import ledis import ledis
from ledis._compat import b from ledis._compat import b
from util import expire_at, expire_at_seconds
l = ledis.Ledis(port=6380) l = ledis.Ledis(port=6380)
def current_time():
return datetime.datetime.now()
class TestCmdList(unittest.TestCase): class TestCmdList(unittest.TestCase):
def setUp(self): def setUp(self):
pass pass
@ -84,21 +81,17 @@ class TestCmdList(unittest.TestCase):
assert l.lttl('mylist') == -1 assert l.lttl('mylist') == -1
def test_lexpireat_datetime(self): def test_lexpireat_datetime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
l.rpush('mylist', '1') l.rpush('mylist', '1')
assert l.lexpireat('mylist', expire_at) assert l.lexpireat('mylist', expire_at())
assert 0 < l.lttl('mylist') <= 61 assert 0 < l.lttl('mylist') <= 61
def test_lexpireat_unixtime(self): def test_lexpireat_unixtime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
l.rpush('mylist', '1') l.rpush('mylist', '1')
expire_at_seconds = int(time.mktime(expire_at.timetuple())) assert l.lexpireat('mylist', expire_at_seconds())
assert l.lexpireat('mylist', expire_at_seconds)
assert l.lttl('mylist') <= 61 assert l.lttl('mylist') <= 61
def test_lexpireat_no_key(self): def test_lexpireat_no_key(self):
expire_at = current_time() + datetime.timedelta(minutes=1) assert not l.lexpireat('mylist', expire_at())
assert not l.lexpireat('mylist', expire_at)
def test_lttl_and_lpersist(self): def test_lttl_and_lpersist(self):
l.rpush('mylist', '1') l.rpush('mylist', '1')

View File

@ -0,0 +1,156 @@
# coding: utf-8
# Test set commands
import unittest
import sys
sys.path.append('..')
import pytest
import ledis
from ledis._compat import b
from ledis import ResponseError
from util import expire_at, expire_at_seconds
l = ledis.Ledis(port=6380)
class TestCmdSet(unittest.TestCase):
def setUp(self):
pass
def tearDown(self):
l.smclear('a', 'b', 'c')
def test_sadd(self):
members = set([b('1'), b('2'), b('3')])
l.sadd('a', *members)
assert l.smembers('a') == members
def test_scard(self):
l.sadd('a', '1', '2', '3')
assert l.scard('a') == 3
def test_sdiff(self):
l.sadd('a', '1', '2', '3')
assert l.sdiff('a', 'b') == set([b('1'), b('2'), b('3')])
l.sadd('b', '2', '3')
assert l.sdiff('a', 'b') == set([b('1')])
def test_sdiffstore(self):
l.sadd('a', '1', '2', '3')
assert l.sdiffstore('c', 'a', 'b') == 3
assert l.smembers('c') == set([b('1'), b('2'), b('3')])
l.sadd('b', '2', '3')
print l.smembers('c')
print "before"
assert l.sdiffstore('c', 'a', 'b') == 1
print l.smembers('c')
assert l.smembers('c') == set([b('1')])
def test_sinter(self):
l.sadd('a', '1', '2', '3')
assert l.sinter('a', 'b') == set()
l.sadd('b', '2', '3')
assert l.sinter('a', 'b') == set([b('2'), b('3')])
def test_sinterstore(self):
l.sadd('a', '1', '2', '3')
assert l.sinterstore('c', 'a', 'b') == 0
assert l.smembers('c') == set()
l.sadd('b', '2', '3')
assert l.sinterstore('c', 'a', 'b') == 2
assert l.smembers('c') == set([b('2'), b('3')])
def test_sismember(self):
l.sadd('a', '1', '2', '3')
assert l.sismember('a', '1')
assert l.sismember('a', '2')
assert l.sismember('a', '3')
assert not l.sismember('a', '4')
def test_smembers(self):
l.sadd('a', '1', '2', '3')
assert l.smembers('a') == set([b('1'), b('2'), b('3')])
def test_srem(self):
l.sadd('a', '1', '2', '3', '4')
assert l.srem('a', '5') == 0
assert l.srem('a', '2', '4') == 2
assert l.smembers('a') == set([b('1'), b('3')])
def test_sunion(self):
l.sadd('a', '1', '2')
l.sadd('b', '2', '3')
assert l.sunion('a', 'b') == set([b('1'), b('2'), b('3')])
def test_sunionstore(self):
l.sadd('a', '1', '2')
l.sadd('b', '2', '3')
assert l.sunionstore('c', 'a', 'b') == 3
assert l.smembers('c') == set([b('1'), b('2'), b('3')])
def test_sclear(self):
members = set([b('1'), b('2'), b('3')])
l.sadd('a', *members)
assert l.sclear('a') == 3
assert l.sclear('a') == 0
def test_smclear(self):
members = set([b('1'), b('2'), b('3')])
l.sadd('a', *members)
l.sadd('b', *members)
assert l.smclear('a', 'b') == 2
def test_sexpire(self):
members = set([b('1'), b('2'), b('3')])
assert l.sexpire('a', 100) == 0
l.sadd('a', *members)
assert l.sexpire('a', 100) == 1
assert l.sttl('a') <= 100
def test_sexpireat_datetime(self):
members = set([b('1'), b('2'), b('3')])
l.sadd('a', *members)
assert l.sexpireat('a', expire_at())
assert 0 < l.sttl('a') <= 61
def test_sexpireat_unixtime(self):
members = set([b('1'), b('2'), b('3')])
l.sadd('a', *members)
assert l.sexpireat('a', expire_at_seconds())
assert 0 < l.sttl('a') <= 61
def test_sexpireat_no_key(self):
assert not l.sexpireat('a', expire_at())
def test_sexpireat(self):
assert l.sexpireat('a', 1577808000) == 0
members = set([b('1'), b('2'), b('3')])
l.sadd('a', *members)
assert l.sexpireat('a', 1577808000) == 1
def test_sttl(self):
members = set([b('1'), b('2'), b('3')])
l.sadd('a', *members)
assert l.sexpire('a', 100)
assert l.sttl('a') <= 100
def test_spersist(self):
members = set([b('1'), b('2'), b('3')])
l.sadd('a', *members)
l.sexpire('a', 100)
assert l.sttl('a') <= 100
assert l.spersist('a')
assert l.sttl('a') == -1
def test_invalid_params(self):
with pytest.raises(ResponseError) as excinfo:
l.sadd("a")
assert excinfo.value.message == "invalid command param"
def test_invalid_value(self):
members = set([b('1'), b('2'), b('3')])
l.sadd('a', *members)
self.assertRaises(ResponseError, lambda: l.sexpire('a', 'a'))

View File

@ -3,17 +3,14 @@
import unittest import unittest
import sys import sys
import datetime, time
sys.path.append('..') sys.path.append('..')
import ledis import ledis
from ledis._compat import b, iteritems from ledis._compat import b
from ledis import ResponseError from util import expire_at, expire_at_seconds
l = ledis.Ledis(port=6380) l = ledis.Ledis(port=6380)
def current_time():
return datetime.datetime.now()
class TestCmdZset(unittest.TestCase): class TestCmdZset(unittest.TestCase):
def setUp(self): def setUp(self):
@ -145,21 +142,17 @@ class TestCmdZset(unittest.TestCase):
assert l.zttl('a') == -1 assert l.zttl('a') == -1
def test_zexpireat_datetime(self): def test_zexpireat_datetime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
l.zadd('a', a1=1) l.zadd('a', a1=1)
assert l.zexpireat('a', expire_at) assert l.zexpireat('a', expire_at())
assert 0 < l.zttl('a') <= 61 assert 0 < l.zttl('a') <= 61
def test_zexpireat_unixtime(self): def test_zexpireat_unixtime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
l.zadd('a', a1=1) l.zadd('a', a1=1)
expire_at_seconds = int(time.mktime(expire_at.timetuple())) assert l.zexpireat('a', expire_at_seconds())
assert l.zexpireat('a', expire_at_seconds)
assert 0 < l.zttl('a') <= 61 assert 0 < l.zttl('a') <= 61
def test_zexpireat_no_key(self): def test_zexpireat_no_key(self):
expire_at = current_time() + datetime.timedelta(minutes=1) assert not l.zexpireat('a', expire_at())
assert not l.zexpireat('a', expire_at)
def test_zttl_and_zpersist(self): def test_zttl_and_zpersist(self):
l.zadd('a', a1=1) l.zadd('a', a1=1)

View File

@ -0,0 +1,20 @@
#coding: utf-8
import datetime
import time
def current_time():
return datetime.datetime.now()
def expire_at(minute=1):
expire_at = current_time() + datetime.timedelta(minutes=minute)
return expire_at
def expire_at_seconds(minute=1):
return int(time.mktime(expire_at(minute=minute).timetuple()))
if __name__ == "__main__":
print expire_at()
print expire_at_seconds()

View File

@ -36,16 +36,38 @@ client.bget("bit key 3", function(err, result){
}); });
//test zunionstore & zinterstore //test zunionstore & zinterstore
client.zadd("zset1", 1, "one") client.zadd("zset1", 1, "one");
client.zadd("zset1", 2, "two") client.zadd("zset1", 2, "two");
client.zadd("zset2", 1, "one") client.zadd("zset2", 1, "one");
client.zadd("zset2", 2, "two") client.zadd("zset2", 2, "two");
client.zadd("zset2", 3, "three") client.zadd("zset2", 3, "three");
client.zunionstore("out", 2, "zset1", "zset2", "weights", 2, 3, ledis.print);
client.zrange("out", 0, -1, "withscores", ledis.print);
client.zinterstore("out", 2, "zset1", "zset2", "weights", 2, 3, ledis.print);
client.zrange("out", 0, -1, "withscores", ledis.print);
//example of set commands
client.sadd("a", 1, 2, 3);
client.sadd("b", 3, 4, 5);
client.sismember("a", 1, ledis.print);
client.smembers("a", ledis.print);
client.sdiff("a", "b", ledis.print);
client.sdiffstore("c", "a", "b", ledis.print);
client.sinter("a", "b", ledis.print);
client.sinterstore("c", "a", "b", ledis.print);
client.sunion("a", "b", ledis.print);
client.sunionstore("c", "a", "b", ledis.print);
client.srem("a", 1, ledis.print);
client.sclear("c", ledis.print);
client.smclear("a", "b", ledis.print);
client.sexpire("a", 100, ledis.print);
client.sexpireat("a", 1577808000, ledis.print);
client.sttl("a", ledis.print);
client.spersist("a", ledis.print);
client.zunionstore("out", 2, "zset1", "zset2", "weights", 2, 3, ledis.print)
client.zrange("out", 0, -1, "withscores", ledis.print)
client.zinterstore("out", 2, "zset1", "zset2", "weights", 2, 3, ledis.print)
client.zrange("out", 0, -1, "withscores", ledis.print)
client.quit() client.quit()

View File

@ -2,6 +2,10 @@
module.exports = [ module.exports = [
"quit", "quit",
"ping",
"echo",
"select",
"bget", "bget",
"bdelete", "bdelete",
"bsetbit", "bsetbit",
@ -93,4 +97,26 @@ module.exports = [
"zexpireat", "zexpireat",
"zttl", "zttl",
"zpersist", "zpersist",
"sadd",
"scard",
"sdiff",
"sdiffstore",
"sinter",
"sinterstore",
"sismember",
"smembers",
"srem",
"sunion",
"sunionstore",
"sclear",
"smclear",
"sexpire",
"sexpireat",
"sttl",
"spersist"
]; ];

View File

@ -113,6 +113,27 @@ local commands = {
"bttl", "bttl",
"bpersist", "bpersist",
--[[set]]
"sadd",
"scard",
"sdiff",
"sdiffstore",
"sinter",
"sinterstore",
"sismember",
"smembers",
"srem",
"sunion",
"sunionstore",
"sclear",
"smclear",
"sexpire",
"sexpireat",
"sttl",
"spersist",
--[[server]] --[[server]]
"ping", "ping",
"echo", "echo",

View File

@ -490,7 +490,6 @@ func (db *DB) sStoreGeneric(dstKey []byte, optType byte, keys ...[]byte) (int64,
var err error var err error
var ek []byte var ek []byte
var num int64 = 0
var v [][]byte var v [][]byte
switch optType { switch optType {
@ -513,22 +512,21 @@ func (db *DB) sStoreGeneric(dstKey []byte, optType byte, keys ...[]byte) (int64,
ek = db.sEncodeSetKey(dstKey, m) ek = db.sEncodeSetKey(dstKey, m)
if v, err := db.db.Get(ek); err != nil { if _, err := db.db.Get(ek); err != nil {
return 0, err return 0, err
} else if v == nil {
num++
} }
t.Put(ek, nil) t.Put(ek, nil)
} }
if _, err = db.sIncrSize(dstKey, num); err != nil { var num = int64(len(v))
sk := db.sEncodeSizeKey(dstKey)
t.Put(sk, PutInt64(num))
if err = t.Commit(); err != nil {
return 0, err return 0, err
} }
return num, nil
err = t.Commit()
return num, err
} }
func (db *DB) SClear(key []byte) (int64, error) { func (db *DB) SClear(key []byte) (int64, error) {

View File

@ -147,7 +147,7 @@ func testUnion(db *DB, t *testing.T) {
m2 := []byte("m2") m2 := []byte("m2")
m3 := []byte("m3") m3 := []byte("m3")
db.SAdd(key, m1, m2) db.SAdd(key, m1, m2)
db.SAdd(key1, m1, m3) db.SAdd(key1, m1, m2, m3)
db.SAdd(key2, m2, m3) db.SAdd(key2, m2, m3)
if _, err := db.sUnionGeneric(key, key2); err != nil { if _, err := db.sUnionGeneric(key, key2); err != nil {
t.Fatal(err) t.Fatal(err)
@ -158,11 +158,13 @@ func testUnion(db *DB, t *testing.T) {
} }
dstkey := []byte("union_dsk") dstkey := []byte("union_dsk")
db.SAdd(dstkey, []byte("x"))
if num, err := db.SUnionStore(dstkey, key1, key2); err != nil { if num, err := db.SUnionStore(dstkey, key1, key2); err != nil {
t.Fatal(err) t.Fatal(err)
} else if num != 3 { } else if num != 3 {
t.Fatal(num) t.Fatal(num)
} }
if _, err := db.SMembers(dstkey); err != nil { if _, err := db.SMembers(dstkey); err != nil {
t.Fatal(err) t.Fatal(err)
} }