add client for python

This commit is contained in:
silentsai 2014-06-18 10:24:40 +08:00
parent 852fce9f4c
commit a0bd2e90e5
17 changed files with 5849 additions and 0 deletions

View File

@ -0,0 +1,31 @@
from redis.client import Redis, StrictRedis
from redis.connection import (
BlockingConnectionPool,
ConnectionPool,
Connection,
UnixDomainSocketConnection
)
from redis.utils import from_url
from redis.exceptions import (
AuthenticationError,
ConnectionError,
BusyLoadingError,
DataError,
InvalidResponse,
PubSubError,
RedisError,
ResponseError,
WatchError,
)
__version__ = '2.7.6'
VERSION = tuple(map(int, __version__.split('.')))
__all__ = [
'Redis', 'StrictRedis', 'ConnectionPool', 'BlockingConnectionPool',
'Connection', 'UnixDomainSocketConnection',
'RedisError', 'ConnectionError', 'ResponseError', 'AuthenticationError',
'InvalidResponse', 'DataError', 'PubSubError', 'WatchError', 'from_url',
'BusyLoadingError'
]

View File

@ -0,0 +1,79 @@
"""Internal module for Python 2 backwards compatibility."""
import sys
if sys.version_info[0] < 3:
from urlparse import urlparse
from itertools import imap, izip
from string import letters as ascii_letters
from Queue import Queue
try:
from cStringIO import StringIO as BytesIO
except ImportError:
from StringIO import StringIO as BytesIO
iteritems = lambda x: x.iteritems()
iterkeys = lambda x: x.iterkeys()
itervalues = lambda x: x.itervalues()
nativestr = lambda x: \
x if isinstance(x, str) else x.encode('utf-8', 'replace')
u = lambda x: x.decode()
b = lambda x: x
next = lambda x: x.next()
byte_to_chr = lambda x: x
unichr = unichr
xrange = xrange
basestring = basestring
unicode = unicode
bytes = str
long = long
else:
from urllib.parse import urlparse
from io import BytesIO
from string import ascii_letters
from queue import Queue
iteritems = lambda x: iter(x.items())
iterkeys = lambda x: iter(x.keys())
itervalues = lambda x: iter(x.values())
byte_to_chr = lambda x: chr(x)
nativestr = lambda x: \
x if isinstance(x, str) else x.decode('utf-8', 'replace')
u = lambda x: x
b = lambda x: x.encode('iso-8859-1') if not isinstance(x, bytes) else x
next = next
unichr = chr
imap = map
izip = zip
xrange = range
basestring = str
unicode = str
bytes = bytes
long = int
try: # Python 3
from queue import LifoQueue, Empty, Full
except ImportError:
from Queue import Empty, Full
try: # Python 2.6 - 2.7
from Queue import LifoQueue
except ImportError: # Python 2.5
from Queue import Queue
# From the Python 2.7 lib. Python 2.5 already extracted the core
# methods to aid implementating different queue organisations.
class LifoQueue(Queue):
"Override queue methods to implement a last-in first-out queue."
def _init(self, maxsize):
self.maxsize = maxsize
self.queue = []
def _qsize(self, len=len):
return len(self.queue)
def _put(self, item):
self.queue.append(item)
def _get(self):
return self.queue.pop()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,580 @@
from itertools import chain
import os
import socket
import sys
from redis._compat import (b, xrange, imap, byte_to_chr, unicode, bytes, long,
BytesIO, nativestr, basestring,
LifoQueue, Empty, Full)
from redis.exceptions import (
RedisError,
ConnectionError,
BusyLoadingError,
ResponseError,
InvalidResponse,
AuthenticationError,
NoScriptError,
ExecAbortError,
)
from redis.utils import HIREDIS_AVAILABLE
if HIREDIS_AVAILABLE:
import hiredis
SYM_STAR = b('*')
SYM_DOLLAR = b('$')
SYM_CRLF = b('\r\n')
SYM_LF = b('\n')
class PythonParser(object):
"Plain Python parsing class"
MAX_READ_LENGTH = 1000000
encoding = None
EXCEPTION_CLASSES = {
'ERR': ResponseError,
'EXECABORT': ExecAbortError,
'LOADING': BusyLoadingError,
'NOSCRIPT': NoScriptError,
}
def __init__(self):
self._fp = None
def __del__(self):
try:
self.on_disconnect()
except Exception:
pass
def on_connect(self, connection):
"Called when the socket connects"
self._fp = connection._sock.makefile('rb')
if connection.decode_responses:
self.encoding = connection.encoding
def on_disconnect(self):
"Called when the socket disconnects"
if self._fp is not None:
self._fp.close()
self._fp = None
def read(self, length=None):
"""
Read a line from the socket if no length is specified,
otherwise read ``length`` bytes. Always strip away the newlines.
"""
try:
if length is not None:
bytes_left = length + 2 # read the line ending
if length > self.MAX_READ_LENGTH:
# apparently reading more than 1MB or so from a windows
# socket can cause MemoryErrors. See:
# https://github.com/andymccurdy/redis-py/issues/205
# read smaller chunks at a time to work around this
try:
buf = BytesIO()
while bytes_left > 0:
read_len = min(bytes_left, self.MAX_READ_LENGTH)
buf.write(self._fp.read(read_len))
bytes_left -= read_len
buf.seek(0)
return buf.read(length)
finally:
buf.close()
return self._fp.read(bytes_left)[:-2]
# no length, read a full line
return self._fp.readline()[:-2]
except (socket.error, socket.timeout):
e = sys.exc_info()[1]
raise ConnectionError("Error while reading from socket: %s" %
(e.args,))
def parse_error(self, response):
"Parse an error response"
error_code = response.split(' ')[0]
if error_code in self.EXCEPTION_CLASSES:
response = response[len(error_code) + 1:]
return self.EXCEPTION_CLASSES[error_code](response)
return ResponseError(response)
def read_response(self):
response = self.read()
if not response:
raise ConnectionError("Socket closed on remote end")
byte, response = byte_to_chr(response[0]), response[1:]
if byte not in ('-', '+', ':', '$', '*'):
raise InvalidResponse("Protocol Error")
# server returned an error
if byte == '-':
response = nativestr(response)
error = self.parse_error(response)
# if the error is a ConnectionError, raise immediately so the user
# is notified
if isinstance(error, ConnectionError):
raise error
# otherwise, we're dealing with a ResponseError that might belong
# inside a pipeline response. the connection's read_response()
# and/or the pipeline's execute() will raise this error if
# necessary, so just return the exception instance here.
return error
# single value
elif byte == '+':
pass
# int value
elif byte == ':':
response = long(response)
# bulk response
elif byte == '$':
length = int(response)
if length == -1:
return None
response = self.read(length)
# multi-bulk response
elif byte == '*':
length = int(response)
if length == -1:
return None
response = [self.read_response() for i in xrange(length)]
if isinstance(response, bytes) and self.encoding:
response = response.decode(self.encoding)
return response
class HiredisParser(object):
"Parser class for connections using Hiredis"
def __init__(self):
if not HIREDIS_AVAILABLE:
raise RedisError("Hiredis is not installed")
def __del__(self):
try:
self.on_disconnect()
except Exception:
pass
def on_connect(self, connection):
self._sock = connection._sock
kwargs = {
'protocolError': InvalidResponse,
'replyError': ResponseError,
}
if connection.decode_responses:
kwargs['encoding'] = connection.encoding
self._reader = hiredis.Reader(**kwargs)
def on_disconnect(self):
self._sock = None
self._reader = None
def read_response(self):
if not self._reader:
raise ConnectionError("Socket closed on remote end")
response = self._reader.gets()
while response is False:
try:
buffer = self._sock.recv(4096)
except (socket.error, socket.timeout):
e = sys.exc_info()[1]
raise ConnectionError("Error while reading from socket: %s" %
(e.args,))
if not buffer:
raise ConnectionError("Socket closed on remote end")
self._reader.feed(buffer)
# proactively, but not conclusively, check if more data is in the
# buffer. if the data received doesn't end with \n, there's more.
if not buffer.endswith(SYM_LF):
continue
response = self._reader.gets()
return response
if HIREDIS_AVAILABLE:
DefaultParser = HiredisParser
else:
DefaultParser = PythonParser
class Connection(object):
"Manages TCP communication to and from a Redis server"
def __init__(self, host='localhost', port=6379, db=0, password=None,
socket_timeout=None, encoding='utf-8',
encoding_errors='strict', decode_responses=False,
parser_class=DefaultParser):
self.pid = os.getpid()
self.host = host
self.port = port
self.db = db
self.password = password
self.socket_timeout = socket_timeout
self.encoding = encoding
self.encoding_errors = encoding_errors
self.decode_responses = decode_responses
self._sock = None
self._parser = parser_class()
def __del__(self):
try:
self.disconnect()
except Exception:
pass
def connect(self):
"Connects to the Redis server if not already connected"
if self._sock:
return
try:
sock = self._connect()
except socket.error:
e = sys.exc_info()[1]
raise ConnectionError(self._error_message(e))
self._sock = sock
self.on_connect()
def _connect(self):
"Create a TCP socket connection"
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(self.socket_timeout)
sock.connect((self.host, self.port))
return sock
def _error_message(self, exception):
# args for socket.error can either be (errno, "message")
# or just "message"
if len(exception.args) == 1:
return "Error connecting to %s:%s. %s." % \
(self.host, self.port, exception.args[0])
else:
return "Error %s connecting %s:%s. %s." % \
(exception.args[0], self.host, self.port, exception.args[1])
def on_connect(self):
"Initialize the connection, authenticate and select a database"
self._parser.on_connect(self)
# if a password is specified, authenticate
if self.password:
self.send_command('AUTH', self.password)
if nativestr(self.read_response()) != 'OK':
raise AuthenticationError('Invalid Password')
# if a database is specified, switch to it
if self.db:
self.send_command('SELECT', self.db)
if nativestr(self.read_response()) != 'OK':
raise ConnectionError('Invalid Database')
def disconnect(self):
"Disconnects from the Redis server"
self._parser.on_disconnect()
if self._sock is None:
return
try:
self._sock.close()
except socket.error:
pass
self._sock = None
def send_packed_command(self, command):
"Send an already packed command to the Redis server"
if not self._sock:
self.connect()
try:
self._sock.sendall(command)
except socket.error:
e = sys.exc_info()[1]
self.disconnect()
if len(e.args) == 1:
_errno, errmsg = 'UNKNOWN', e.args[0]
else:
_errno, errmsg = e.args
raise ConnectionError("Error %s while writing to socket. %s." %
(_errno, errmsg))
except Exception:
self.disconnect()
raise
def send_command(self, *args):
"Pack and send a command to the Redis server"
self.send_packed_command(self.pack_command(*args))
def read_response(self):
"Read the response from a previously sent command"
try:
response = self._parser.read_response()
except Exception:
self.disconnect()
raise
if isinstance(response, ResponseError):
raise response
return response
def encode(self, value):
"Return a bytestring representation of the value"
if isinstance(value, bytes):
return value
if isinstance(value, float):
value = repr(value)
if not isinstance(value, basestring):
value = str(value)
if isinstance(value, unicode):
value = value.encode(self.encoding, self.encoding_errors)
return value
def pack_command(self, *args):
"Pack a series of arguments into a value Redis command"
output = SYM_STAR + b(str(len(args))) + SYM_CRLF
for enc_value in imap(self.encode, args):
output += SYM_DOLLAR
output += b(str(len(enc_value)))
output += SYM_CRLF
output += enc_value
output += SYM_CRLF
return output
class UnixDomainSocketConnection(Connection):
def __init__(self, path='', db=0, password=None,
socket_timeout=None, encoding='utf-8',
encoding_errors='strict', decode_responses=False,
parser_class=DefaultParser):
self.pid = os.getpid()
self.path = path
self.db = db
self.password = password
self.socket_timeout = socket_timeout
self.encoding = encoding
self.encoding_errors = encoding_errors
self.decode_responses = decode_responses
self._sock = None
self._parser = parser_class()
def _connect(self):
"Create a Unix domain socket connection"
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(self.socket_timeout)
sock.connect(self.path)
return sock
def _error_message(self, exception):
# args for socket.error can either be (errno, "message")
# or just "message"
if len(exception.args) == 1:
return "Error connecting to unix socket: %s. %s." % \
(self.path, exception.args[0])
else:
return "Error %s connecting to unix socket: %s. %s." % \
(exception.args[0], self.path, exception.args[1])
# TODO: add ability to block waiting on a connection to be released
class ConnectionPool(object):
"Generic connection pool"
def __init__(self, connection_class=Connection, max_connections=None,
**connection_kwargs):
self.pid = os.getpid()
self.connection_class = connection_class
self.connection_kwargs = connection_kwargs
self.max_connections = max_connections or 2 ** 31
self._created_connections = 0
self._available_connections = []
self._in_use_connections = set()
def _checkpid(self):
if self.pid != os.getpid():
self.disconnect()
self.__init__(self.connection_class, self.max_connections,
**self.connection_kwargs)
def get_connection(self, command_name, *keys, **options):
"Get a connection from the pool"
self._checkpid()
try:
connection = self._available_connections.pop()
except IndexError:
connection = self.make_connection()
self._in_use_connections.add(connection)
return connection
def make_connection(self):
"Create a new connection"
if self._created_connections >= self.max_connections:
raise ConnectionError("Too many connections")
self._created_connections += 1
return self.connection_class(**self.connection_kwargs)
def release(self, connection):
"Releases the connection back to the pool"
self._checkpid()
if connection.pid == self.pid:
self._in_use_connections.remove(connection)
self._available_connections.append(connection)
def disconnect(self):
"Disconnects all connections in the pool"
all_conns = chain(self._available_connections,
self._in_use_connections)
for connection in all_conns:
connection.disconnect()
class BlockingConnectionPool(object):
"""
Thread-safe blocking connection pool::
>>> from redis.client import Redis
>>> client = Redis(connection_pool=BlockingConnectionPool())
It performs the same function as the default
``:py:class: ~redis.connection.ConnectionPool`` implementation, in that,
it maintains a pool of reusable connections that can be shared by
multiple redis clients (safely across threads if required).
The difference is that, in the event that a client tries to get a
connection from the pool when all of connections are in use, rather than
raising a ``:py:class: ~redis.exceptions.ConnectionError`` (as the default
``:py:class: ~redis.connection.ConnectionPool`` implementation does), it
makes the client wait ("blocks") for a specified number of seconds until
a connection becomes available.
Use ``max_connections`` to increase / decrease the pool size::
>>> pool = BlockingConnectionPool(max_connections=10)
Use ``timeout`` to tell it either how many seconds to wait for a connection
to become available, or to block forever:
# Block forever.
>>> pool = BlockingConnectionPool(timeout=None)
# Raise a ``ConnectionError`` after five seconds if a connection is
# not available.
>>> pool = BlockingConnectionPool(timeout=5)
"""
def __init__(self, max_connections=50, timeout=20, connection_class=None,
queue_class=None, **connection_kwargs):
"Compose and assign values."
# Compose.
if connection_class is None:
connection_class = Connection
if queue_class is None:
queue_class = LifoQueue
# Assign.
self.connection_class = connection_class
self.connection_kwargs = connection_kwargs
self.queue_class = queue_class
self.max_connections = max_connections
self.timeout = timeout
# Validate the ``max_connections``. With the "fill up the queue"
# algorithm we use, it must be a positive integer.
is_valid = isinstance(max_connections, int) and max_connections > 0
if not is_valid:
raise ValueError('``max_connections`` must be a positive integer')
# Get the current process id, so we can disconnect and reinstantiate if
# it changes.
self.pid = os.getpid()
# Create and fill up a thread safe queue with ``None`` values.
self.pool = self.queue_class(max_connections)
while True:
try:
self.pool.put_nowait(None)
except Full:
break
# Keep a list of actual connection instances so that we can
# disconnect them later.
self._connections = []
def _checkpid(self):
"""
Check the current process id. If it has changed, disconnect and
re-instantiate this connection pool instance.
"""
# Get the current process id.
pid = os.getpid()
# If it hasn't changed since we were instantiated, then we're fine, so
# just exit, remaining connected.
if self.pid == pid:
return
# If it has changed, then disconnect and re-instantiate.
self.disconnect()
self.reinstantiate()
def make_connection(self):
"Make a fresh connection."
connection = self.connection_class(**self.connection_kwargs)
self._connections.append(connection)
return connection
def get_connection(self, command_name, *keys, **options):
"""
Get a connection, blocking for ``self.timeout`` until a connection
is available from the pool.
If the connection returned is ``None`` then creates a new connection.
Because we use a last-in first-out queue, the existing connections
(having been returned to the pool after the initial ``None`` values
were added) will be returned before ``None`` values. This means we only
create new connections when we need to, i.e.: the actual number of
connections will only increase in response to demand.
"""
# Make sure we haven't changed process.
self._checkpid()
# Try and get a connection from the pool. If one isn't available within
# self.timeout then raise a ``ConnectionError``.
connection = None
try:
connection = self.pool.get(block=True, timeout=self.timeout)
except Empty:
# Note that this is not caught by the redis client and will be
# raised unless handled by application code. If you want never to
raise ConnectionError("No connection available.")
# If the ``connection`` is actually ``None`` then that's a cue to make
# a new connection to add to the pool.
if connection is None:
connection = self.make_connection()
return connection
def release(self, connection):
"Releases the connection back to the pool."
# Make sure we haven't changed process.
self._checkpid()
# Put the connection back into the pool.
try:
self.pool.put_nowait(connection)
except Full:
# This shouldn't normally happen but might perhaps happen after a
# reinstantiation. So, we can handle the exception by not putting
# the connection back on the pool, because we definitely do not
# want to reuse it.
pass
def disconnect(self):
"Disconnects all connections in the pool."
for connection in self._connections:
connection.disconnect()
def reinstantiate(self):
"""
Reinstatiate this instance within a new process with a new connection
pool set.
"""
self.__init__(max_connections=self.max_connections,
timeout=self.timeout,
connection_class=self.connection_class,
queue_class=self.queue_class, **self.connection_kwargs)

