*scan commands bug fix; add more tests

This commit is contained in:
holys 2014-09-04 12:53:19 +08:00
parent 6f3f049707
commit ba9f611488
13 changed files with 173 additions and 25 deletions

View File

@ -23,3 +23,6 @@ clean:
test: test:
go test -tags '$(GO_BUILD_TAGS)' ./... go test -tags '$(GO_BUILD_TAGS)' ./...
pytest:
sh client/ledis-py/tests/all.sh

5
client/ledis-py/Makefile Normal file
View File

@ -0,0 +1,5 @@
.PHONY: test
test:
sh tests/all.sh

View File

@ -4,7 +4,7 @@ import time as mod_time
from itertools import chain, starmap from itertools import chain, starmap
from ledis._compat import (b, izip, imap, iteritems, from ledis._compat import (b, izip, imap, iteritems,
basestring, long, nativestr, bytes) basestring, long, nativestr, bytes)
from ledis.connection import ConnectionPool, UnixDomainSocketConnection from ledis.connection import ConnectionPool, UnixDomainSocketConnection, Token
from ledis.exceptions import ( from ledis.exceptions import (
ConnectionError, ConnectionError,
DataError, DataError,
@ -87,6 +87,7 @@ def parse_info(response):
return info return info
# def parse_lscan(response, )
class Ledis(object): class Ledis(object):
""" """
@ -138,6 +139,7 @@ class Ledis(object):
'INFO': parse_info, 'INFO': parse_info,
} }
) )
@classmethod @classmethod
@ -382,8 +384,21 @@ class Ledis(object):
"Removes an expiration on name" "Removes an expiration on name"
return self.execute_command('PERSIST', name) return self.execute_command('PERSIST', name)
def scan(self, key, match = "", count = 10): def scan(self, key="" , match=None, count=10):
return self.execute_command("SCAN", key, match, count) pieces = [key]
if match is not None:
pieces.extend(["MATCH", match])
pieces.extend(["COUNT", count])
return self.execute_command("SCAN", *pieces)
def scan_iter(self, match=None, count=10):
key = ""
while key != "":
key, data = self.scan(key=key, match=match, count=count)
for item in data:
yield item
#### LIST COMMANDS #### #### LIST COMMANDS ####
def lindex(self, name, index): def lindex(self, name, index):
@ -460,8 +475,8 @@ class Ledis(object):
"Removes an expiration on ``name``" "Removes an expiration on ``name``"
return self.execute_command('LPERSIST', name) return self.execute_command('LPERSIST', name)
def lscan(self, key, match = "", count = 10): def lscan(self, key="", match=None, count=10):
return self.execute_command("LSCAN", key, match, count) return self.scan_generic("LSCAN", key=key, match=match, count=count)
#### SET COMMANDS #### #### SET COMMANDS ####
@ -560,8 +575,8 @@ class Ledis(object):
"Removes an expiration on name" "Removes an expiration on name"
return self.execute_command('SPERSIST', name) return self.execute_command('SPERSIST', name)
def sscan(self, key, match = "", count = 10): def sscan(self, key="", match=None, count = 10):
return self.execute_command("SSCAN", key, match, count) return self.scan_generic("SSCAN", key=key, match=match, count=count)
#### SORTED SET COMMANDS #### #### SORTED SET COMMANDS ####
@ -759,9 +774,17 @@ class Ledis(object):
"Removes an expiration on name" "Removes an expiration on name"
return self.execute_command('ZPERSIST', name) return self.execute_command('ZPERSIST', name)
def zscan(self, key, match = "", count = 10):
return self.execute_command("ZSCAN", key, match, count)
def scan_generic(self, scan_type, key="", match=None, count=10):
pieces = [key]
if match is not None:
pieces.extend([Token("MATCH"), match])
pieces.extend([Token("count"), count])
scan_type = scan_type.upper()
return self.execute_command(scan_type, *pieces)
def zscan(self, key="", match=None, count=10):
return self.scan_generic("ZSCAN", key=key, match=match, count=count)
#### HASH COMMANDS #### #### HASH COMMANDS ####
def hdel(self, name, *keys): def hdel(self, name, *keys):
@ -855,8 +878,8 @@ class Ledis(object):
"Removes an expiration on name" "Removes an expiration on name"
return self.execute_command('HPERSIST', name) return self.execute_command('HPERSIST', name)
def hscan(self, key, match = "", count = 10): def hscan(self, key="", match=None, count=10):
return self.execute_command("HSCAN", key, match, count) return self.scan_generic("HSCAN", key=key, match=match, count=count)
### BIT COMMANDS ### BIT COMMANDS
@ -934,8 +957,8 @@ class Ledis(object):
"Removes an expiration on name" "Removes an expiration on name"
return self.execute_command('BPERSIST', name) return self.execute_command('BPERSIST', name)
def bscan(self, key, match = "", count = 10): def bscan(self, key="", match=None, count=10):
return self.execute_command("BSCAN", key, match, count) return self.scan_generic("BSCAN", key=key, match=match, count=count)
def eval(self, script, keys, *args): def eval(self, script, keys, *args):
n = len(keys) n = len(keys)

