update ledis-py, added new tests

This commit is contained in:
holys 2014-07-01 10:40:56 +08:00
parent 115ca10a6c
commit c86882a530
21 changed files with 845 additions and 4559 deletions

4
.gitignore vendored
View File

@ -1 +1,3 @@
build
build
*.pyc
.DS_Store

View File

@ -1,31 +1,27 @@
from redis.client import Redis, StrictRedis
from redis.connection import (
from ledis.client import Ledis
from ledis.connection import (
BlockingConnectionPool,
ConnectionPool,
Connection,
UnixDomainSocketConnection
)
from redis.utils import from_url
from redis.exceptions import (
AuthenticationError,
from ledis.utils import from_url
from ledis.exceptions import (
ConnectionError,
BusyLoadingError,
DataError,
InvalidResponse,
PubSubError,
RedisError,
LedisError,
ResponseError,
WatchError,
)
__version__ = '2.7.6'
__version__ = '0.0.1'
VERSION = tuple(map(int, __version__.split('.')))
__all__ = [
'Redis', 'StrictRedis', 'ConnectionPool', 'BlockingConnectionPool',
'Ledis', 'ConnectionPool', 'BlockingConnectionPool',
'Connection', 'UnixDomainSocketConnection',
'RedisError', 'ConnectionError', 'ResponseError', 'AuthenticationError',
'InvalidResponse', 'DataError', 'PubSubError', 'WatchError', 'from_url',
'BusyLoadingError'
'LedisError', 'ConnectionError', 'ResponseError',
'InvalidResponse', 'DataError', 'from_url', 'BusyLoadingError',
]

View File

@ -3,7 +3,7 @@ import sys
if sys.version_info[0] < 3:
from urlparse import urlparse
from urlparse import parse_qs, urlparse
from itertools import imap, izip
from string import letters as ascii_letters
from Queue import Queue
@ -28,7 +28,7 @@ if sys.version_info[0] < 3:
bytes = str
long = long
else:
from urllib.parse import urlparse
from urllib.parse import parse_qs, urlparse
from io import BytesIO
from string import ascii_letters
from queue import Queue

File diff suppressed because it is too large Load Diff

View File

@ -3,22 +3,17 @@ import os
import socket
import sys
from redis._compat import (b, xrange, imap, byte_to_chr, unicode, bytes, long,
BytesIO, nativestr, basestring,
LifoQueue, Empty, Full)
from redis.exceptions import (
RedisError,
from ledis._compat import (b, xrange, imap, byte_to_chr, unicode, bytes, long,
BytesIO, nativestr, basestring, iteritems,
LifoQueue, Empty, Full, urlparse, parse_qs)
from ledis.exceptions import (
LedisError,
ConnectionError,
BusyLoadingError,
ResponseError,
InvalidResponse,
AuthenticationError,
NoScriptError,
ExecAbortError,
)
from redis.utils import HIREDIS_AVAILABLE
if HIREDIS_AVAILABLE:
import hiredis
)
SYM_STAR = b('*')
@ -36,7 +31,6 @@ class PythonParser(object):
'ERR': ResponseError,
'EXECABORT': ExecAbortError,
'LOADING': BusyLoadingError,
'NOSCRIPT': NoScriptError,
}
def __init__(self):
@ -146,61 +140,11 @@ class PythonParser(object):
return response
class HiredisParser(object):
"Parser class for connections using Hiredis"
def __init__(self):
if not HIREDIS_AVAILABLE:
raise RedisError("Hiredis is not installed")
def __del__(self):
try:
self.on_disconnect()
except Exception:
pass
def on_connect(self, connection):
self._sock = connection._sock
kwargs = {
'protocolError': InvalidResponse,
'replyError': ResponseError,
}
if connection.decode_responses:
kwargs['encoding'] = connection.encoding
self._reader = hiredis.Reader(**kwargs)
def on_disconnect(self):
self._sock = None
self._reader = None
def read_response(self):
if not self._reader:
raise ConnectionError("Socket closed on remote end")
response = self._reader.gets()
while response is False:
try:
buffer = self._sock.recv(4096)
except (socket.error, socket.timeout):
e = sys.exc_info()[1]
raise ConnectionError("Error while reading from socket: %s" %
(e.args,))
if not buffer:
raise ConnectionError("Socket closed on remote end")
self._reader.feed(buffer)
# proactively, but not conclusively, check if more data is in the
# buffer. if the data received doesn't end with \n, there's more.
if not buffer.endswith(SYM_LF):
continue
response = self._reader.gets()
return response
if HIREDIS_AVAILABLE:
DefaultParser = HiredisParser
else:
DefaultParser = PythonParser
DefaultParser = PythonParser
class Connection(object):
"Manages TCP communication to and from a Redis server"
"Manages TCP communication to and from a Ledis server"
def __init__(self, host='localhost', port=6379, db=0, password=None,
socket_timeout=None, encoding='utf-8',
encoding_errors='strict', decode_responses=False,
@ -224,7 +168,7 @@ class Connection(object):
pass
def connect(self):
"Connects to the Redis server if not already connected"
"Connects to the Ledis server if not already connected"
if self._sock:
return
try:
@ -270,7 +214,7 @@ class Connection(object):
raise ConnectionError('Invalid Database')
def disconnect(self):
"Disconnects from the Redis server"
"Disconnects from the Ledis server"
self._parser.on_disconnect()
if self._sock is None:
return
@ -281,7 +225,7 @@ class Connection(object):
self._sock = None
def send_packed_command(self, command):
"Send an already packed command to the Redis server"
"Send an already packed command to the Ledis server"
if not self._sock:
self.connect()
try:
@ -300,7 +244,7 @@ class Connection(object):
raise
def send_command(self, *args):
"Pack and send a command to the Redis server"
"Pack and send a command to the Ledis server"
self.send_packed_command(self.pack_command(*args))
def read_response(self):
@ -327,7 +271,7 @@ class Connection(object):
return value
def pack_command(self, *args):
"Pack a series of arguments into a value Redis command"
"Pack a series of arguments into a value Ledis command"
output = SYM_STAR + b(str(len(args))) + SYM_CRLF
for enc_value in imap(self.encode, args):
output += SYM_DOLLAR
@ -375,6 +319,88 @@ class UnixDomainSocketConnection(Connection):
# TODO: add ability to block waiting on a connection to be released
class ConnectionPool(object):
"Generic connection pool"
@classmethod
def from_url(cls, url, db=None, **kwargs):
"""
Return a connection pool configured from the given URL.
For example::
redis://[:password]@localhost:6379/0
rediss://[:password]@localhost:6379/0
unix://[:password]@/path/to/socket.sock?db=0
Three URL schemes are supported:
redis:// creates a normal TCP socket connection
rediss:// creates a SSL wrapped TCP socket connection
unix:// creates a Unix Domain Socket connection
There are several ways to specify a database number. The parse function
will return the first specified option:
1. A ``db`` querystring option, e.g. redis://localhost?db=0
2. If using the redis:// scheme, the path argument of the url, e.g.
redis://localhost/0
3. The ``db`` argument to this function.
If none of these options are specified, db=0 is used.
Any additional querystring arguments and keyword arguments will be
passed along to the ConnectionPool class's initializer. In the case
of conflicting arguments, querystring arguments always win.
"""
url_string = url
url = urlparse(url)
qs = ''
# in python2.6, custom URL schemes don't recognize querystring values
# they're left as part of the url.path.
if '?' in url.path and not url.query:
# chop the querystring including the ? off the end of the url
# and reparse it.
qs = url.path.split('?', 1)[1]
url = urlparse(url_string[:-(len(qs) + 1)])
else:
qs = url.query
url_options = {}
for name, value in iteritems(parse_qs(qs)):
if value and len(value) > 0:
url_options[name] = value[0]
# We only support redis:// and unix:// schemes.
if url.scheme == 'unix':
url_options.update({
'password': url.password,
'path': url.path,
'connection_class': UnixDomainSocketConnection,
})
else:
url_options.update({
'host': url.hostname,
'port': int(url.port or 6379),
'password': url.password,
})
# If there's a path argument, use it as the db argument if a
# querystring value wasn't specified
if 'db' not in url_options and url.path:
try:
url_options['db'] = int(url.path.replace('/', ''))
except (AttributeError, ValueError):
pass
if url.scheme == 'lediss':
url_options['connection_class'] = SSLConnection
# last shot at the db value
url_options['db'] = int(url_options.get('db', db or 0))
# update the arguments from the URL values
kwargs.update(url_options)
return cls(**kwargs)
def __init__(self, connection_class=Connection, max_connections=None,
**connection_kwargs):
self.pid = os.getpid()

View File

@ -1,15 +1,10 @@
"Core exceptions raised by the Redis client"
"Core exceptions raised by the LedisDB client"
class RedisError(Exception):
class LedisError(Exception):
pass
class AuthenticationError(RedisError):
pass
class ServerError(RedisError):
class ServerError(LedisError):
pass
@ -25,23 +20,11 @@ class InvalidResponse(ServerError):
pass
class ResponseError(RedisError):
class ResponseError(LedisError):
pass
class DataError(RedisError):
pass
class PubSubError(RedisError):
pass
class WatchError(RedisError):
pass
class NoScriptError(ResponseError):
class DataError(LedisError):
pass

View File

@ -1,16 +1,10 @@
try:
import hiredis
HIREDIS_AVAILABLE = True
except ImportError:
HIREDIS_AVAILABLE = False
def from_url(url, db=None, **kwargs):
"""
Returns an active Redis client generated from the given database URL.
Returns an active Ledis client generated from the given database URL.
Will attempt to extract the database id from the path url fragment, if
none is provided.
"""
from redis.client import Redis
return Redis.from_url(url, db, **kwargs)
from ledis.client import Ledis
return Ledis.from_url(url, db, **kwargs)

View File

@ -2,7 +2,7 @@
import os
import sys
from redis import __version__
from ledis import __version__
try:
from setuptools import setup
@ -30,18 +30,14 @@ long_description = f.read()
f.close()
setup(
name='redis',
name='ledis',
version=__version__,
description='Python client for Redis key-value store',
description='Python client for ledis key-value store',
long_description=long_description,
url='http://github.com/andymccurdy/redis-py',
author='Andy McCurdy',
author_email='sedrik@gmail.com',
maintainer='Andy McCurdy',
maintainer_email='sedrik@gmail.com',
keywords=['Redis', 'key-value store'],
url='https://github.com/siddontang/ledisdb',
keywords=['ledis', 'key-value store'],
license='MIT',
packages=['redis'],
packages=['ledis'],
tests_require=['pytest>=2.5.0'],
cmdclass={'test': PyTest},
classifiers=[

View File

@ -1,46 +0,0 @@
import pytest
import redis
from distutils.version import StrictVersion
_REDIS_VERSIONS = {}
def get_version(**kwargs):
params = {'host': 'localhost', 'port': 6379, 'db': 9}
params.update(kwargs)
key = '%s:%s' % (params['host'], params['port'])
if key not in _REDIS_VERSIONS:
client = redis.Redis(**params)
_REDIS_VERSIONS[key] = client.info()['redis_version']
client.connection_pool.disconnect()
return _REDIS_VERSIONS[key]
def _get_client(cls, request=None, **kwargs):
params = {'host': 'localhost', 'port': 6379, 'db': 9}
params.update(kwargs)
client = cls(**params)
client.flushdb()
if request:
def teardown():
client.flushdb()
client.connection_pool.disconnect()
request.addfinalizer(teardown)
return client
def skip_if_server_version_lt(min_version):
check = StrictVersion(get_version()) < StrictVersion(min_version)
return pytest.mark.skipif(check, reason="")
@pytest.fixture()
def r(request, **kwargs):
return _get_client(redis.Redis, request, **kwargs)
@pytest.fixture()
def sr(request, **kwargs):
return _get_client(redis.StrictRedis, request, **kwargs)

View File

@ -0,0 +1,143 @@
# coding: utf-8
# Test Cases for list commands
import unittest
import sys
import datetime, time
sys.path.append('..')
import ledis
from ledis._compat import b, iteritems, itervalues
from ledis import ResponseError
def current_time():
return datetime.datetime.now()
class TestCmdHash(unittest.TestCase):
def setUp(self):
self.l = ledis.Ledis(port=6666)
def tearDown(self):
self.l.hmclear('myhash', 'a')
def test_hdel(self):
self.l.hset('myhash', 'field1', 'foo')
assert self.l.hdel('myhash', 'field1') == 1
assert self.l.hdel('myhash', 'field1') == 0
assert self.l.hdel('myhash', 'field1', 'field2') == 0
def test_hexists(self):
self.l.hset('myhash', 'field1', 'foo')
self.l.hdel('myhash', 'field2')
assert self.l.hexists('myhash', 'field1') == 1
assert self.l.hexists('myhash', 'field2') == 0
def test_hget(self):
self.l.hset('myhash', 'field1', 'foo')
assert self.l.hget('myhash', 'field1') == 'foo'
self.assertIsNone(self.l.hget('myhash', 'field2'))
def test_hgetall(self):
h = {'field1': 'foo', 'field2': 'bar'}
self.l.hmset('myhash', h)
assert self.l.hgetall('myhash') == h
def test_hincrby(self):
assert self.l.hincrby('myhash', 'field1') == 1
self.l.hclear('myhash')
assert self.l.hincrby('myhash', 'field1', 1) == 1
assert self.l.hincrby('myhash', 'field1', 5) == 6
assert self.l.hincrby('myhash', 'field1', -10) == -4
def test_hkeys(self):
h = {'field1': 'foo', 'field2': 'bar'}
self.l.hmset('myhash', h)
assert self.l.hkeys('myhash') == ['field1', 'field2']
def test_hlen(self):
self.l.hset('myhash', 'field1', 'foo')
assert self.l.hlen('myhash') == 1
self.l.hset('myhash', 'field2', 'bar')
assert self.l.hlen('myhash') == 2
def test_hmget(self):
assert self.l.hmset('myhash', {'a': '1', 'b': '2', 'c': '3'})
assert self.l.hmget('myhash', 'a', 'b', 'c') == ['1', '2', '3']
def test_hmset(self):
h = {'a': '1', 'b': '2', 'c': '3'}
assert self.l.hmset('myhash', h)
assert self.l.hgetall('myhash') == h
def test_hset(self):
self.l.hclear('myhash')
assert int(self.l.hset('myhash', 'field1', 'foo')) == 1
assert self.l.hset('myhash', 'field1', 'foo') == 0
def test_hvals(self):
h = {'a': '1', 'b': '2', 'c': '3'}
self.l.hmset('myhash', h)
local_vals = list(itervalues(h))
remote_vals = self.l.hvals('myhash')
assert sorted(local_vals) == sorted(remote_vals)
def test_hclear(self):
h = {'a': '1', 'b': '2', 'c': '3'}
self.l.hmset('myhash', h)
assert self.l.hclear('myhash') == 3
assert self.l.hclear('myhash') == 0
def test_hmclear(self):
h = {'a': '1', 'b': '2', 'c': '3'}
self.l.hmset('myhash1', h)
self.l.hmset('myhash2', h)
assert self.l.hmclear('myhash1', 'myhash2') == 2
def test_hexpire(self):
assert self.l.hexpire('myhash', 100) == 0
self.l.hset('myhash', 'field1', 'foo')
assert self.l.hexpire('myhash', 100) == 1
assert self.l.httl('myhash') <= 100
def test_hexpireat_datetime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
self.l.hset('a', 'f', 'foo')
assert self.l.hexpireat('a', expire_at)
assert 0 < self.l.httl('a') <= 61
def test_hexpireat_unixtime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
self.l.hset('a', 'f', 'foo')
expire_at_seconds = int(time.mktime(expire_at.timetuple()))
assert self.l.hexpireat('a', expire_at_seconds)
assert 0 < self.l.httl('a') <= 61
def test_zexpireat_no_key(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
assert not self.l.hexpireat('a', expire_at)
def test_hexpireat(self):
assert self.l.hexpireat('myhash', 1577808000) == 0
self.l.hset('myhash', 'field1', 'foo')
assert self.l.hexpireat('myhash', 1577808000) == 1
def test_httl(self):
self.l.hset('myhash', 'field1', 'foo')
assert self.l.hexpire('myhash', 100)
assert self.l.httl('myhash') <= 100
def test_hpersist(self):
self.l.hset('myhash', 'field1', 'foo')
self.l.hexpire('myhash', 100)
assert self.l.httl('myhash') <= 100
assert self.l.hpersist('myhash')
assert self.l.httl('myhash') == -1

View File

@ -0,0 +1,156 @@
# coding: utf-8
# Test Cases for k/v commands
import unittest
import sys
import datetime, time
sys.path.append('..')
import ledis
from ledis._compat import b, iteritems
def current_time():
return datetime.datetime.now()
class TestCmdKv(unittest.TestCase):
def setUp(self):
self.l = ledis.Ledis(port=6666)
def tearDown(self):
self.l.delete('a', 'b', 'c', 'non_exist_key')
def test_decr(self):
assert self.l.delete('a') == 1
assert self.l.decr('a') == -1
assert self.l['a'] == b('-1')
assert self.l.decr('a') == -2
assert self.l['a'] == b('-2')
assert self.l.decr('a', amount=5) == -7
assert self.l['a'] == b('-7')
#FIXME: how to test exception?
# self.l.set('b', '234293482390480948029348230948')
# self.assertRaises(ResponseError, self.l.delete('b'))
def test_decrby(self):
assert self.l.delete('a') == 1
assert self.l.decrby('a') == -1
assert self.l['a'] == b('-1')
assert self.l.decrby('a') == -2
assert self.l['a'] == b('-2')
assert self.l.decrby('a', amount=5) == -7
assert self.l['a'] == b('-7')
def test_del(self):
assert self.l.delete('a') == 1
assert self.l.delete('a', 'b', 'c') == 3
def test_exists(self):
self.l.delete('a', 'non_exist_key')
self.l.set('a', 'hello')
self.assertTrue(self.l.exists('a'))
self.assertFalse(self.l.exists('non_exist_key'))
def test_get(self):
self.l.set('a', 'hello')
assert self.l.get('a') == 'hello'
self.l.set('b', '中文')
assert self.l.get('b') == '中文'
self.l.delete('non_exist_key')
self.assertIsNone(self.l.get('non_exist_key'))
def test_getset(self):
self.l.set('a', 'hello')
assert self.l.getset('a', 'world') == 'hello'
assert self.l.get('a') == 'world'
self.l.delete('non_exist_key')
self.assertIsNone(self.l.getset('non_exist_key', 'non'))
def test_incr(self):
self.l.delete('non_exist_key')
assert self.l.incr('non_exist_key') == 1
self.l.set('a', 100)
assert self.l.incr('a') == 101
def test_incrby(self):
self.l.delete('a')
assert self.l.incrby('a', 100) == 100
self.l.set('a', 100)
assert self.l.incrby('a', 100) == 200
assert self.l.incrby('a', amount=100) == 300
def test_mget(self):
self.l.set('a', 'hello')
self.l.set('b', 'world')
self.l.delete('non_exist_key')
assert self.l.mget('a', 'b', 'non_exist_key') == ['hello', 'world', None]
self.l.delete('a', 'b')
assert self.l.mget(['a', 'b']) == [None, None]
def test_mset(self):
d = {'a': b('1'), 'b': b('2'), 'c': b('3')}
assert self.l.mset(**d)
for k, v in iteritems(d):
assert self.l[k] == v
def test_set(self):
self.assertTrue(self.l.set('a', 100))
def test_setnx(self):
self.l.delete('a')
assert self.l.setnx('a', '1')
assert self.l['a'] == b('1')
assert not self.l.setnx('a', '2')
assert self.l['a'] == b('1')
def test_ttl(self):
assert self.l.set('a', 'hello')
assert self.l.expire('a', 100)
assert self.l.ttl('a') <= 100
self.l.delete('a')
assert self.l.ttl('a') == -1
self.l.set('a', 'hello')
assert self.l.ttl('a') == -1
def test_persist(self):
assert self.l.set('a', 'hello')
assert self.l.expire('a', 100)
assert self.l.ttl('a') <= 100
assert self.l.persist('a')
self.l.delete('non_exist_key')
assert not self.l.persist('non_exist_key')
def test_expire(self):
assert not self.l.expire('a', 100)
self.l.set('a', 'hello')
self.assertTrue(self.l.expire('a', 100))
self.l.delete('a')
self.assertFalse(self.l.expire('a', 100))
def test_expireat_datetime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
self.l.set('a', '1')
assert self.l.expireat('a', expire_at)
assert 0 < self.l.ttl('a') <= 61
def test_expireat_unixtime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
self.l.set('a', '1')
expire_at_seconds = int(time.mktime(expire_at.timetuple()))
assert self.l.expireat('a', expire_at_seconds)
assert 0 < self.l.ttl('a') <= 61
def test_expireat_no_key(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
assert not self.l.expireat('a', expire_at)
def test_expireat(self):
self.l.set('a', 'hello')
self.assertTrue(self.l.expireat('a', 1577808000)) # time is 2020.1.1
self.l.delete('a')
self.assertFalse(self.l.expireat('a', 1577808000))

View File

@ -0,0 +1,106 @@
# coding: utf-8
# Test Cases for list commands
import unittest
import datetime, time
import sys
sys.path.append('..')
import ledis
def current_time():
return datetime.datetime.now()
class TestCmdList(unittest.TestCase):
def setUp(self):
self.l = ledis.Ledis(port=6666)
def tearDown(self):
self.l.lmclear('mylist', 'mylist1', 'mylist2')
def test_lindex(self):
self.l.rpush('mylist', '1', '2', '3')
assert self.l.lindex('mylist', 0) == '1'
assert self.l.lindex('mylist', 1) == '2'
assert self.l.lindex('mylist', 2) == '3'
def test_llen(self):
self.l.rpush('mylist', '1', '2', '3')
assert self.l.llen('mylist') == 3
def test_lpop(self):
self.l.rpush('mylist', '1', '2', '3')
assert self.l.lpop('mylist') == '1'
assert self.l.lpop('mylist') == '2'
assert self.l.lpop('mylist') == '3'
assert self.l.lpop('mylist') is None
def test_lpush(self):
assert self.l.lpush('mylist', '1') == 1
assert self.l.lpush('mylist', '2') == 2
assert self.l.lpush('mylist', '3', '4', '5') == 5
assert self.l.lrange('mylist', 0, 5) == ['5', '4', '3', '2', '1']
def test_lrange(self):
self.l.rpush('mylist', '1', '2', '3', '4', '5')
assert self.l.lrange('mylist', 0, 2) == ['1', '2', '3']
assert self.l.lrange('mylist', 2, 10) == ['3', '4', '5']
assert self.l.lrange('mylist', 0, 5) == ['1', '2', '3', '4', '5']
def test_rpush(self):
assert self.l.rpush('mylist', '1') == 1
assert self.l.rpush('mylist', '2') == 2
assert self.l.rpush('mylist', '3', '4') == 4
assert self.l.lrange('mylist', 0, 5) == ['1', '2', '3', '4']
def test_rpop(self):
self.l.rpush('mylist', '1', '2', '3')
assert self.l.rpop('mylist') == '3'
assert self.l.rpop('mylist') == '2'
assert self.l.rpop('mylist') == '1'
assert self.l.rpop('mylist') is None
def test_lclear(self):
self.l.rpush('mylist', '1', '2', '3')
assert self.l.lclear('mylist') == 3
assert self.l.lclear('mylist') == 0
def test_lmclear(self):
self.l.rpush('mylist1', '1', '2', '3')
self.l.rpush('mylist2', '1', '2', '3')
assert self.l.lmclear('mylist1', 'mylist2') == 2
def test_lexpire(self):
assert not self.l.lexpire('mylist', 100)
self.l.rpush('mylist', '1')
assert self.l.lexpire('mylist', 100)
assert 0 < self.l.lttl('mylist') <= 100
assert self.l.lpersist('mylist')
assert self.l.lttl('mylist') == -1
def test_lexpireat_datetime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
self.l.rpush('mylist', '1')
assert self.l.lexpireat('mylist', expire_at)
assert 0 < self.l.lttl('mylist') <= 61
def test_lexpireat_unixtime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
self.l.rpush('mylist', '1')
expire_at_seconds = int(time.mktime(expire_at.timetuple()))
assert self.l.lexpireat('mylist', expire_at_seconds)
assert self.l.lttl('mylist') <= 61
def test_lexpireat_no_key(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
assert not self.l.lexpireat('mylist', expire_at)
def test_lttl_and_lpersist(self):
self.l.rpush('mylist', '1')
self.l.lexpire('mylist', 100)
assert 0 < self.l.lttl('mylist') <= 100
assert self.l.lpersist('mylist')
assert self.l.lttl('mylist') == -1

View File

@ -0,0 +1,180 @@
# coding: utf-8
# Test Cases for list commands
import unittest
import sys
import datetime, time
sys.path.append('..')
import ledis
from ledis._compat import b, iteritems
from ledis import ResponseError
def current_time():
return datetime.datetime.now()
class TestCmdZset(unittest.TestCase):
def setUp(self):
self.l = ledis.Ledis(port=6666)
def tearDown(self):
self.l.zclear('a')
def test_zadd(self):
self.l.zadd('a', a1=1, a2=2, a3=3)
assert self.l.zrange('a', 0, -1) == ['a1', 'a2', 'a3']
def test_zcard(self):
self.l.zadd('a', a1=1, a2=2, a3=3)
assert self.l.zcard('a') == 3
def test_zcount(self):
self.l.zadd('a', a1=1, a2=2, a3=3)
assert self.l.zcount('a', '-inf', '+inf') == 3
assert self.l.zcount('a', 1, 2) == 2
assert self.l.zcount('a', 10, 20) == 0
def test_zincrby(self):
self.l.zadd('a', a1=1, a2=2, a3=3)
assert self.l.zincrby('a', 'a2') == 3.0
assert self.l.zincrby('a', 'a3', amount=5) == 8.0
assert self.l.zscore('a', 'a2') == 3.0
assert self.l.zscore('a', 'a3') == 8.0
def test_zrange(self):
self.l.zadd('a', a1=1, a2=2, a3=3)
assert self.l.zrange('a', 0, 1) == ['a1', 'a2']
assert self.l.zrange('a', 2, 3) == ['a3']
#withscores
assert self.l.zrange('a', 0, 1, withscores=True) == \
[('a1', 1.0), ('a2', 2.0)]
assert self.l.zrange('a', 2, 3, withscores=True) == \
[('a3', 3.0)]
def test_zrangebyscore(self):
self.l.zadd('a', a1=1, a2=2, a3=3, a4=4, a5=5)
assert self.l.zrangebyscore('a', 2, 4) == ['a2', 'a3', 'a4']
# slicing with start/num
assert self.l.zrangebyscore('a', 2, 4, start=1, num=2) == \
['a3', 'a4']
# withscores
assert self.l.zrangebyscore('a', 2, 4, withscores=True) == \
[('a2', 2.0), ('a3', 3.0), ('a4', 4.0)]
# custom score function
assert self.l.zrangebyscore('a', 2, 4, withscores=True,
score_cast_func=int) == \
[('a2', 2), ('a3', 3), ('a4', 4)]
def test_rank(self):
self.l.zadd('a', a1=1, a2=2, a3=3, a4=4, a5=5)
assert self.l.zrank('a', 'a1') == 0
assert self.l.zrank('a', 'a3') == 2
assert self.l.zrank('a', 'a6') is None
def test_zrem(self):
self.l.zadd('a', a1=1, a2=2, a3=3)
assert self.l.zrem('a', 'a2') == 1
assert self.l.zrange('a', 0, -1) == ['a1', 'a3']
assert self.l.zrem('a', 'b') == 0
assert self.l.zrange('a', 0, -1) == ['a1', 'a3']
# multiple keys
self.l.zadd('a', a1=1, a2=2, a3=3)
assert self.l.zrem('a', 'a1', 'a2') == 2
assert self.l.zrange('a', 0, -1) == ['a3']
def test_zremrangebyrank(self):
self.l.zadd('a', a1=1, a2=2, a3=3, a4=4, a5=5)
assert self.l.zremrangebyrank('a', 1, 3) == 3
assert self.l.zrange('a', 0, -1) == ['a1', 'a5']
def test_zremrangebyscore(self):
self.l.zadd('a', a1=1, a2=2, a3=3, a4=4, a5=5)
assert self.l.zremrangebyscore('a', 2, 4) == 3
assert self.l.zrange('a', 0, -1) == ['a1', 'a5']
assert self.l.zremrangebyscore('a', 2, 4) == 0
assert self.l.zrange('a', 0, -1) == ['a1', 'a5']
def test_zrevrange(self):
self.l.zadd('a', a1=1, a2=2, a3=3)
assert self.l.zrevrange('a', 0, 1) == ['a3', 'a2']
assert self.l.zrevrange('a', 1, 2) == ['a2', 'a1']
def test_zrevrank(self):
self.l.zadd('a', a1=1, a2=2, a3=3, a4=4, a5=5)
assert self.l.zrevrank('a', 'a1') == 4
assert self.l.zrevrank('a', 'a2') == 3
assert self.l.zrevrank('a', 'a6') is None
def test_zrevrangebyscore(self):
self.l.zadd('a', a1=1, a2=2, a3=3, a4=4, a5=5)
assert self.l.zrevrangebyscore('a', 2, 4) == ['a4', 'a3', 'a2']
# slicing with start/num
assert self.l.zrevrangebyscore('a', 2, 4, start=1, num=2) == \
['a3', 'a2']
# withscores
assert self.l.zrevrangebyscore('a', 2, 4, withscores=True) == \
[('a4', 4.0), ('a3', 3.0), ('a2', 2.0)]
# custom score function
assert self.l.zrevrangebyscore('a', 2, 4, withscores=True,
score_cast_func=int) == \
[('a4', 4), ('a3', 3), ('a2', 2)]
def test_zscore(self):
self.l.zadd('a', a1=1, a2=2, a3=3)
assert self.l.zscore('a', 'a1') == 1.0
assert self.l.zscore('a', 'a2') == 2.0
assert self.l.zscore('a', 'a4') is None
def test_zclear(self):
self.l.zadd('a', a1=1, a2=2, a3=3)
assert self.l.zclear('a') == 3
assert self.l.zclear('a') == 0
def test_zmclear(self):
self.l.zadd('a', a1=1, a2=2, a3=3)
self.l.zadd('b', b1=1, b2=2, b3=3)
assert self.l.lmclear('a', 'b') == 2
assert self.l.lmclear('c', 'd') == 2
def test_zexpire(self):
assert not self.l.zexpire('a', 100)
self.l.zadd('a', a1=1, a2=2, a3=3)
assert self.l.zexpire('a', 100)
assert 0 < self.l.zttl('a') <= 100
assert self.l.zpersist('a')
assert self.l.zttl('a') == -1
def test_zexpireat_datetime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
self.l.zadd('a', a1=1)
assert self.l.zexpireat('a', expire_at)
assert 0 < self.l.zttl('a') <= 61
def test_zexpireat_unixtime(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
self.l.zadd('a', a1=1)
expire_at_seconds = int(time.mktime(expire_at.timetuple()))
assert self.l.zexpireat('a', expire_at_seconds)
assert 0 < self.l.zttl('a') <= 61
def test_zexpireat_no_key(self):
expire_at = current_time() + datetime.timedelta(minutes=1)
assert not self.l.zexpireat('a', expire_at)
def test_zttl_and_zpersist(self):
self.l.zadd('a', a1=1)
self.l.zexpire('a', 100)
assert 0 < self.l.zttl('a') <= 100
assert self.l.zpersist('a')
assert self.l.zttl('a') == -1

File diff suppressed because it is too large Load Diff

View File

@ -1,402 +0,0 @@
from __future__ import with_statement
import os
import pytest
import redis
import time
import re
from threading import Thread
from redis.connection import ssl_available
from .conftest import skip_if_server_version_lt
class DummyConnection(object):
description_format = "DummyConnection<>"
def __init__(self, **kwargs):
self.kwargs = kwargs
self.pid = os.getpid()
class TestConnectionPool(object):
def get_pool(self, connection_kwargs=None, max_connections=None,
connection_class=DummyConnection):
connection_kwargs = connection_kwargs or {}
pool = redis.ConnectionPool(
connection_class=connection_class,
max_connections=max_connections,
**connection_kwargs)
return pool
def test_connection_creation(self):
connection_kwargs = {'foo': 'bar', 'biz': 'baz'}
pool = self.get_pool(connection_kwargs=connection_kwargs)
connection = pool.get_connection('_')
assert isinstance(connection, DummyConnection)
assert connection.kwargs == connection_kwargs
def test_multiple_connections(self):
pool = self.get_pool()
c1 = pool.get_connection('_')
c2 = pool.get_connection('_')
assert c1 != c2
def test_max_connections(self):
pool = self.get_pool(max_connections=2)
pool.get_connection('_')
pool.get_connection('_')
with pytest.raises(redis.ConnectionError):
pool.get_connection('_')
def test_reuse_previously_released_connection(self):
pool = self.get_pool()
c1 = pool.get_connection('_')
pool.release(c1)
c2 = pool.get_connection('_')
assert c1 == c2
def test_repr_contains_db_info_tcp(self):
connection_kwargs = {'host': 'localhost', 'port': 6379, 'db': 1}
pool = self.get_pool(connection_kwargs=connection_kwargs,
connection_class=redis.Connection)
expected = 'ConnectionPool<Connection<host=localhost,port=6379,db=1>>'
assert repr(pool) == expected
def test_repr_contains_db_info_unix(self):
connection_kwargs = {'path': '/abc', 'db': 1}
pool = self.get_pool(connection_kwargs=connection_kwargs,
connection_class=redis.UnixDomainSocketConnection)
expected = 'ConnectionPool<UnixDomainSocketConnection<path=/abc,db=1>>'
assert repr(pool) == expected
class TestBlockingConnectionPool(object):
def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20):
connection_kwargs = connection_kwargs or {}
pool = redis.BlockingConnectionPool(connection_class=DummyConnection,
max_connections=max_connections,
timeout=timeout,
**connection_kwargs)
return pool
def test_connection_creation(self):
connection_kwargs = {'foo': 'bar', 'biz': 'baz'}
pool = self.get_pool(connection_kwargs=connection_kwargs)
connection = pool.get_connection('_')
assert isinstance(connection, DummyConnection)
assert connection.kwargs == connection_kwargs
def test_multiple_connections(self):
pool = self.get_pool()
c1 = pool.get_connection('_')
c2 = pool.get_connection('_')
assert c1 != c2
def test_connection_pool_blocks_until_timeout(self):
"When out of connections, block for timeout seconds, then raise"
pool = self.get_pool(max_connections=1, timeout=0.1)
pool.get_connection('_')
start = time.time()
with pytest.raises(redis.ConnectionError):
pool.get_connection('_')
# we should have waited at least 0.1 seconds
assert time.time() - start >= 0.1
def connection_pool_blocks_until_another_connection_released(self):
"""
When out of connections, block until another connection is released
to the pool
"""
pool = self.get_pool(max_connections=1, timeout=2)
c1 = pool.get_connection('_')
def target():
time.sleep(0.1)
pool.release(c1)
Thread(target=target).start()
start = time.time()
pool.get_connection('_')
assert time.time() - start >= 0.1
def test_reuse_previously_released_connection(self):
pool = self.get_pool()
c1 = pool.get_connection('_')
pool.release(c1)
c2 = pool.get_connection('_')
assert c1 == c2
def test_repr_contains_db_info_tcp(self):
pool = redis.ConnectionPool(host='localhost', port=6379, db=0)
expected = 'ConnectionPool<Connection<host=localhost,port=6379,db=0>>'
assert repr(pool) == expected
def test_repr_contains_db_info_unix(self):
pool = redis.ConnectionPool(
connection_class=redis.UnixDomainSocketConnection,
path='abc',
db=0,
)
expected = 'ConnectionPool<UnixDomainSocketConnection<path=abc,db=0>>'
assert repr(pool) == expected
class TestConnectionPoolURLParsing(object):
def test_defaults(self):
pool = redis.ConnectionPool.from_url('redis://localhost')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 0,
'password': None,
}
def test_hostname(self):
pool = redis.ConnectionPool.from_url('redis://myhost')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'myhost',
'port': 6379,
'db': 0,
'password': None,
}
def test_port(self):
pool = redis.ConnectionPool.from_url('redis://localhost:6380')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6380,
'db': 0,
'password': None,
}
def test_password(self):
pool = redis.ConnectionPool.from_url('redis://:mypassword@localhost')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 0,
'password': 'mypassword',
}
def test_db_as_argument(self):
pool = redis.ConnectionPool.from_url('redis://localhost', db='1')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 1,
'password': None,
}
def test_db_in_path(self):
pool = redis.ConnectionPool.from_url('redis://localhost/2', db='1')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 2,
'password': None,
}
def test_db_in_querystring(self):
pool = redis.ConnectionPool.from_url('redis://localhost/2?db=3',
db='1')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 3,
'password': None,
}
def test_extra_querystring_options(self):
pool = redis.ConnectionPool.from_url('redis://localhost?a=1&b=2')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 0,
'password': None,
'a': '1',
'b': '2'
}
def test_calling_from_subclass_returns_correct_instance(self):
pool = redis.BlockingConnectionPool.from_url('redis://localhost')
assert isinstance(pool, redis.BlockingConnectionPool)
def test_client_creates_connection_pool(self):
r = redis.StrictRedis.from_url('redis://myhost')
assert r.connection_pool.connection_class == redis.Connection
assert r.connection_pool.connection_kwargs == {
'host': 'myhost',
'port': 6379,
'db': 0,
'password': None,
}
class TestConnectionPoolUnixSocketURLParsing(object):
def test_defaults(self):
pool = redis.ConnectionPool.from_url('unix:///socket')
assert pool.connection_class == redis.UnixDomainSocketConnection
assert pool.connection_kwargs == {
'path': '/socket',
'db': 0,
'password': None,
}
def test_password(self):
pool = redis.ConnectionPool.from_url('unix://:mypassword@/socket')
assert pool.connection_class == redis.UnixDomainSocketConnection
assert pool.connection_kwargs == {
'path': '/socket',
'db': 0,
'password': 'mypassword',
}
def test_db_as_argument(self):
pool = redis.ConnectionPool.from_url('unix:///socket', db=1)
assert pool.connection_class == redis.UnixDomainSocketConnection
assert pool.connection_kwargs == {
'path': '/socket',
'db': 1,
'password': None,
}
def test_db_in_querystring(self):
pool = redis.ConnectionPool.from_url('unix:///socket?db=2', db=1)
assert pool.connection_class == redis.UnixDomainSocketConnection
assert pool.connection_kwargs == {
'path': '/socket',
'db': 2,
'password': None,
}
def test_extra_querystring_options(self):
pool = redis.ConnectionPool.from_url('unix:///socket?a=1&b=2')
assert pool.connection_class == redis.UnixDomainSocketConnection
assert pool.connection_kwargs == {
'path': '/socket',
'db': 0,
'password': None,
'a': '1',
'b': '2'
}
class TestSSLConnectionURLParsing(object):
@pytest.mark.skipif(not ssl_available, reason="SSL not installed")
def test_defaults(self):
pool = redis.ConnectionPool.from_url('rediss://localhost')
assert pool.connection_class == redis.SSLConnection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 0,
'password': None,
}
@pytest.mark.skipif(not ssl_available, reason="SSL not installed")
def test_cert_reqs_options(self):
import ssl
pool = redis.ConnectionPool.from_url('rediss://?ssl_cert_reqs=none')
assert pool.get_connection('_').cert_reqs == ssl.CERT_NONE
pool = redis.ConnectionPool.from_url(
'rediss://?ssl_cert_reqs=optional')
assert pool.get_connection('_').cert_reqs == ssl.CERT_OPTIONAL
pool = redis.ConnectionPool.from_url(
'rediss://?ssl_cert_reqs=required')
assert pool.get_connection('_').cert_reqs == ssl.CERT_REQUIRED
class TestConnection(object):
def test_on_connect_error(self):
"""
An error in Connection.on_connect should disconnect from the server
see for details: https://github.com/andymccurdy/redis-py/issues/368
"""
# this assumes the Redis server being tested against doesn't have
# 9999 databases ;)
bad_connection = redis.Redis(db=9999)
# an error should be raised on connect
with pytest.raises(redis.RedisError):
bad_connection.info()
pool = bad_connection.connection_pool
assert len(pool._available_connections) == 1
assert not pool._available_connections[0]._sock
@skip_if_server_version_lt('2.8.8')
def test_busy_loading_disconnects_socket(self, r):
"""
If Redis raises a LOADING error, the connection should be
disconnected and a BusyLoadingError raised
"""
with pytest.raises(redis.BusyLoadingError):
r.execute_command('DEBUG', 'ERROR', 'LOADING fake message')
pool = r.connection_pool
assert len(pool._available_connections) == 1
assert not pool._available_connections[0]._sock
@skip_if_server_version_lt('2.8.8')
def test_busy_loading_from_pipeline_immediate_command(self, r):
"""
BusyLoadingErrors should raise from Pipelines that execute a
command immediately, like WATCH does.
"""
pipe = r.pipeline()
with pytest.raises(redis.BusyLoadingError):
pipe.immediate_execute_command('DEBUG', 'ERROR',
'LOADING fake message')
pool = r.connection_pool
assert not pipe.connection
assert len(pool._available_connections) == 1
assert not pool._available_connections[0]._sock
@skip_if_server_version_lt('2.8.8')
def test_busy_loading_from_pipeline(self, r):
"""
BusyLoadingErrors should be raised from a pipeline execution
regardless of the raise_on_error flag.
"""
pipe = r.pipeline()
pipe.execute_command('DEBUG', 'ERROR', 'LOADING fake message')
with pytest.raises(redis.BusyLoadingError):
pipe.execute()
pool = r.connection_pool
assert not pipe.connection
assert len(pool._available_connections) == 1
assert not pool._available_connections[0]._sock
@skip_if_server_version_lt('2.8.8')
def test_read_only_error(self, r):
"READONLY errors get turned in ReadOnlyError exceptions"
with pytest.raises(redis.ReadOnlyError):
r.execute_command('DEBUG', 'ERROR', 'READONLY blah blah')
def test_connect_from_url_tcp(self):
connection = redis.Redis.from_url('redis://localhost')
pool = connection.connection_pool
assert re.match('(.*)<(.*)<(.*)>>', repr(pool)).groups() == (
'ConnectionPool',
'Connection',
'host=localhost,port=6379,db=0',
)
def test_connect_from_url_unix(self):
connection = redis.Redis.from_url('unix:///path/to/socket')
pool = connection.connection_pool
assert re.match('(.*)<(.*)<(.*)>>', repr(pool)).groups() == (
'ConnectionPool',
'UnixDomainSocketConnection',
'path=/path/to/socket,db=0',
)