View File

@ -0,0 +1,49 @@
"Core exceptions raised by the Redis client"
class RedisError(Exception):
pass
class AuthenticationError(RedisError):
pass
class ServerError(RedisError):
pass
class ConnectionError(ServerError):
pass
class BusyLoadingError(ConnectionError):
pass
class InvalidResponse(ServerError):
pass
class ResponseError(RedisError):
pass
class DataError(RedisError):
pass
class PubSubError(RedisError):
pass
class WatchError(RedisError):
pass
class NoScriptError(ResponseError):
pass
class ExecAbortError(ResponseError):
pass

View File

@ -0,0 +1,16 @@
try:
import hiredis
HIREDIS_AVAILABLE = True
except ImportError:
HIREDIS_AVAILABLE = False
def from_url(url, db=None, **kwargs):
"""
Returns an active Redis client generated from the given database URL.
Will attempt to extract the database id from the path url fragment, if
none is provided.
"""
from redis.client import Redis
return Redis.from_url(url, db, **kwargs)

61
client/ledis-py/setup.py Normal file
View File

@ -0,0 +1,61 @@
#!/usr/bin/env python
import os
import sys
from redis import __version__
try:
from setuptools import setup
from setuptools.command.test import test as TestCommand
class PyTest(TestCommand):
def finalize_options(self):
TestCommand.finalize_options(self)
self.test_args = []
self.test_suite = True
def run_tests(self):
# import here, because outside the eggs aren't loaded
import pytest
errno = pytest.main(self.test_args)
sys.exit(errno)
except ImportError:
from distutils.core import setup
PyTest = lambda x: x
f = open(os.path.join(os.path.dirname(__file__), 'README.rst'))
long_description = f.read()
f.close()
setup(
name='redis',
version=__version__,
description='Python client for Redis key-value store',
long_description=long_description,
url='http://github.com/andymccurdy/redis-py',
author='Andy McCurdy',
author_email='sedrik@gmail.com',
maintainer='Andy McCurdy',
maintainer_email='sedrik@gmail.com',
keywords=['Redis', 'key-value store'],
license='MIT',
packages=['redis'],
tests_require=['pytest>=2.5.0'],
cmdclass={'test': PyTest},
classifiers=[
'Development Status :: 5 - Production/Stable',
'Environment :: Console',
'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License',
'Operating System :: OS Independent',
'Programming Language :: Python',
'Programming Language :: Python :: 2.6',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.2',
'Programming Language :: Python :: 3.3',
'Programming Language :: Python :: 3.4',
]
)