View File

@ -588,3 +588,23 @@ class BlockingConnectionPool(object):
timeout=self.timeout, timeout=self.timeout,
connection_class=self.connection_class, connection_class=self.connection_class,
queue_class=self.queue_class, **self.connection_kwargs) queue_class=self.queue_class, **self.connection_kwargs)
class Token(object):
"""
Literal strings in Redis commands, such as the command names and any
hard-coded arguments are wrapped in this class so we know not to apply
and encoding rules on them.
"""
def __init__(self, value):
if isinstance(value, Token):
value = value.value
self.value = value
def __repr__(self):
return self.value
def __str__(self):
return self.value

View File

@ -1,7 +1,7 @@
dbs=(leveldb rocksdb hyperleveldb goleveldb boltdb lmdb) dbs=(leveldb rocksdb hyperleveldb goleveldb boltdb lmdb)
for db in "${dbs[@]}" for db in "${dbs[@]}"
do do
killall ledis-server
ledis-server -db_name=$db & ledis-server -db_name=$db &
py.test py.test
killall ledis-server
done done

View File

@ -17,8 +17,7 @@ class TestCmdBit(unittest.TestCase):
pass pass
def tearDown(self): def tearDown(self):
l.bdelete('a') l.flushdb()
l.bdelete('non_exists_key')
def test_bget(self): def test_bget(self):
"bget is the same as get in K/V commands" "bget is the same as get in K/V commands"

View File

@ -19,7 +19,7 @@ class TestCmdHash(unittest.TestCase):
pass pass
def tearDown(self): def tearDown(self):
l.hmclear('myhash', 'a') l.flushdb()
def test_hdel(self): def test_hdel(self):

View File

@ -18,7 +18,7 @@ class TestCmdKv(unittest.TestCase):
pass pass
def tearDown(self): def tearDown(self):
l.delete('a', 'b', 'c', 'non_exist_key') l.flushdb()
def test_decr(self): def test_decr(self):
assert l.delete('a') == 1 assert l.delete('a') == 1

View File

@ -18,7 +18,7 @@ class TestCmdList(unittest.TestCase):
pass pass
def tearDown(self): def tearDown(self):
l.lmclear('mylist', 'mylist1', 'mylist2') l.flushdb()
def test_lindex(self): def test_lindex(self):
l.rpush('mylist', '1', '2', '3') l.rpush('mylist', '1', '2', '3')

View File

@ -20,7 +20,7 @@ class TestCmdSet(unittest.TestCase):
pass pass
def tearDown(self): def tearDown(self):
l.smclear('a', 'b', 'c') l.flushdb()
def test_sadd(self): def test_sadd(self):
members = set([b('1'), b('2'), b('3')]) members = set([b('1'), b('2'), b('3')])

View File

@ -17,7 +17,7 @@ class TestCmdZset(unittest.TestCase):
pass pass
def tearDown(self): def tearDown(self):
l.zclear('a') l.flushdb()
def test_zadd(self): def test_zadd(self):
l.zadd('a', a1=1, a2=2, a3=3) l.zadd('a', a1=1, a2=2, a3=3)

View File