View File

@ -1,33 +0,0 @@
from __future__ import with_statement
import pytest
from redis._compat import unichr, u, unicode
from .conftest import r as _redis_client
class TestEncoding(object):
@pytest.fixture()
def r(self, request):
return _redis_client(request=request, decode_responses=True)
def test_simple_encoding(self, r):
unicode_string = unichr(3456) + u('abcd') + unichr(3421)
r['unicode-string'] = unicode_string
cached_val = r['unicode-string']
assert isinstance(cached_val, unicode)
assert unicode_string == cached_val
def test_list_encoding(self, r):
unicode_string = unichr(3456) + u('abcd') + unichr(3421)
result = [unicode_string, unicode_string, unicode_string]
r.rpush('a', *result)
assert r.lrange('a', 0, -1) == result
class TestCommandsAndTokensArentEncoded(object):
@pytest.fixture()
def r(self, request):
return _redis_client(request=request, charset='utf-16')
def test_basic_command(self, r):
r.set('hello', 'world')

View File

@ -1,167 +0,0 @@
from __future__ import with_statement
import pytest
import time
from redis.exceptions import LockError, ResponseError
from redis.lock import Lock, LuaLock
class TestLock(object):
lock_class = Lock
def get_lock(self, redis, *args, **kwargs):
kwargs['lock_class'] = self.lock_class
return redis.lock(*args, **kwargs)
def test_lock(self, sr):
lock = self.get_lock(sr, 'foo')
assert lock.acquire(blocking=False)
assert sr.get('foo') == lock.local.token
assert sr.ttl('foo') == -1
lock.release()
assert sr.get('foo') is None
def test_competing_locks(self, sr):
lock1 = self.get_lock(sr, 'foo')
lock2 = self.get_lock(sr, 'foo')
assert lock1.acquire(blocking=False)
assert not lock2.acquire(blocking=False)
lock1.release()
assert lock2.acquire(blocking=False)
assert not lock1.acquire(blocking=False)
lock2.release()
def test_timeout(self, sr):
lock = self.get_lock(sr, 'foo', timeout=10)
assert lock.acquire(blocking=False)
assert 8 < sr.ttl('foo') <= 10
lock.release()
def test_float_timeout(self, sr):
lock = self.get_lock(sr, 'foo', timeout=9.5)
assert lock.acquire(blocking=False)
assert 8 < sr.pttl('foo') <= 9500
lock.release()
def test_blocking_timeout(self, sr):
lock1 = self.get_lock(sr, 'foo')
assert lock1.acquire(blocking=False)
lock2 = self.get_lock(sr, 'foo', blocking_timeout=0.2)
start = time.time()
assert not lock2.acquire()
assert (time.time() - start) > 0.2
lock1.release()
def test_context_manager(self, sr):
# blocking_timeout prevents a deadlock if the lock can't be acquired
# for some reason
with self.get_lock(sr, 'foo', blocking_timeout=0.2) as lock:
assert sr.get('foo') == lock.local.token
assert sr.get('foo') is None
def test_high_sleep_raises_error(self, sr):
"If sleep is higher than timeout, it should raise an error"
with pytest.raises(LockError):
self.get_lock(sr, 'foo', timeout=1, sleep=2)
def test_releasing_unlocked_lock_raises_error(self, sr):
lock = self.get_lock(sr, 'foo')
with pytest.raises(LockError):
lock.release()
def test_releasing_lock_no_longer_owned_raises_error(self, sr):
lock = self.get_lock(sr, 'foo')
lock.acquire(blocking=False)
# manually change the token
sr.set('foo', 'a')
with pytest.raises(LockError):
lock.release()
# even though we errored, the token is still cleared
assert lock.local.token is None
def test_extend_lock(self, sr):
lock = self.get_lock(sr, 'foo', timeout=10)
assert lock.acquire(blocking=False)
assert 8000 < sr.pttl('foo') <= 10000
assert lock.extend(10)
assert 16000 < sr.pttl('foo') <= 20000
lock.release()
def test_extend_lock_float(self, sr):
lock = self.get_lock(sr, 'foo', timeout=10.0)
assert lock.acquire(blocking=False)
assert 8000 < sr.pttl('foo') <= 10000
assert lock.extend(10.0)
assert 16000 < sr.pttl('foo') <= 20000
lock.release()
def test_extending_unlocked_lock_raises_error(self, sr):
lock = self.get_lock(sr, 'foo', timeout=10)
with pytest.raises(LockError):
lock.extend(10)
def test_extending_lock_with_no_timeout_raises_error(self, sr):
lock = self.get_lock(sr, 'foo')
assert lock.acquire(blocking=False)
with pytest.raises(LockError):
lock.extend(10)
lock.release()
def test_extending_lock_no_longer_owned_raises_error(self, sr):
lock = self.get_lock(sr, 'foo')
assert lock.acquire(blocking=False)
sr.set('foo', 'a')
with pytest.raises(LockError):
lock.extend(10)
class TestLuaLock(TestLock):
lock_class = LuaLock
class TestLockClassSelection(object):
def test_lock_class_argument(self, sr):
lock = sr.lock('foo', lock_class=Lock)
assert type(lock) == Lock
lock = sr.lock('foo', lock_class=LuaLock)
assert type(lock) == LuaLock
def test_cached_lualock_flag(self, sr):
try:
sr._use_lua_lock = True
lock = sr.lock('foo')
assert type(lock) == LuaLock
finally:
sr._use_lua_lock = None
def test_cached_lock_flag(self, sr):
try:
sr._use_lua_lock = False
lock = sr.lock('foo')
assert type(lock) == Lock
finally:
sr._use_lua_lock = None
def test_lua_compatible_server(self, sr, monkeypatch):
@classmethod
def mock_register(cls, redis):
return
monkeypatch.setattr(LuaLock, 'register_scripts', mock_register)
try:
lock = sr.lock('foo')
assert type(lock) == LuaLock
assert sr._use_lua_lock is True
finally:
sr._use_lua_lock = None
def test_lua_unavailable(self, sr, monkeypatch):
@classmethod
def mock_register(cls, redis):
raise ResponseError()
monkeypatch.setattr(LuaLock, 'register_scripts', mock_register)
try:
lock = sr.lock('foo')
assert type(lock) == Lock
assert sr._use_lua_lock is False
finally:
sr._use_lua_lock = None