View File

View File

@ -0,0 +1,46 @@
import pytest
import redis
from distutils.version import StrictVersion
_REDIS_VERSIONS = {}
def get_version(**kwargs):
params = {'host': 'localhost', 'port': 6379, 'db': 9}
params.update(kwargs)
key = '%s:%s' % (params['host'], params['port'])
if key not in _REDIS_VERSIONS:
client = redis.Redis(**params)
_REDIS_VERSIONS[key] = client.info()['redis_version']
client.connection_pool.disconnect()
return _REDIS_VERSIONS[key]
def _get_client(cls, request=None, **kwargs):
params = {'host': 'localhost', 'port': 6379, 'db': 9}
params.update(kwargs)
client = cls(**params)
client.flushdb()
if request:
def teardown():
client.flushdb()
client.connection_pool.disconnect()
request.addfinalizer(teardown)
return client
def skip_if_server_version_lt(min_version):
check = StrictVersion(get_version()) < StrictVersion(min_version)
return pytest.mark.skipif(check, reason="")
@pytest.fixture()
def r(request, **kwargs):
return _get_client(redis.Redis, request, **kwargs)
@pytest.fixture()
def sr(request, **kwargs):
return _get_client(redis.StrictRedis, request, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,402 @@
from __future__ import with_statement
import os
import pytest
import redis
import time
import re
from threading import Thread
from redis.connection import ssl_available
from .conftest import skip_if_server_version_lt
class DummyConnection(object):
description_format = "DummyConnection<>"
def __init__(self, **kwargs):
self.kwargs = kwargs
self.pid = os.getpid()
class TestConnectionPool(object):
def get_pool(self, connection_kwargs=None, max_connections=None,
connection_class=DummyConnection):
connection_kwargs = connection_kwargs or {}
pool = redis.ConnectionPool(
connection_class=connection_class,
max_connections=max_connections,
**connection_kwargs)
return pool
def test_connection_creation(self):
connection_kwargs = {'foo': 'bar', 'biz': 'baz'}
pool = self.get_pool(connection_kwargs=connection_kwargs)
connection = pool.get_connection('_')
assert isinstance(connection, DummyConnection)
assert connection.kwargs == connection_kwargs
def test_multiple_connections(self):
pool = self.get_pool()
c1 = pool.get_connection('_')
c2 = pool.get_connection('_')
assert c1 != c2
def test_max_connections(self):
pool = self.get_pool(max_connections=2)
pool.get_connection('_')
pool.get_connection('_')
with pytest.raises(redis.ConnectionError):
pool.get_connection('_')
def test_reuse_previously_released_connection(self):
pool = self.get_pool()
c1 = pool.get_connection('_')
pool.release(c1)
c2 = pool.get_connection('_')
assert c1 == c2
def test_repr_contains_db_info_tcp(self):
connection_kwargs = {'host': 'localhost', 'port': 6379, 'db': 1}
pool = self.get_pool(connection_kwargs=connection_kwargs,
connection_class=redis.Connection)
expected = 'ConnectionPool<Connection<host=localhost,port=6379,db=1>>'
assert repr(pool) == expected
def test_repr_contains_db_info_unix(self):
connection_kwargs = {'path': '/abc', 'db': 1}
pool = self.get_pool(connection_kwargs=connection_kwargs,
connection_class=redis.UnixDomainSocketConnection)
expected = 'ConnectionPool<UnixDomainSocketConnection<path=/abc,db=1>>'
assert repr(pool) == expected
class TestBlockingConnectionPool(object):
def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20):
connection_kwargs = connection_kwargs or {}
pool = redis.BlockingConnectionPool(connection_class=DummyConnection,
max_connections=max_connections,
timeout=timeout,
**connection_kwargs)
return pool
def test_connection_creation(self):
connection_kwargs = {'foo': 'bar', 'biz': 'baz'}
pool = self.get_pool(connection_kwargs=connection_kwargs)
connection = pool.get_connection('_')
assert isinstance(connection, DummyConnection)
assert connection.kwargs == connection_kwargs
def test_multiple_connections(self):
pool = self.get_pool()
c1 = pool.get_connection('_')
c2 = pool.get_connection('_')
assert c1 != c2
def test_connection_pool_blocks_until_timeout(self):
"When out of connections, block for timeout seconds, then raise"
pool = self.get_pool(max_connections=1, timeout=0.1)
pool.get_connection('_')
start = time.time()
with pytest.raises(redis.ConnectionError):
pool.get_connection('_')
# we should have waited at least 0.1 seconds
assert time.time() - start >= 0.1
def connection_pool_blocks_until_another_connection_released(self):
"""
When out of connections, block until another connection is released
to the pool
"""
pool = self.get_pool(max_connections=1, timeout=2)
c1 = pool.get_connection('_')
def target():
time.sleep(0.1)
pool.release(c1)
Thread(target=target).start()
start = time.time()
pool.get_connection('_')
assert time.time() - start >= 0.1
def test_reuse_previously_released_connection(self):
pool = self.get_pool()
c1 = pool.get_connection('_')
pool.release(c1)
c2 = pool.get_connection('_')
assert c1 == c2
def test_repr_contains_db_info_tcp(self):
pool = redis.ConnectionPool(host='localhost', port=6379, db=0)
expected = 'ConnectionPool<Connection<host=localhost,port=6379,db=0>>'
assert repr(pool) == expected
def test_repr_contains_db_info_unix(self):
pool = redis.ConnectionPool(
connection_class=redis.UnixDomainSocketConnection,
path='abc',
db=0,
)
expected = 'ConnectionPool<UnixDomainSocketConnection<path=abc,db=0>>'
assert repr(pool) == expected
class TestConnectionPoolURLParsing(object):
def test_defaults(self):
pool = redis.ConnectionPool.from_url('redis://localhost')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 0,
'password': None,
}
def test_hostname(self):
pool = redis.ConnectionPool.from_url('redis://myhost')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'myhost',
'port': 6379,
'db': 0,
'password': None,
}
def test_port(self):
pool = redis.ConnectionPool.from_url('redis://localhost:6380')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6380,
'db': 0,
'password': None,
}
def test_password(self):
pool = redis.ConnectionPool.from_url('redis://:mypassword@localhost')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 0,
'password': 'mypassword',
}
def test_db_as_argument(self):
pool = redis.ConnectionPool.from_url('redis://localhost', db='1')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 1,
'password': None,
}
def test_db_in_path(self):
pool = redis.ConnectionPool.from_url('redis://localhost/2', db='1')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 2,
'password': None,
}
def test_db_in_querystring(self):
pool = redis.ConnectionPool.from_url('redis://localhost/2?db=3',
db='1')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 3,
'password': None,
}
def test_extra_querystring_options(self):
pool = redis.ConnectionPool.from_url('redis://localhost?a=1&b=2')
assert pool.connection_class == redis.Connection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 0,
'password': None,
'a': '1',
'b': '2'
}
def test_calling_from_subclass_returns_correct_instance(self):
pool = redis.BlockingConnectionPool.from_url('redis://localhost')
assert isinstance(pool, redis.BlockingConnectionPool)
def test_client_creates_connection_pool(self):
r = redis.StrictRedis.from_url('redis://myhost')
assert r.connection_pool.connection_class == redis.Connection
assert r.connection_pool.connection_kwargs == {
'host': 'myhost',
'port': 6379,
'db': 0,
'password': None,
}
class TestConnectionPoolUnixSocketURLParsing(object):
def test_defaults(self):
pool = redis.ConnectionPool.from_url('unix:///socket')
assert pool.connection_class == redis.UnixDomainSocketConnection
assert pool.connection_kwargs == {
'path': '/socket',
'db': 0,
'password': None,
}
def test_password(self):
pool = redis.ConnectionPool.from_url('unix://:mypassword@/socket')
assert pool.connection_class == redis.UnixDomainSocketConnection
assert pool.connection_kwargs == {
'path': '/socket',
'db': 0,
'password': 'mypassword',
}
def test_db_as_argument(self):
pool = redis.ConnectionPool.from_url('unix:///socket', db=1)
assert pool.connection_class == redis.UnixDomainSocketConnection
assert pool.connection_kwargs == {
'path': '/socket',
'db': 1,
'password': None,
}
def test_db_in_querystring(self):
pool = redis.ConnectionPool.from_url('unix:///socket?db=2', db=1)
assert pool.connection_class == redis.UnixDomainSocketConnection
assert pool.connection_kwargs == {
'path': '/socket',
'db': 2,
'password': None,
}
def test_extra_querystring_options(self):
pool = redis.ConnectionPool.from_url('unix:///socket?a=1&b=2')
assert pool.connection_class == redis.UnixDomainSocketConnection
assert pool.connection_kwargs == {
'path': '/socket',
'db': 0,
'password': None,
'a': '1',
'b': '2'
}
class TestSSLConnectionURLParsing(object):
@pytest.mark.skipif(not ssl_available, reason="SSL not installed")
def test_defaults(self):
pool = redis.ConnectionPool.from_url('rediss://localhost')
assert pool.connection_class == redis.SSLConnection
assert pool.connection_kwargs == {
'host': 'localhost',
'port': 6379,
'db': 0,
'password': None,
}
@pytest.mark.skipif(not ssl_available, reason="SSL not installed")
def test_cert_reqs_options(self):
import ssl
pool = redis.ConnectionPool.from_url('rediss://?ssl_cert_reqs=none')
assert pool.get_connection('_').cert_reqs == ssl.CERT_NONE
pool = redis.ConnectionPool.from_url(
'rediss://?ssl_cert_reqs=optional')
assert pool.get_connection('_').cert_reqs == ssl.CERT_OPTIONAL
pool = redis.ConnectionPool.from_url(
'rediss://?ssl_cert_reqs=required')
assert pool.get_connection('_').cert_reqs == ssl.CERT_REQUIRED
class TestConnection(object):
def test_on_connect_error(self):
"""
An error in Connection.on_connect should disconnect from the server
see for details: https://github.com/andymccurdy/redis-py/issues/368
"""
# this assumes the Redis server being tested against doesn't have
# 9999 databases ;)
bad_connection = redis.Redis(db=9999)
# an error should be raised on connect
with pytest.raises(redis.RedisError):
bad_connection.info()
pool = bad_connection.connection_pool
assert len(pool._available_connections) == 1
assert not pool._available_connections[0]._sock
@skip_if_server_version_lt('2.8.8')
def test_busy_loading_disconnects_socket(self, r):
"""
If Redis raises a LOADING error, the connection should be
disconnected and a BusyLoadingError raised
"""
with pytest.raises(redis.BusyLoadingError):
r.execute_command('DEBUG', 'ERROR', 'LOADING fake message')
pool = r.connection_pool
assert len(pool._available_connections) == 1
assert not pool._available_connections[0]._sock
@skip_if_server_version_lt('2.8.8')
def test_busy_loading_from_pipeline_immediate_command(self, r):
"""
BusyLoadingErrors should raise from Pipelines that execute a
command immediately, like WATCH does.
"""
pipe = r.pipeline()
with pytest.raises(redis.BusyLoadingError):
pipe.immediate_execute_command('DEBUG', 'ERROR',
'LOADING fake message')
pool = r.connection_pool
assert not pipe.connection
assert len(pool._available_connections) == 1
assert not pool._available_connections[0]._sock
@skip_if_server_version_lt('2.8.8')
def test_busy_loading_from_pipeline(self, r):
"""
BusyLoadingErrors should be raised from a pipeline execution
regardless of the raise_on_error flag.
"""
pipe = r.pipeline()
pipe.execute_command('DEBUG', 'ERROR', 'LOADING fake message')
with pytest.raises(redis.BusyLoadingError):
pipe.execute()
pool = r.connection_pool
assert not pipe.connection
assert len(pool._available_connections) == 1
assert not pool._available_connections[0]._sock
@skip_if_server_version_lt('2.8.8')
def test_read_only_error(self, r):
"READONLY errors get turned in ReadOnlyError exceptions"
with pytest.raises(redis.ReadOnlyError):
r.execute_command('DEBUG', 'ERROR', 'READONLY blah blah')
def test_connect_from_url_tcp(self):
connection = redis.Redis.from_url('redis://localhost')
pool = connection.connection_pool
assert re.match('(.*)<(.*)<(.*)>>', repr(pool)).groups() == (
'ConnectionPool',
'Connection',
'host=localhost,port=6379,db=0',
)
def test_connect_from_url_unix(self):
connection = redis.Redis.from_url('unix:///path/to/socket')
pool = connection.connection_pool
assert re.match('(.*)<(.*)<(.*)>>', repr(pool)).groups() == (
'ConnectionPool',
'UnixDomainSocketConnection',
'path=/path/to/socket,db=0',
)