@ -10,13 +10,14 @@ from ledis._compat import b
from ledis import ResponseError from ledis import ResponseError
l = ledis.Ledis(port=6380) l = ledis.Ledis(port=6380)
dbs = ["leveldb", "rocksdb", "goleveldb", "hyperleveldb", "lmdb", "boltdb"]
class TestOtherCommands(unittest.TestCase): class TestOtherCommands(unittest.TestCase):
def setUp(self): def setUp(self):
pass pass
def tearDown(self): def tearDown(self):
pass l.flushdb()
# server information # server information
def test_echo(self): def test_echo(self):
@ -29,3 +30,92 @@ class TestOtherCommands(unittest.TestCase):
assert l.select('1') assert l.select('1')
assert l.select('15') assert l.select('15')
self.assertRaises(ResponseError, lambda: l.select('16')) self.assertRaises(ResponseError, lambda: l.select('16'))
def test_info(self):
info1 = l.info()
assert info1.get("db_name") in dbs
info2 = l.info(section="server")
assert info2.get("os") in ["linux", "darwin"]
def test_flushdb(self):
l.set("a", 1)
assert l.flushdb() == "OK"
assert l.get("a") is None
def test_flushall(self):
l.select(1)
l.set("a", 1)
assert l.get("a") == b("1")
l.select(10)
l.set("a", 1)
assert l.get("a") == b("1")
assert l.flushall() == "OK"
assert l.get("a") is None
l.select(1)
assert l.get("a") is None
# test *scan commands
def check_keys(self, scan_type):
d = {
"scan": l.scan,
"sscan": l.sscan,
"lscan": l.lscan,
"hscan": l.hscan,
"zscan": l.zscan,
"bscan": l.bscan
}
key, keys = d[scan_type]()
assert key == ""
assert set(keys) == set([b("a"), b("b"), b("c")])
_, keys = d[scan_type](match="a")
assert set(keys) == set([b("a")])
_, keys = d[scan_type](key="a")
assert set(keys) == set([b("b"), b("c")])
def test_scan(self):
d = {"a":1, "b":2, "c": 3}
l.mset(d)
self.check_keys("scan")
def test_lscan(self):
l.rpush("a", 1)
l.rpush("b", 1)
l.rpush("c", 1)
self.check_keys("lscan")
def test_hscan(self):
l.hset("a", "hello", "world")
l.hset("b", "hello", "world")
l.hset("c", "hello", "world")
self.check_keys("hscan")
def test_sscan(self):
l.sadd("a", 1)
l.sadd("b", 2)
l.sadd("c", 3)
self.check_keys("sscan")
def test_zscan(self):
l.zadd("a", 1, "a")
l.zadd("b", 1, "a")
l.zadd("c", 1, "a")
self.check_keys("zscan")
def test_bscan(self):
l.bsetbit("a", 1, 1)
l.bsetbit("b", 1, 1)
l.bsetbit("c", 1, 1)
self.check_keys("bscan")

View File

@ -4,14 +4,21 @@ sys.path.append("..")
import ledis import ledis
global_l = ledis.Ledis()
#db that do not support transaction
dbs = ["leveldb", "rocksdb", "hyperleveldb", "goleveldb"]
check = global_l.info().get("db_name") in dbs
class TestTx(unittest.TestCase): class TestTx(unittest.TestCase):
def setUp(self): def setUp(self):
self.l = ledis.Ledis(port=6380) self.l = ledis.Ledis(port=6380)
def tearDown(self): def tearDown(self):
self.l.delete("a") self.l.flushdb()
@unittest.skipIf(check, reason="db not support transaction")
def test_commit(self): def test_commit(self):
tx = self.l.tx() tx = self.l.tx()
self.l.set("a", "no-tx") self.l.set("a", "no-tx")
@ -24,6 +31,7 @@ class TestTx(unittest.TestCase):
tx.commit() tx.commit()
assert self.l.get("a") == "tx" assert self.l.get("a") == "tx"
@unittest.skipIf(check, reason="db not support transaction")
def test_rollback(self): def test_rollback(self):
tx = self.l.tx() tx = self.l.tx()
self.l.set("a", "no-tx") self.l.set("a", "no-tx")