View File

@ -1,226 +0,0 @@
from __future__ import with_statement
import pytest
import redis
from redis._compat import b, u, unichr, unicode
class TestPipeline(object):
def test_pipeline(self, r):
with r.pipeline() as pipe:
pipe.set('a', 'a1').get('a').zadd('z', z1=1).zadd('z', z2=4)
pipe.zincrby('z', 'z1').zrange('z', 0, 5, withscores=True)
assert pipe.execute() == \
[
True,
b('a1'),
True,
True,
2.0,
[(b('z1'), 2.0), (b('z2'), 4)],
]
def test_pipeline_length(self, r):
with r.pipeline() as pipe:
# Initially empty.
assert len(pipe) == 0
assert not pipe
# Fill 'er up!
pipe.set('a', 'a1').set('b', 'b1').set('c', 'c1')
assert len(pipe) == 3
assert pipe
# Execute calls reset(), so empty once again.
pipe.execute()
assert len(pipe) == 0
assert not pipe
def test_pipeline_no_transaction(self, r):
with r.pipeline(transaction=False) as pipe:
pipe.set('a', 'a1').set('b', 'b1').set('c', 'c1')
assert pipe.execute() == [True, True, True]
assert r['a'] == b('a1')
assert r['b'] == b('b1')
assert r['c'] == b('c1')
def test_pipeline_no_transaction_watch(self, r):
r['a'] = 0
with r.pipeline(transaction=False) as pipe:
pipe.watch('a')
a = pipe.get('a')
pipe.multi()
pipe.set('a', int(a) + 1)
assert pipe.execute() == [True]
def test_pipeline_no_transaction_watch_failure(self, r):
r['a'] = 0
with r.pipeline(transaction=False) as pipe:
pipe.watch('a')
a = pipe.get('a')
r['a'] = 'bad'
pipe.multi()
pipe.set('a', int(a) + 1)
with pytest.raises(redis.WatchError):
pipe.execute()
assert r['a'] == b('bad')
def test_exec_error_in_response(self, r):
"""
an invalid pipeline command at exec time adds the exception instance
to the list of returned values
"""
r['c'] = 'a'
with r.pipeline() as pipe:
pipe.set('a', 1).set('b', 2).lpush('c', 3).set('d', 4)
result = pipe.execute(raise_on_error=False)
assert result[0]
assert r['a'] == b('1')
assert result[1]
assert r['b'] == b('2')
# we can't lpush to a key that's a string value, so this should
# be a ResponseError exception
assert isinstance(result[2], redis.ResponseError)
assert r['c'] == b('a')
# since this isn't a transaction, the other commands after the
# error are still executed
assert result[3]
assert r['d'] == b('4')
# make sure the pipe was restored to a working state
assert pipe.set('z', 'zzz').execute() == [True]
assert r['z'] == b('zzz')
def test_exec_error_raised(self, r):
r['c'] = 'a'
with r.pipeline() as pipe:
pipe.set('a', 1).set('b', 2).lpush('c', 3).set('d', 4)
with pytest.raises(redis.ResponseError) as ex:
pipe.execute()
assert unicode(ex.value).startswith('Command # 3 (LPUSH c 3) of '
'pipeline caused error: ')
# make sure the pipe was restored to a working state
assert pipe.set('z', 'zzz').execute() == [True]
assert r['z'] == b('zzz')
def test_parse_error_raised(self, r):
with r.pipeline() as pipe:
# the zrem is invalid because we don't pass any keys to it
pipe.set('a', 1).zrem('b').set('b', 2)
with pytest.raises(redis.ResponseError) as ex:
pipe.execute()
assert unicode(ex.value).startswith('Command # 2 (ZREM b) of '
'pipeline caused error: ')
# make sure the pipe was restored to a working state
assert pipe.set('z', 'zzz').execute() == [True]
assert r['z'] == b('zzz')
def test_watch_succeed(self, r):
r['a'] = 1
r['b'] = 2
with r.pipeline() as pipe:
pipe.watch('a', 'b')
assert pipe.watching
a_value = pipe.get('a')
b_value = pipe.get('b')
assert a_value == b('1')
assert b_value == b('2')
pipe.multi()
pipe.set('c', 3)
assert pipe.execute() == [True]
assert not pipe.watching
def test_watch_failure(self, r):
r['a'] = 1
r['b'] = 2
with r.pipeline() as pipe:
pipe.watch('a', 'b')
r['b'] = 3
pipe.multi()
pipe.get('a')
with pytest.raises(redis.WatchError):
pipe.execute()
assert not pipe.watching
def test_unwatch(self, r):
r['a'] = 1
r['b'] = 2
with r.pipeline() as pipe:
pipe.watch('a', 'b')
r['b'] = 3
pipe.unwatch()
assert not pipe.watching
pipe.get('a')
assert pipe.execute() == [b('1')]
def test_transaction_callable(self, r):
r['a'] = 1
r['b'] = 2
has_run = []
def my_transaction(pipe):
a_value = pipe.get('a')
assert a_value in (b('1'), b('2'))
b_value = pipe.get('b')
assert b_value == b('2')
# silly run-once code... incr's "a" so WatchError should be raised
# forcing this all to run again. this should incr "a" once to "2"
if not has_run:
r.incr('a')
has_run.append('it has')
pipe.multi()
pipe.set('c', int(a_value) + int(b_value))
result = r.transaction(my_transaction, 'a', 'b')
assert result == [True]
assert r['c'] == b('4')
def test_exec_error_in_no_transaction_pipeline(self, r):
r['a'] = 1
with r.pipeline(transaction=False) as pipe:
pipe.llen('a')
pipe.expire('a', 100)
with pytest.raises(redis.ResponseError) as ex:
pipe.execute()
assert unicode(ex.value).startswith('Command # 1 (LLEN a) of '
'pipeline caused error: ')
assert r['a'] == b('1')
def test_exec_error_in_no_transaction_pipeline_unicode_command(self, r):
key = unichr(3456) + u('abcd') + unichr(3421)
r[key] = 1
with r.pipeline(transaction=False) as pipe:
pipe.llen(key)
pipe.expire(key, 100)
with pytest.raises(redis.ResponseError) as ex:
pipe.execute()
expected = unicode('Command # 1 (LLEN %s) of pipeline caused '
'error: ') % key
assert unicode(ex.value).startswith(expected)
assert r[key] == b('1')