View File

@ -0,0 +1,33 @@
from __future__ import with_statement
import pytest
from redis._compat import unichr, u, unicode
from .conftest import r as _redis_client
class TestEncoding(object):
@pytest.fixture()
def r(self, request):
return _redis_client(request=request, decode_responses=True)
def test_simple_encoding(self, r):
unicode_string = unichr(3456) + u('abcd') + unichr(3421)
r['unicode-string'] = unicode_string
cached_val = r['unicode-string']
assert isinstance(cached_val, unicode)
assert unicode_string == cached_val
def test_list_encoding(self, r):
unicode_string = unichr(3456) + u('abcd') + unichr(3421)
result = [unicode_string, unicode_string, unicode_string]
r.rpush('a', *result)
assert r.lrange('a', 0, -1) == result
class TestCommandsAndTokensArentEncoded(object):
@pytest.fixture()
def r(self, request):
return _redis_client(request=request, charset='utf-16')
def test_basic_command(self, r):
r.set('hello', 'world')

View File

@ -0,0 +1,167 @@
from __future__ import with_statement
import pytest
import time
from redis.exceptions import LockError, ResponseError
from redis.lock import Lock, LuaLock
class TestLock(object):
lock_class = Lock
def get_lock(self, redis, *args, **kwargs):
kwargs['lock_class'] = self.lock_class
return redis.lock(*args, **kwargs)
def test_lock(self, sr):
lock = self.get_lock(sr, 'foo')
assert lock.acquire(blocking=False)
assert sr.get('foo') == lock.local.token
assert sr.ttl('foo') == -1
lock.release()
assert sr.get('foo') is None
def test_competing_locks(self, sr):
lock1 = self.get_lock(sr, 'foo')
lock2 = self.get_lock(sr, 'foo')
assert lock1.acquire(blocking=False)
assert not lock2.acquire(blocking=False)
lock1.release()
assert lock2.acquire(blocking=False)
assert not lock1.acquire(blocking=False)
lock2.release()
def test_timeout(self, sr):
lock = self.get_lock(sr, 'foo', timeout=10)
assert lock.acquire(blocking=False)
assert 8 < sr.ttl('foo') <= 10
lock.release()
def test_float_timeout(self, sr):
lock = self.get_lock(sr, 'foo', timeout=9.5)
assert lock.acquire(blocking=False)
assert 8 < sr.pttl('foo') <= 9500
lock.release()
def test_blocking_timeout(self, sr):
lock1 = self.get_lock(sr, 'foo')
assert lock1.acquire(blocking=False)
lock2 = self.get_lock(sr, 'foo', blocking_timeout=0.2)
start = time.time()
assert not lock2.acquire()
assert (time.time() - start) > 0.2
lock1.release()
def test_context_manager(self, sr):
# blocking_timeout prevents a deadlock if the lock can't be acquired
# for some reason
with self.get_lock(sr, 'foo', blocking_timeout=0.2) as lock:
assert sr.get('foo') == lock.local.token
assert sr.get('foo') is None
def test_high_sleep_raises_error(self, sr):
"If sleep is higher than timeout, it should raise an error"
with pytest.raises(LockError):
self.get_lock(sr, 'foo', timeout=1, sleep=2)
def test_releasing_unlocked_lock_raises_error(self, sr):
lock = self.get_lock(sr, 'foo')
with pytest.raises(LockError):
lock.release()
def test_releasing_lock_no_longer_owned_raises_error(self, sr):
lock = self.get_lock(sr, 'foo')
lock.acquire(blocking=False)
# manually change the token
sr.set('foo', 'a')
with pytest.raises(LockError):
lock.release()
# even though we errored, the token is still cleared
assert lock.local.token is None
def test_extend_lock(self, sr):
lock = self.get_lock(sr, 'foo', timeout=10)
assert lock.acquire(blocking=False)
assert 8000 < sr.pttl('foo') <= 10000
assert lock.extend(10)
assert 16000 < sr.pttl('foo') <= 20000
lock.release()
def test_extend_lock_float(self, sr):
lock = self.get_lock(sr, 'foo', timeout=10.0)
assert lock.acquire(blocking=False)
assert 8000 < sr.pttl('foo') <= 10000
assert lock.extend(10.0)
assert 16000 < sr.pttl('foo') <= 20000
lock.release()
def test_extending_unlocked_lock_raises_error(self, sr):
lock = self.get_lock(sr, 'foo', timeout=10)
with pytest.raises(LockError):
lock.extend(10)
def test_extending_lock_with_no_timeout_raises_error(self, sr):
lock = self.get_lock(sr, 'foo')
assert lock.acquire(blocking=False)
with pytest.raises(LockError):
lock.extend(10)
lock.release()
def test_extending_lock_no_longer_owned_raises_error(self, sr):
lock = self.get_lock(sr, 'foo')
assert lock.acquire(blocking=False)
sr.set('foo', 'a')
with pytest.raises(LockError):
lock.extend(10)
class TestLuaLock(TestLock):
lock_class = LuaLock
class TestLockClassSelection(object):
def test_lock_class_argument(self, sr):
lock = sr.lock('foo', lock_class=Lock)
assert type(lock) == Lock
lock = sr.lock('foo', lock_class=LuaLock)
assert type(lock) == LuaLock
def test_cached_lualock_flag(self, sr):
try:
sr._use_lua_lock = True
lock = sr.lock('foo')
assert type(lock) == LuaLock
finally:
sr._use_lua_lock = None
def test_cached_lock_flag(self, sr):
try:
sr._use_lua_lock = False
lock = sr.lock('foo')
assert type(lock) == Lock
finally:
sr._use_lua_lock = None
def test_lua_compatible_server(self, sr, monkeypatch):
@classmethod
def mock_register(cls, redis):
return
monkeypatch.setattr(LuaLock, 'register_scripts', mock_register)
try:
lock = sr.lock('foo')
assert type(lock) == LuaLock
assert sr._use_lua_lock is True
finally:
sr._use_lua_lock = None
def test_lua_unavailable(self, sr, monkeypatch):
@classmethod
def mock_register(cls, redis):
raise ResponseError()
monkeypatch.setattr(LuaLock, 'register_scripts', mock_register)
try:
lock = sr.lock('foo')
assert type(lock) == Lock
assert sr._use_lua_lock is False
finally:
sr._use_lua_lock = None

View File

@ -0,0 +1,226 @@
from __future__ import with_statement
import pytest
import redis
from redis._compat import b, u, unichr, unicode
class TestPipeline(object):
def test_pipeline(self, r):
with r.pipeline() as pipe:
pipe.set('a', 'a1').get('a').zadd('z', z1=1).zadd('z', z2=4)
pipe.zincrby('z', 'z1').zrange('z', 0, 5, withscores=True)
assert pipe.execute() == \
[
True,
b('a1'),
True,
True,
2.0,
[(b('z1'), 2.0), (b('z2'), 4)],
]
def test_pipeline_length(self, r):
with r.pipeline() as pipe:
# Initially empty.
assert len(pipe) == 0
assert not pipe
# Fill 'er up!
pipe.set('a', 'a1').set('b', 'b1').set('c', 'c1')
assert len(pipe) == 3
assert pipe
# Execute calls reset(), so empty once again.
pipe.execute()
assert len(pipe) == 0
assert not pipe
def test_pipeline_no_transaction(self, r):
with r.pipeline(transaction=False) as pipe:
pipe.set('a', 'a1').set('b', 'b1').set('c', 'c1')
assert pipe.execute() == [True, True, True]
assert r['a'] == b('a1')
assert r['b'] == b('b1')
assert r['c'] == b('c1')
def test_pipeline_no_transaction_watch(self, r):
r['a'] = 0
with r.pipeline(transaction=False) as pipe:
pipe.watch('a')
a = pipe.get('a')
pipe.multi()
pipe.set('a', int(a) + 1)
assert pipe.execute() == [True]
def test_pipeline_no_transaction_watch_failure(self, r):
r['a'] = 0
with r.pipeline(transaction=False) as pipe:
pipe.watch('a')
a = pipe.get('a')
r['a'] = 'bad'
pipe.multi()
pipe.set('a', int(a) + 1)
with pytest.raises(redis.WatchError):
pipe.execute()
assert r['a'] == b('bad')
def test_exec_error_in_response(self, r):
"""
an invalid pipeline command at exec time adds the exception instance
to the list of returned values
"""
r['c'] = 'a'
with r.pipeline() as pipe:
pipe.set('a', 1).set('b', 2).lpush('c', 3).set('d', 4)
result = pipe.execute(raise_on_error=False)
assert result[0]
assert r['a'] == b('1')
assert result[1]
assert r['b'] == b('2')
# we can't lpush to a key that's a string value, so this should
# be a ResponseError exception
assert isinstance(result[2], redis.ResponseError)
assert r['c'] == b('a')
# since this isn't a transaction, the other commands after the
# error are still executed
assert result[3]
assert r['d'] == b('4')
# make sure the pipe was restored to a working state
assert pipe.set('z', 'zzz').execute() == [True]
assert r['z'] == b('zzz')
def test_exec_error_raised(self, r):
r['c'] = 'a'
with r.pipeline() as pipe:
pipe.set('a', 1).set('b', 2).lpush('c', 3).set('d', 4)
with pytest.raises(redis.ResponseError) as ex:
pipe.execute()
assert unicode(ex.value).startswith('Command # 3 (LPUSH c 3) of '
'pipeline caused error: ')
# make sure the pipe was restored to a working state
assert pipe.set('z', 'zzz').execute() == [True]
assert r['z'] == b('zzz')
def test_parse_error_raised(self, r):
with r.pipeline() as pipe:
# the zrem is invalid because we don't pass any keys to it
pipe.set('a', 1).zrem('b').set('b', 2)
with pytest.raises(redis.ResponseError) as ex:
pipe.execute()
assert unicode(ex.value).startswith('Command # 2 (ZREM b) of '
'pipeline caused error: ')
# make sure the pipe was restored to a working state
assert pipe.set('z', 'zzz').execute() == [True]
assert r['z'] == b('zzz')
def test_watch_succeed(self, r):
r['a'] = 1
r['b'] = 2
with r.pipeline() as pipe:
pipe.watch('a', 'b')
assert pipe.watching
a_value = pipe.get('a')
b_value = pipe.get('b')
assert a_value == b('1')
assert b_value == b('2')
pipe.multi()
pipe.set('c', 3)
assert pipe.execute() == [True]
assert not pipe.watching
def test_watch_failure(self, r):
r['a'] = 1
r['b'] = 2
with r.pipeline() as pipe:
pipe.watch('a', 'b')
r['b'] = 3
pipe.multi()
pipe.get('a')
with pytest.raises(redis.WatchError):
pipe.execute()
assert not pipe.watching
def test_unwatch(self, r):
r['a'] = 1
r['b'] = 2
with r.pipeline() as pipe:
pipe.watch('a', 'b')
r['b'] = 3
pipe.unwatch()
assert not pipe.watching
pipe.get('a')
assert pipe.execute() == [b('1')]
def test_transaction_callable(self, r):
r['a'] = 1
r['b'] = 2
has_run = []
def my_transaction(pipe):
a_value = pipe.get('a')
assert a_value in (b('1'), b('2'))
b_value = pipe.get('b')
assert b_value == b('2')
# silly run-once code... incr's "a" so WatchError should be raised
# forcing this all to run again. this should incr "a" once to "2"
if not has_run:
r.incr('a')
has_run.append('it has')
pipe.multi()
pipe.set('c', int(a_value) + int(b_value))
result = r.transaction(my_transaction, 'a', 'b')
assert result == [True]
assert r['c'] == b('4')
def test_exec_error_in_no_transaction_pipeline(self, r):
r['a'] = 1
with r.pipeline(transaction=False) as pipe:
pipe.llen('a')
pipe.expire('a', 100)
with pytest.raises(redis.ResponseError) as ex:
pipe.execute()
assert unicode(ex.value).startswith('Command # 1 (LLEN a) of '
'pipeline caused error: ')
assert r['a'] == b('1')
def test_exec_error_in_no_transaction_pipeline_unicode_command(self, r):
key = unichr(3456) + u('abcd') + unichr(3421)
r[key] = 1
with r.pipeline(transaction=False) as pipe:
pipe.llen(key)
pipe.expire(key, 100)
with pytest.raises(redis.ResponseError) as ex:
pipe.execute()
expected = unicode('Command # 1 (LLEN %s) of pipeline caused '
'error: ') % key
assert unicode(ex.value).startswith(expected)
assert r[key] == b('1')