View File

@ -1,392 +0,0 @@
from __future__ import with_statement
import pytest
import time
import redis
from redis.exceptions import ConnectionError
from redis._compat import basestring, u, unichr
from .conftest import r as _redis_client
def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False):
now = time.time()
timeout = now + timeout
while now < timeout:
message = pubsub.get_message(
ignore_subscribe_messages=ignore_subscribe_messages)
if message is not None:
return message
time.sleep(0.01)
now = time.time()
return None
def make_message(type, channel, data, pattern=None):
return {
'type': type,
'pattern': pattern and pattern.encode('utf-8') or None,
'channel': channel.encode('utf-8'),
'data': data.encode('utf-8') if isinstance(data, basestring) else data
}
def make_subscribe_test_data(pubsub, type):
if type == 'channel':
return {
'p': pubsub,
'sub_type': 'subscribe',
'unsub_type': 'unsubscribe',
'sub_func': pubsub.subscribe,
'unsub_func': pubsub.unsubscribe,
'keys': ['foo', 'bar', u('uni') + unichr(4456) + u('code')]
}
elif type == 'pattern':
return {
'p': pubsub,
'sub_type': 'psubscribe',
'unsub_type': 'punsubscribe',
'sub_func': pubsub.psubscribe,
'unsub_func': pubsub.punsubscribe,
'keys': ['f*', 'b*', u('uni') + unichr(4456) + u('*')]
}
assert False, 'invalid subscribe type: %s' % type
class TestPubSubSubscribeUnsubscribe(object):
def _test_subscribe_unsubscribe(self, p, sub_type, unsub_type, sub_func,
unsub_func, keys):
for key in keys:
assert sub_func(key) is None
# should be a message for each channel/pattern we just subscribed to
for i, key in enumerate(keys):
assert wait_for_message(p) == make_message(sub_type, key, i + 1)
for key in keys:
assert unsub_func(key) is None
# should be a message for each channel/pattern we just unsubscribed
# from
for i, key in enumerate(keys):
i = len(keys) - 1 - i
assert wait_for_message(p) == make_message(unsub_type, key, i)
def test_channel_subscribe_unsubscribe(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), 'channel')
self._test_subscribe_unsubscribe(**kwargs)
def test_pattern_subscribe_unsubscribe(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), 'pattern')
self._test_subscribe_unsubscribe(**kwargs)
def _test_resubscribe_on_reconnection(self, p, sub_type, unsub_type,
sub_func, unsub_func, keys):
for key in keys:
assert sub_func(key) is None
# should be a message for each channel/pattern we just subscribed to
for i, key in enumerate(keys):
assert wait_for_message(p) == make_message(sub_type, key, i + 1)
# manually disconnect
p.connection.disconnect()
# calling get_message again reconnects and resubscribes
# note, we may not re-subscribe to channels in exactly the same order
# so we have to do some extra checks to make sure we got them all
messages = []
for i in range(len(keys)):
messages.append(wait_for_message(p))
unique_channels = set()
assert len(messages) == len(keys)
for i, message in enumerate(messages):
assert message['type'] == sub_type
assert message['data'] == i + 1
assert isinstance(message['channel'], bytes)
channel = message['channel'].decode('utf-8')
unique_channels.add(channel)
assert len(unique_channels) == len(keys)
for channel in unique_channels:
assert channel in keys
def test_resubscribe_to_channels_on_reconnection(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), 'channel')
self._test_resubscribe_on_reconnection(**kwargs)
def test_resubscribe_to_patterns_on_reconnection(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), 'pattern')
self._test_resubscribe_on_reconnection(**kwargs)
def _test_subscribed_property(self, p, sub_type, unsub_type, sub_func,
unsub_func, keys):
assert p.subscribed is False
sub_func(keys[0])
# we're now subscribed even though we haven't processed the
# reply from the server just yet
assert p.subscribed is True
assert wait_for_message(p) == make_message(sub_type, keys[0], 1)
# we're still subscribed
assert p.subscribed is True
# unsubscribe from all channels
unsub_func()
# we're still technically subscribed until we process the
# response messages from the server
assert p.subscribed is True
assert wait_for_message(p) == make_message(unsub_type, keys[0], 0)
# now we're no longer subscribed as no more messages can be delivered
# to any channels we were listening to
assert p.subscribed is False
# subscribing again flips the flag back
sub_func(keys[0])
assert p.subscribed is True
assert wait_for_message(p) == make_message(sub_type, keys[0], 1)
# unsubscribe again
unsub_func()
assert p.subscribed is True
# subscribe to another channel before reading the unsubscribe response
sub_func(keys[1])
assert p.subscribed is True
# read the unsubscribe for key1
assert wait_for_message(p) == make_message(unsub_type, keys[0], 0)
# we're still subscribed to key2, so subscribed should still be True
assert p.subscribed is True
# read the key2 subscribe message
assert wait_for_message(p) == make_message(sub_type, keys[1], 1)
unsub_func()
# haven't read the message yet, so we're still subscribed
assert p.subscribed is True
assert wait_for_message(p) == make_message(unsub_type, keys[1], 0)
# now we're finally unsubscribed
assert p.subscribed is False
def test_subscribe_property_with_channels(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), 'channel')
self._test_subscribed_property(**kwargs)
def test_subscribe_property_with_patterns(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), 'pattern')
self._test_subscribed_property(**kwargs)
def test_ignore_all_subscribe_messages(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
checks = (
(p.subscribe, 'foo'),
(p.unsubscribe, 'foo'),
(p.psubscribe, 'f*'),
(p.punsubscribe, 'f*'),
)
assert p.subscribed is False
for func, channel in checks:
assert func(channel) is None
assert p.subscribed is True
assert wait_for_message(p) is None
assert p.subscribed is False
def test_ignore_individual_subscribe_messages(self, r):
p = r.pubsub()
checks = (
(p.subscribe, 'foo'),
(p.unsubscribe, 'foo'),
(p.psubscribe, 'f*'),
(p.punsubscribe, 'f*'),
)
assert p.subscribed is False
for func, channel in checks:
assert func(channel) is None
assert p.subscribed is True
message = wait_for_message(p, ignore_subscribe_messages=True)
assert message is None
assert p.subscribed is False
class TestPubSubMessages(object):
def setup_method(self, method):
self.message = None
def message_handler(self, message):
self.message = message
def test_published_message_to_channel(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.subscribe('foo')
assert r.publish('foo', 'test message') == 1
message = wait_for_message(p)
assert isinstance(message, dict)
assert message == make_message('message', 'foo', 'test message')
def test_published_message_to_pattern(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.subscribe('foo')
p.psubscribe('f*')
# 1 to pattern, 1 to channel
assert r.publish('foo', 'test message') == 2
message1 = wait_for_message(p)
message2 = wait_for_message(p)
assert isinstance(message1, dict)
assert isinstance(message2, dict)
expected = [
make_message('message', 'foo', 'test message'),
make_message('pmessage', 'foo', 'test message', pattern='f*')
]
assert message1 in expected
assert message2 in expected
assert message1 != message2
def test_channel_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.subscribe(foo=self.message_handler)
assert r.publish('foo', 'test message') == 1
assert wait_for_message(p) is None
assert self.message == make_message('message', 'foo', 'test message')
def test_pattern_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.psubscribe(**{'f*': self.message_handler})
assert r.publish('foo', 'test message') == 1
assert wait_for_message(p) is None
assert self.message == make_message('pmessage', 'foo', 'test message',
pattern='f*')
def test_unicode_channel_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
channel = u('uni') + unichr(4456) + u('code')
channels = {channel: self.message_handler}
p.subscribe(**channels)
assert r.publish(channel, 'test message') == 1
assert wait_for_message(p) is None
assert self.message == make_message('message', channel, 'test message')
def test_unicode_pattern_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
pattern = u('uni') + unichr(4456) + u('*')
channel = u('uni') + unichr(4456) + u('code')
p.psubscribe(**{pattern: self.message_handler})
assert r.publish(channel, 'test message') == 1
assert wait_for_message(p) is None
assert self.message == make_message('pmessage', channel,
'test message', pattern=pattern)
class TestPubSubAutoDecoding(object):
"These tests only validate that we get unicode values back"
channel = u('uni') + unichr(4456) + u('code')
pattern = u('uni') + unichr(4456) + u('*')
data = u('abc') + unichr(4458) + u('123')
def make_message(self, type, channel, data, pattern=None):
return {
'type': type,
'channel': channel,
'pattern': pattern,
'data': data
}
def setup_method(self, method):
self.message = None
def message_handler(self, message):
self.message = message
@pytest.fixture()
def r(self, request):
return _redis_client(request=request, decode_responses=True)
def test_channel_subscribe_unsubscribe(self, r):
p = r.pubsub()
p.subscribe(self.channel)
assert wait_for_message(p) == self.make_message('subscribe',
self.channel, 1)
p.unsubscribe(self.channel)
assert wait_for_message(p) == self.make_message('unsubscribe',
self.channel, 0)
def test_pattern_subscribe_unsubscribe(self, r):
p = r.pubsub()
p.psubscribe(self.pattern)
assert wait_for_message(p) == self.make_message('psubscribe',
self.pattern, 1)
p.punsubscribe(self.pattern)
assert wait_for_message(p) == self.make_message('punsubscribe',
self.pattern, 0)
def test_channel_publish(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.subscribe(self.channel)
r.publish(self.channel, self.data)
assert wait_for_message(p) == self.make_message('message',
self.channel,
self.data)
def test_pattern_publish(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.psubscribe(self.pattern)
r.publish(self.channel, self.data)
assert wait_for_message(p) == self.make_message('pmessage',
self.channel,
self.data,
pattern=self.pattern)
def test_channel_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.subscribe(**{self.channel: self.message_handler})
r.publish(self.channel, self.data)
assert wait_for_message(p) is None
assert self.message == self.make_message('message', self.channel,
self.data)
# test that we reconnected to the correct channel
p.connection.disconnect()
assert wait_for_message(p) is None # should reconnect
new_data = self.data + u('new data')
r.publish(self.channel, new_data)
assert wait_for_message(p) is None
assert self.message == self.make_message('message', self.channel,
new_data)
def test_pattern_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.psubscribe(**{self.pattern: self.message_handler})
r.publish(self.channel, self.data)
assert wait_for_message(p) is None
assert self.message == self.make_message('pmessage', self.channel,
self.data,
pattern=self.pattern)
# test that we reconnected to the correct pattern
p.connection.disconnect()
assert wait_for_message(p) is None # should reconnect
new_data = self.data + u('new data')
r.publish(self.channel, new_data)
assert wait_for_message(p) is None
assert self.message == self.make_message('pmessage', self.channel,
new_data,
pattern=self.pattern)
class TestPubSubRedisDown(object):
def test_channel_subscribe(self, r):
r = redis.Redis(host='localhost', port=6390)
p = r.pubsub()
with pytest.raises(ConnectionError):
p.subscribe('foo')

View File

@ -1,82 +0,0 @@
from __future__ import with_statement
import pytest
from redis import exceptions
from redis._compat import b
multiply_script = """
local value = redis.call('GET', KEYS[1])
value = tonumber(value)
return value * ARGV[1]"""
class TestScripting(object):
@pytest.fixture(autouse=True)
def reset_scripts(self, r):
r.script_flush()
def test_eval(self, r):
r.set('a', 2)
# 2 * 3 == 6
assert r.eval(multiply_script, 1, 'a', 3) == 6
def test_evalsha(self, r):
r.set('a', 2)
sha = r.script_load(multiply_script)
# 2 * 3 == 6
assert r.evalsha(sha, 1, 'a', 3) == 6
def test_evalsha_script_not_loaded(self, r):
r.set('a', 2)
sha = r.script_load(multiply_script)
# remove the script from Redis's cache
r.script_flush()
with pytest.raises(exceptions.NoScriptError):
r.evalsha(sha, 1, 'a', 3)
def test_script_loading(self, r):
# get the sha, then clear the cache
sha = r.script_load(multiply_script)
r.script_flush()
assert r.script_exists(sha) == [False]
r.script_load(multiply_script)
assert r.script_exists(sha) == [True]
def test_script_object(self, r):
r.set('a', 2)
multiply = r.register_script(multiply_script)
assert not multiply.sha
# test evalsha fail -> script load + retry
assert multiply(keys=['a'], args=[3]) == 6
assert multiply.sha
assert r.script_exists(multiply.sha) == [True]
# test first evalsha
assert multiply(keys=['a'], args=[3]) == 6
def test_script_object_in_pipeline(self, r):
multiply = r.register_script(multiply_script)
assert not multiply.sha
pipe = r.pipeline()
pipe.set('a', 2)
pipe.get('a')
multiply(keys=['a'], args=[3], client=pipe)
# even though the pipeline wasn't executed yet, we made sure the
# script was loaded and got a valid sha
assert multiply.sha
assert r.script_exists(multiply.sha) == [True]
# [SET worked, GET 'a', result of multiple script]
assert pipe.execute() == [True, b('2'), 6]
# purge the script from redis's cache and re-run the pipeline
# the multiply script object knows it's sha, so it shouldn't get
# reloaded until pipe.execute()
r.script_flush()
pipe = r.pipeline()
pipe.set('a', 2)
pipe.get('a')
assert multiply.sha
multiply(keys=['a'], args=[3], client=pipe)
assert r.script_exists(multiply.sha) == [False]
# [SET worked, GET 'a', result of multiple script]
assert pipe.execute() == [True, b('2'), 6]

View File

@ -1,173 +0,0 @@
from __future__ import with_statement
import pytest
from redis import exceptions
from redis.sentinel import (Sentinel, SentinelConnectionPool,
MasterNotFoundError, SlaveNotFoundError)
from redis._compat import next
import redis.sentinel
class SentinelTestClient(object):
def __init__(self, cluster, id):
self.cluster = cluster
self.id = id
def sentinel_masters(self):
self.cluster.connection_error_if_down(self)
return {self.cluster.service_name: self.cluster.master}
def sentinel_slaves(self, master_name):
self.cluster.connection_error_if_down(self)
if master_name != self.cluster.service_name:
return []
return self.cluster.slaves
class SentinelTestCluster(object):
def __init__(self, service_name='mymaster', ip='127.0.0.1', port=6379):
self.clients = {}
self.master = {
'ip': ip,
'port': port,
'is_master': True,
'is_sdown': False,
'is_odown': False,
'num-other-sentinels': 0,
}
self.service_name = service_name
self.slaves = []
self.nodes_down = set()
def connection_error_if_down(self, node):
if node.id in self.nodes_down:
raise exceptions.ConnectionError
def client(self, host, port, **kwargs):
return SentinelTestClient(self, (host, port))
@pytest.fixture()
def cluster(request):
def teardown():
redis.sentinel.StrictRedis = saved_StrictRedis
cluster = SentinelTestCluster()
saved_StrictRedis = redis.sentinel.StrictRedis
redis.sentinel.StrictRedis = cluster.client
request.addfinalizer(teardown)
return cluster
@pytest.fixture()
def sentinel(request, cluster):
return Sentinel([('foo', 26379), ('bar', 26379)])
def test_discover_master(sentinel):
address = sentinel.discover_master('mymaster')
assert address == ('127.0.0.1', 6379)
def test_discover_master_error(sentinel):
with pytest.raises(MasterNotFoundError):
sentinel.discover_master('xxx')
def test_discover_master_sentinel_down(cluster, sentinel):
# Put first sentinel 'foo' down
cluster.nodes_down.add(('foo', 26379))
address = sentinel.discover_master('mymaster')
assert address == ('127.0.0.1', 6379)
# 'bar' is now first sentinel
assert sentinel.sentinels[0].id == ('bar', 26379)
def test_master_min_other_sentinels(cluster):
sentinel = Sentinel([('foo', 26379)], min_other_sentinels=1)
# min_other_sentinels
with pytest.raises(MasterNotFoundError):
sentinel.discover_master('mymaster')
cluster.master['num-other-sentinels'] = 2
address = sentinel.discover_master('mymaster')
assert address == ('127.0.0.1', 6379)
def test_master_odown(cluster, sentinel):
cluster.master['is_odown'] = True
with pytest.raises(MasterNotFoundError):
sentinel.discover_master('mymaster')
def test_master_sdown(cluster, sentinel):
cluster.master['is_sdown'] = True
with pytest.raises(MasterNotFoundError):
sentinel.discover_master('mymaster')
def test_discover_slaves(cluster, sentinel):
assert sentinel.discover_slaves('mymaster') == []
cluster.slaves = [
{'ip': 'slave0', 'port': 1234, 'is_odown': False, 'is_sdown': False},
{'ip': 'slave1', 'port': 1234, 'is_odown': False, 'is_sdown': False},
]
assert sentinel.discover_slaves('mymaster') == [
('slave0', 1234), ('slave1', 1234)]
# slave0 -> ODOWN
cluster.slaves[0]['is_odown'] = True
assert sentinel.discover_slaves('mymaster') == [
('slave1', 1234)]
# slave1 -> SDOWN
cluster.slaves[1]['is_sdown'] = True
assert sentinel.discover_slaves('mymaster') == []
cluster.slaves[0]['is_odown'] = False
cluster.slaves[1]['is_sdown'] = False
# node0 -> DOWN
cluster.nodes_down.add(('foo', 26379))
assert sentinel.discover_slaves('mymaster') == [
('slave0', 1234), ('slave1', 1234)]
def test_master_for(cluster, sentinel):
master = sentinel.master_for('mymaster', db=9)
assert master.ping()
assert master.connection_pool.master_address == ('127.0.0.1', 6379)
# Use internal connection check
master = sentinel.master_for('mymaster', db=9, check_connection=True)
assert master.ping()
def test_slave_for(cluster, sentinel):
cluster.slaves = [
{'ip': '127.0.0.1', 'port': 6379,
'is_odown': False, 'is_sdown': False},
]
slave = sentinel.slave_for('mymaster', db=9)
assert slave.ping()
def test_slave_for_slave_not_found_error(cluster, sentinel):
cluster.master['is_odown'] = True
slave = sentinel.slave_for('mymaster', db=9)
with pytest.raises(SlaveNotFoundError):
slave.ping()
def test_slave_round_robin(cluster, sentinel):
cluster.slaves = [
{'ip': 'slave0', 'port': 6379, 'is_odown': False, 'is_sdown': False},
{'ip': 'slave1', 'port': 6379, 'is_odown': False, 'is_sdown': False},
]
pool = SentinelConnectionPool('mymaster', sentinel)
rotator = pool.rotate_slaves()
assert next(rotator) in (('slave0', 6379), ('slave1', 6379))
assert next(rotator) in (('slave0', 6379), ('slave1', 6379))
# Fallback to master
assert next(rotator) == ('127.0.0.1', 6379)
with pytest.raises(SlaveNotFoundError):
next(rotator)