View File

@ -0,0 +1,392 @@
from __future__ import with_statement
import pytest
import time
import redis
from redis.exceptions import ConnectionError
from redis._compat import basestring, u, unichr
from .conftest import r as _redis_client
def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False):
now = time.time()
timeout = now + timeout
while now < timeout:
message = pubsub.get_message(
ignore_subscribe_messages=ignore_subscribe_messages)
if message is not None:
return message
time.sleep(0.01)
now = time.time()
return None
def make_message(type, channel, data, pattern=None):
return {
'type': type,
'pattern': pattern and pattern.encode('utf-8') or None,
'channel': channel.encode('utf-8'),
'data': data.encode('utf-8') if isinstance(data, basestring) else data
}
def make_subscribe_test_data(pubsub, type):
if type == 'channel':
return {
'p': pubsub,
'sub_type': 'subscribe',
'unsub_type': 'unsubscribe',
'sub_func': pubsub.subscribe,
'unsub_func': pubsub.unsubscribe,
'keys': ['foo', 'bar', u('uni') + unichr(4456) + u('code')]
}
elif type == 'pattern':
return {
'p': pubsub,
'sub_type': 'psubscribe',
'unsub_type': 'punsubscribe',
'sub_func': pubsub.psubscribe,
'unsub_func': pubsub.punsubscribe,
'keys': ['f*', 'b*', u('uni') + unichr(4456) + u('*')]
}
assert False, 'invalid subscribe type: %s' % type
class TestPubSubSubscribeUnsubscribe(object):
def _test_subscribe_unsubscribe(self, p, sub_type, unsub_type, sub_func,
unsub_func, keys):
for key in keys:
assert sub_func(key) is None
# should be a message for each channel/pattern we just subscribed to
for i, key in enumerate(keys):
assert wait_for_message(p) == make_message(sub_type, key, i + 1)
for key in keys:
assert unsub_func(key) is None
# should be a message for each channel/pattern we just unsubscribed
# from
for i, key in enumerate(keys):
i = len(keys) - 1 - i
assert wait_for_message(p) == make_message(unsub_type, key, i)
def test_channel_subscribe_unsubscribe(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), 'channel')
self._test_subscribe_unsubscribe(**kwargs)
def test_pattern_subscribe_unsubscribe(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), 'pattern')
self._test_subscribe_unsubscribe(**kwargs)
def _test_resubscribe_on_reconnection(self, p, sub_type, unsub_type,
sub_func, unsub_func, keys):
for key in keys:
assert sub_func(key) is None
# should be a message for each channel/pattern we just subscribed to
for i, key in enumerate(keys):
assert wait_for_message(p) == make_message(sub_type, key, i + 1)
# manually disconnect
p.connection.disconnect()
# calling get_message again reconnects and resubscribes
# note, we may not re-subscribe to channels in exactly the same order
# so we have to do some extra checks to make sure we got them all
messages = []
for i in range(len(keys)):
messages.append(wait_for_message(p))
unique_channels = set()
assert len(messages) == len(keys)
for i, message in enumerate(messages):
assert message['type'] == sub_type
assert message['data'] == i + 1
assert isinstance(message['channel'], bytes)
channel = message['channel'].decode('utf-8')
unique_channels.add(channel)
assert len(unique_channels) == len(keys)
for channel in unique_channels:
assert channel in keys
def test_resubscribe_to_channels_on_reconnection(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), 'channel')
self._test_resubscribe_on_reconnection(**kwargs)
def test_resubscribe_to_patterns_on_reconnection(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), 'pattern')
self._test_resubscribe_on_reconnection(**kwargs)
def _test_subscribed_property(self, p, sub_type, unsub_type, sub_func,
unsub_func, keys):
assert p.subscribed is False
sub_func(keys[0])
# we're now subscribed even though we haven't processed the
# reply from the server just yet
assert p.subscribed is True
assert wait_for_message(p) == make_message(sub_type, keys[0], 1)
# we're still subscribed
assert p.subscribed is True
# unsubscribe from all channels
unsub_func()
# we're still technically subscribed until we process the
# response messages from the server
assert p.subscribed is True
assert wait_for_message(p) == make_message(unsub_type, keys[0], 0)
# now we're no longer subscribed as no more messages can be delivered
# to any channels we were listening to
assert p.subscribed is False
# subscribing again flips the flag back
sub_func(keys[0])
assert p.subscribed is True
assert wait_for_message(p) == make_message(sub_type, keys[0], 1)
# unsubscribe again
unsub_func()
assert p.subscribed is True
# subscribe to another channel before reading the unsubscribe response
sub_func(keys[1])
assert p.subscribed is True
# read the unsubscribe for key1
assert wait_for_message(p) == make_message(unsub_type, keys[0], 0)
# we're still subscribed to key2, so subscribed should still be True
assert p.subscribed is True
# read the key2 subscribe message
assert wait_for_message(p) == make_message(sub_type, keys[1], 1)
unsub_func()
# haven't read the message yet, so we're still subscribed
assert p.subscribed is True
assert wait_for_message(p) == make_message(unsub_type, keys[1], 0)
# now we're finally unsubscribed
assert p.subscribed is False
def test_subscribe_property_with_channels(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), 'channel')
self._test_subscribed_property(**kwargs)
def test_subscribe_property_with_patterns(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), 'pattern')
self._test_subscribed_property(**kwargs)
def test_ignore_all_subscribe_messages(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
checks = (
(p.subscribe, 'foo'),
(p.unsubscribe, 'foo'),
(p.psubscribe, 'f*'),
(p.punsubscribe, 'f*'),
)
assert p.subscribed is False
for func, channel in checks:
assert func(channel) is None
assert p.subscribed is True
assert wait_for_message(p) is None
assert p.subscribed is False
def test_ignore_individual_subscribe_messages(self, r):
p = r.pubsub()
checks = (
(p.subscribe, 'foo'),
(p.unsubscribe, 'foo'),
(p.psubscribe, 'f*'),
(p.punsubscribe, 'f*'),
)
assert p.subscribed is False
for func, channel in checks:
assert func(channel) is None
assert p.subscribed is True
message = wait_for_message(p, ignore_subscribe_messages=True)
assert message is None
assert p.subscribed is False
class TestPubSubMessages(object):
def setup_method(self, method):
self.message = None
def message_handler(self, message):
self.message = message
def test_published_message_to_channel(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.subscribe('foo')
assert r.publish('foo', 'test message') == 1
message = wait_for_message(p)
assert isinstance(message, dict)
assert message == make_message('message', 'foo', 'test message')
def test_published_message_to_pattern(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.subscribe('foo')
p.psubscribe('f*')
# 1 to pattern, 1 to channel
assert r.publish('foo', 'test message') == 2
message1 = wait_for_message(p)
message2 = wait_for_message(p)
assert isinstance(message1, dict)
assert isinstance(message2, dict)
expected = [
make_message('message', 'foo', 'test message'),
make_message('pmessage', 'foo', 'test message', pattern='f*')
]
assert message1 in expected
assert message2 in expected
assert message1 != message2
def test_channel_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.subscribe(foo=self.message_handler)
assert r.publish('foo', 'test message') == 1
assert wait_for_message(p) is None
assert self.message == make_message('message', 'foo', 'test message')
def test_pattern_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.psubscribe(**{'f*': self.message_handler})
assert r.publish('foo', 'test message') == 1
assert wait_for_message(p) is None
assert self.message == make_message('pmessage', 'foo', 'test message',
pattern='f*')
def test_unicode_channel_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
channel = u('uni') + unichr(4456) + u('code')
channels = {channel: self.message_handler}
p.subscribe(**channels)
assert r.publish(channel, 'test message') == 1
assert wait_for_message(p) is None
assert self.message == make_message('message', channel, 'test message')
def test_unicode_pattern_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
pattern = u('uni') + unichr(4456) + u('*')
channel = u('uni') + unichr(4456) + u('code')
p.psubscribe(**{pattern: self.message_handler})
assert r.publish(channel, 'test message') == 1
assert wait_for_message(p) is None
assert self.message == make_message('pmessage', channel,
'test message', pattern=pattern)
class TestPubSubAutoDecoding(object):
"These tests only validate that we get unicode values back"
channel = u('uni') + unichr(4456) + u('code')
pattern = u('uni') + unichr(4456) + u('*')
data = u('abc') + unichr(4458) + u('123')
def make_message(self, type, channel, data, pattern=None):
return {
'type': type,
'channel': channel,
'pattern': pattern,
'data': data
}
def setup_method(self, method):
self.message = None
def message_handler(self, message):
self.message = message
@pytest.fixture()
def r(self, request):
return _redis_client(request=request, decode_responses=True)
def test_channel_subscribe_unsubscribe(self, r):
p = r.pubsub()
p.subscribe(self.channel)
assert wait_for_message(p) == self.make_message('subscribe',
self.channel, 1)
p.unsubscribe(self.channel)
assert wait_for_message(p) == self.make_message('unsubscribe',
self.channel, 0)
def test_pattern_subscribe_unsubscribe(self, r):
p = r.pubsub()
p.psubscribe(self.pattern)
assert wait_for_message(p) == self.make_message('psubscribe',
self.pattern, 1)
p.punsubscribe(self.pattern)
assert wait_for_message(p) == self.make_message('punsubscribe',
self.pattern, 0)
def test_channel_publish(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.subscribe(self.channel)
r.publish(self.channel, self.data)
assert wait_for_message(p) == self.make_message('message',
self.channel,
self.data)
def test_pattern_publish(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.psubscribe(self.pattern)
r.publish(self.channel, self.data)
assert wait_for_message(p) == self.make_message('pmessage',
self.channel,
self.data,
pattern=self.pattern)
def test_channel_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.subscribe(**{self.channel: self.message_handler})
r.publish(self.channel, self.data)
assert wait_for_message(p) is None
assert self.message == self.make_message('message', self.channel,
self.data)
# test that we reconnected to the correct channel
p.connection.disconnect()
assert wait_for_message(p) is None # should reconnect
new_data = self.data + u('new data')
r.publish(self.channel, new_data)
assert wait_for_message(p) is None
assert self.message == self.make_message('message', self.channel,
new_data)
def test_pattern_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
p.psubscribe(**{self.pattern: self.message_handler})
r.publish(self.channel, self.data)
assert wait_for_message(p) is None
assert self.message == self.make_message('pmessage', self.channel,
self.data,
pattern=self.pattern)
# test that we reconnected to the correct pattern
p.connection.disconnect()
assert wait_for_message(p) is None # should reconnect
new_data = self.data + u('new data')
r.publish(self.channel, new_data)
assert wait_for_message(p) is None
assert self.message == self.make_message('pmessage', self.channel,
new_data,
pattern=self.pattern)
class TestPubSubRedisDown(object):
def test_channel_subscribe(self, r):
r = redis.Redis(host='localhost', port=6390)
p = r.pubsub()
with pytest.raises(ConnectionError):
p.subscribe('foo')

View File

@ -0,0 +1,82 @@
from __future__ import with_statement
import pytest
from redis import exceptions
from redis._compat import b
multiply_script = """
local value = redis.call('GET', KEYS[1])
value = tonumber(value)
return value * ARGV[1]"""
class TestScripting(object):
@pytest.fixture(autouse=True)
def reset_scripts(self, r):
r.script_flush()
def test_eval(self, r):
r.set('a', 2)
# 2 * 3 == 6
assert r.eval(multiply_script, 1, 'a', 3) == 6
def test_evalsha(self, r):
r.set('a', 2)
sha = r.script_load(multiply_script)
# 2 * 3 == 6
assert r.evalsha(sha, 1, 'a', 3) == 6
def test_evalsha_script_not_loaded(self, r):
r.set('a', 2)
sha = r.script_load(multiply_script)
# remove the script from Redis's cache
r.script_flush()
with pytest.raises(exceptions.NoScriptError):
r.evalsha(sha, 1, 'a', 3)
def test_script_loading(self, r):
# get the sha, then clear the cache
sha = r.script_load(multiply_script)
r.script_flush()
assert r.script_exists(sha) == [False]
r.script_load(multiply_script)
assert r.script_exists(sha) == [True]
def test_script_object(self, r):
r.set('a', 2)
multiply = r.register_script(multiply_script)
assert not multiply.sha
# test evalsha fail -> script load + retry
assert multiply(keys=['a'], args=[3]) == 6
assert multiply.sha
assert r.script_exists(multiply.sha) == [True]
# test first evalsha
assert multiply(keys=['a'], args=[3]) == 6
def test_script_object_in_pipeline(self, r):
multiply = r.register_script(multiply_script)
assert not multiply.sha
pipe = r.pipeline()
pipe.set('a', 2)
pipe.get('a')
multiply(keys=['a'], args=[3], client=pipe)
# even though the pipeline wasn't executed yet, we made sure the
# script was loaded and got a valid sha
assert multiply.sha
assert r.script_exists(multiply.sha) == [True]
# [SET worked, GET 'a', result of multiple script]
assert pipe.execute() == [True, b('2'), 6]
# purge the script from redis's cache and re-run the pipeline
# the multiply script object knows it's sha, so it shouldn't get
# reloaded until pipe.execute()
r.script_flush()
pipe = r.pipeline()
pipe.set('a', 2)
pipe.get('a')
assert multiply.sha
multiply(keys=['a'], args=[3], client=pipe)
assert r.script_exists(multiply.sha) == [False]
# [SET worked, GET 'a', result of multiple script]
assert pipe.execute() == [True, b('2'), 6]

View File

@ -0,0 +1,173 @@
from __future__ import with_statement
import pytest
from redis import exceptions
from redis.sentinel import (Sentinel, SentinelConnectionPool,
MasterNotFoundError, SlaveNotFoundError)
from redis._compat import next
import redis.sentinel
class SentinelTestClient(object):
def __init__(self, cluster, id):
self.cluster = cluster
self.id = id
def sentinel_masters(self):
self.cluster.connection_error_if_down(self)
return {self.cluster.service_name: self.cluster.master}
def sentinel_slaves(self, master_name):
self.cluster.connection_error_if_down(self)
if master_name != self.cluster.service_name:
return []
return self.cluster.slaves
class SentinelTestCluster(object):
def __init__(self, service_name='mymaster', ip='127.0.0.1', port=6379):
self.clients = {}
self.master = {
'ip': ip,
'port': port,
'is_master': True,
'is_sdown': False,
'is_odown': False,
'num-other-sentinels': 0,
}
self.service_name = service_name
self.slaves = []
self.nodes_down = set()
def connection_error_if_down(self, node):
if node.id in self.nodes_down:
raise exceptions.ConnectionError
def client(self, host, port, **kwargs):
return SentinelTestClient(self, (host, port))
@pytest.fixture()
def cluster(request):
def teardown():
redis.sentinel.StrictRedis = saved_StrictRedis
cluster = SentinelTestCluster()
saved_StrictRedis = redis.sentinel.StrictRedis
redis.sentinel.StrictRedis = cluster.client
request.addfinalizer(teardown)
return cluster
@pytest.fixture()
def sentinel(request, cluster):
return Sentinel([('foo', 26379), ('bar', 26379)])
def test_discover_master(sentinel):
address = sentinel.discover_master('mymaster')
assert address == ('127.0.0.1', 6379)
def test_discover_master_error(sentinel):
with pytest.raises(MasterNotFoundError):
sentinel.discover_master('xxx')
def test_discover_master_sentinel_down(cluster, sentinel):
# Put first sentinel 'foo' down
cluster.nodes_down.add(('foo', 26379))
address = sentinel.discover_master('mymaster')
assert address == ('127.0.0.1', 6379)
# 'bar' is now first sentinel
assert sentinel.sentinels[0].id == ('bar', 26379)
def test_master_min_other_sentinels(cluster):
sentinel = Sentinel([('foo', 26379)], min_other_sentinels=1)
# min_other_sentinels
with pytest.raises(MasterNotFoundError):
sentinel.discover_master('mymaster')
cluster.master['num-other-sentinels'] = 2
address = sentinel.discover_master('mymaster')
assert address == ('127.0.0.1', 6379)
def test_master_odown(cluster, sentinel):
cluster.master['is_odown'] = True
with pytest.raises(MasterNotFoundError):
sentinel.discover_master('mymaster')
def test_master_sdown(cluster, sentinel):
cluster.master['is_sdown'] = True
with pytest.raises(MasterNotFoundError):
sentinel.discover_master('mymaster')
def test_discover_slaves(cluster, sentinel):
assert sentinel.discover_slaves('mymaster') == []
cluster.slaves = [
{'ip': 'slave0', 'port': 1234, 'is_odown': False, 'is_sdown': False},
{'ip': 'slave1', 'port': 1234, 'is_odown': False, 'is_sdown': False},
]
assert sentinel.discover_slaves('mymaster') == [
('slave0', 1234), ('slave1', 1234)]
# slave0 -> ODOWN
cluster.slaves[0]['is_odown'] = True
assert sentinel.discover_slaves('mymaster') == [
('slave1', 1234)]
# slave1 -> SDOWN
cluster.slaves[1]['is_sdown'] = True
assert sentinel.discover_slaves('mymaster') == []
cluster.slaves[0]['is_odown'] = False
cluster.slaves[1]['is_sdown'] = False
# node0 -> DOWN
cluster.nodes_down.add(('foo', 26379))
assert sentinel.discover_slaves('mymaster') == [
('slave0', 1234), ('slave1', 1234)]
def test_master_for(cluster, sentinel):
master = sentinel.master_for('mymaster', db=9)
assert master.ping()
assert master.connection_pool.master_address == ('127.0.0.1', 6379)
# Use internal connection check
master = sentinel.master_for('mymaster', db=9, check_connection=True)
assert master.ping()
def test_slave_for(cluster, sentinel):
cluster.slaves = [
{'ip': '127.0.0.1', 'port': 6379,
'is_odown': False, 'is_sdown': False},
]
slave = sentinel.slave_for('mymaster', db=9)
assert slave.ping()
def test_slave_for_slave_not_found_error(cluster, sentinel):
cluster.master['is_odown'] = True
slave = sentinel.slave_for('mymaster', db=9)
with pytest.raises(SlaveNotFoundError):
slave.ping()
def test_slave_round_robin(cluster, sentinel):
cluster.slaves = [
{'ip': 'slave0', 'port': 6379, 'is_odown': False, 'is_sdown': False},
{'ip': 'slave1', 'port': 6379, 'is_odown': False, 'is_sdown': False},
]
pool = SentinelConnectionPool('mymaster', sentinel)
rotator = pool.rotate_slaves()
assert next(rotator) in (('slave0', 6379), ('slave1', 6379))
assert next(rotator) in (('slave0', 6379), ('slave1', 6379))
# Fallback to master
assert next(rotator) == ('127.0.0.1', 6379)
with pytest.raises(SlaveNotFoundError):
next(rotator)