mirror of https://github.com/ledisdb/ledisdb.git
add client for python
This commit is contained in:
parent
852fce9f4c
commit
a0bd2e90e5
|
@ -0,0 +1,31 @@
|
||||||
|
from redis.client import Redis, StrictRedis
|
||||||
|
from redis.connection import (
|
||||||
|
BlockingConnectionPool,
|
||||||
|
ConnectionPool,
|
||||||
|
Connection,
|
||||||
|
UnixDomainSocketConnection
|
||||||
|
)
|
||||||
|
from redis.utils import from_url
|
||||||
|
from redis.exceptions import (
|
||||||
|
AuthenticationError,
|
||||||
|
ConnectionError,
|
||||||
|
BusyLoadingError,
|
||||||
|
DataError,
|
||||||
|
InvalidResponse,
|
||||||
|
PubSubError,
|
||||||
|
RedisError,
|
||||||
|
ResponseError,
|
||||||
|
WatchError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__version__ = '2.7.6'
|
||||||
|
VERSION = tuple(map(int, __version__.split('.')))
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Redis', 'StrictRedis', 'ConnectionPool', 'BlockingConnectionPool',
|
||||||
|
'Connection', 'UnixDomainSocketConnection',
|
||||||
|
'RedisError', 'ConnectionError', 'ResponseError', 'AuthenticationError',
|
||||||
|
'InvalidResponse', 'DataError', 'PubSubError', 'WatchError', 'from_url',
|
||||||
|
'BusyLoadingError'
|
||||||
|
]
|
|
@ -0,0 +1,79 @@
|
||||||
|
"""Internal module for Python 2 backwards compatibility."""
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
if sys.version_info[0] < 3:
|
||||||
|
from urlparse import urlparse
|
||||||
|
from itertools import imap, izip
|
||||||
|
from string import letters as ascii_letters
|
||||||
|
from Queue import Queue
|
||||||
|
try:
|
||||||
|
from cStringIO import StringIO as BytesIO
|
||||||
|
except ImportError:
|
||||||
|
from StringIO import StringIO as BytesIO
|
||||||
|
|
||||||
|
iteritems = lambda x: x.iteritems()
|
||||||
|
iterkeys = lambda x: x.iterkeys()
|
||||||
|
itervalues = lambda x: x.itervalues()
|
||||||
|
nativestr = lambda x: \
|
||||||
|
x if isinstance(x, str) else x.encode('utf-8', 'replace')
|
||||||
|
u = lambda x: x.decode()
|
||||||
|
b = lambda x: x
|
||||||
|
next = lambda x: x.next()
|
||||||
|
byte_to_chr = lambda x: x
|
||||||
|
unichr = unichr
|
||||||
|
xrange = xrange
|
||||||
|
basestring = basestring
|
||||||
|
unicode = unicode
|
||||||
|
bytes = str
|
||||||
|
long = long
|
||||||
|
else:
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from io import BytesIO
|
||||||
|
from string import ascii_letters
|
||||||
|
from queue import Queue
|
||||||
|
|
||||||
|
iteritems = lambda x: iter(x.items())
|
||||||
|
iterkeys = lambda x: iter(x.keys())
|
||||||
|
itervalues = lambda x: iter(x.values())
|
||||||
|
byte_to_chr = lambda x: chr(x)
|
||||||
|
nativestr = lambda x: \
|
||||||
|
x if isinstance(x, str) else x.decode('utf-8', 'replace')
|
||||||
|
u = lambda x: x
|
||||||
|
b = lambda x: x.encode('iso-8859-1') if not isinstance(x, bytes) else x
|
||||||
|
next = next
|
||||||
|
unichr = chr
|
||||||
|
imap = map
|
||||||
|
izip = zip
|
||||||
|
xrange = range
|
||||||
|
basestring = str
|
||||||
|
unicode = str
|
||||||
|
bytes = bytes
|
||||||
|
long = int
|
||||||
|
|
||||||
|
try: # Python 3
|
||||||
|
from queue import LifoQueue, Empty, Full
|
||||||
|
except ImportError:
|
||||||
|
from Queue import Empty, Full
|
||||||
|
try: # Python 2.6 - 2.7
|
||||||
|
from Queue import LifoQueue
|
||||||
|
except ImportError: # Python 2.5
|
||||||
|
from Queue import Queue
|
||||||
|
# From the Python 2.7 lib. Python 2.5 already extracted the core
|
||||||
|
# methods to aid implementating different queue organisations.
|
||||||
|
|
||||||
|
class LifoQueue(Queue):
|
||||||
|
"Override queue methods to implement a last-in first-out queue."
|
||||||
|
|
||||||
|
def _init(self, maxsize):
|
||||||
|
self.maxsize = maxsize
|
||||||
|
self.queue = []
|
||||||
|
|
||||||
|
def _qsize(self, len=len):
|
||||||
|
return len(self.queue)
|
||||||
|
|
||||||
|
def _put(self, item):
|
||||||
|
self.queue.append(item)
|
||||||
|
|
||||||
|
def _get(self):
|
||||||
|
return self.queue.pop()
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,580 @@
|
||||||
|
from itertools import chain
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from redis._compat import (b, xrange, imap, byte_to_chr, unicode, bytes, long,
|
||||||
|
BytesIO, nativestr, basestring,
|
||||||
|
LifoQueue, Empty, Full)
|
||||||
|
from redis.exceptions import (
|
||||||
|
RedisError,
|
||||||
|
ConnectionError,
|
||||||
|
BusyLoadingError,
|
||||||
|
ResponseError,
|
||||||
|
InvalidResponse,
|
||||||
|
AuthenticationError,
|
||||||
|
NoScriptError,
|
||||||
|
ExecAbortError,
|
||||||
|
)
|
||||||
|
from redis.utils import HIREDIS_AVAILABLE
|
||||||
|
if HIREDIS_AVAILABLE:
|
||||||
|
import hiredis
|
||||||
|
|
||||||
|
|
||||||
|
SYM_STAR = b('*')
|
||||||
|
SYM_DOLLAR = b('$')
|
||||||
|
SYM_CRLF = b('\r\n')
|
||||||
|
SYM_LF = b('\n')
|
||||||
|
|
||||||
|
|
||||||
|
class PythonParser(object):
|
||||||
|
"Plain Python parsing class"
|
||||||
|
MAX_READ_LENGTH = 1000000
|
||||||
|
encoding = None
|
||||||
|
|
||||||
|
EXCEPTION_CLASSES = {
|
||||||
|
'ERR': ResponseError,
|
||||||
|
'EXECABORT': ExecAbortError,
|
||||||
|
'LOADING': BusyLoadingError,
|
||||||
|
'NOSCRIPT': NoScriptError,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._fp = None
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
try:
|
||||||
|
self.on_disconnect()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_connect(self, connection):
|
||||||
|
"Called when the socket connects"
|
||||||
|
self._fp = connection._sock.makefile('rb')
|
||||||
|
if connection.decode_responses:
|
||||||
|
self.encoding = connection.encoding
|
||||||
|
|
||||||
|
def on_disconnect(self):
|
||||||
|
"Called when the socket disconnects"
|
||||||
|
if self._fp is not None:
|
||||||
|
self._fp.close()
|
||||||
|
self._fp = None
|
||||||
|
|
||||||
|
def read(self, length=None):
|
||||||
|
"""
|
||||||
|
Read a line from the socket if no length is specified,
|
||||||
|
otherwise read ``length`` bytes. Always strip away the newlines.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if length is not None:
|
||||||
|
bytes_left = length + 2 # read the line ending
|
||||||
|
if length > self.MAX_READ_LENGTH:
|
||||||
|
# apparently reading more than 1MB or so from a windows
|
||||||
|
# socket can cause MemoryErrors. See:
|
||||||
|
# https://github.com/andymccurdy/redis-py/issues/205
|
||||||
|
# read smaller chunks at a time to work around this
|
||||||
|
try:
|
||||||
|
buf = BytesIO()
|
||||||
|
while bytes_left > 0:
|
||||||
|
read_len = min(bytes_left, self.MAX_READ_LENGTH)
|
||||||
|
buf.write(self._fp.read(read_len))
|
||||||
|
bytes_left -= read_len
|
||||||
|
buf.seek(0)
|
||||||
|
return buf.read(length)
|
||||||
|
finally:
|
||||||
|
buf.close()
|
||||||
|
return self._fp.read(bytes_left)[:-2]
|
||||||
|
|
||||||
|
# no length, read a full line
|
||||||
|
return self._fp.readline()[:-2]
|
||||||
|
except (socket.error, socket.timeout):
|
||||||
|
e = sys.exc_info()[1]
|
||||||
|
raise ConnectionError("Error while reading from socket: %s" %
|
||||||
|
(e.args,))
|
||||||
|
|
||||||
|
def parse_error(self, response):
|
||||||
|
"Parse an error response"
|
||||||
|
error_code = response.split(' ')[0]
|
||||||
|
if error_code in self.EXCEPTION_CLASSES:
|
||||||
|
response = response[len(error_code) + 1:]
|
||||||
|
return self.EXCEPTION_CLASSES[error_code](response)
|
||||||
|
return ResponseError(response)
|
||||||
|
|
||||||
|
def read_response(self):
|
||||||
|
response = self.read()
|
||||||
|
if not response:
|
||||||
|
raise ConnectionError("Socket closed on remote end")
|
||||||
|
|
||||||
|
byte, response = byte_to_chr(response[0]), response[1:]
|
||||||
|
|
||||||
|
if byte not in ('-', '+', ':', '$', '*'):
|
||||||
|
raise InvalidResponse("Protocol Error")
|
||||||
|
|
||||||
|
# server returned an error
|
||||||
|
if byte == '-':
|
||||||
|
response = nativestr(response)
|
||||||
|
error = self.parse_error(response)
|
||||||
|
# if the error is a ConnectionError, raise immediately so the user
|
||||||
|
# is notified
|
||||||
|
if isinstance(error, ConnectionError):
|
||||||
|
raise error
|
||||||
|
# otherwise, we're dealing with a ResponseError that might belong
|
||||||
|
# inside a pipeline response. the connection's read_response()
|
||||||
|
# and/or the pipeline's execute() will raise this error if
|
||||||
|
# necessary, so just return the exception instance here.
|
||||||
|
return error
|
||||||
|
# single value
|
||||||
|
elif byte == '+':
|
||||||
|
pass
|
||||||
|
# int value
|
||||||
|
elif byte == ':':
|
||||||
|
response = long(response)
|
||||||
|
# bulk response
|
||||||
|
elif byte == '$':
|
||||||
|
length = int(response)
|
||||||
|
if length == -1:
|
||||||
|
return None
|
||||||
|
response = self.read(length)
|
||||||
|
# multi-bulk response
|
||||||
|
elif byte == '*':
|
||||||
|
length = int(response)
|
||||||
|
if length == -1:
|
||||||
|
return None
|
||||||
|
response = [self.read_response() for i in xrange(length)]
|
||||||
|
if isinstance(response, bytes) and self.encoding:
|
||||||
|
response = response.decode(self.encoding)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class HiredisParser(object):
|
||||||
|
"Parser class for connections using Hiredis"
|
||||||
|
def __init__(self):
|
||||||
|
if not HIREDIS_AVAILABLE:
|
||||||
|
raise RedisError("Hiredis is not installed")
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
try:
|
||||||
|
self.on_disconnect()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_connect(self, connection):
|
||||||
|
self._sock = connection._sock
|
||||||
|
kwargs = {
|
||||||
|
'protocolError': InvalidResponse,
|
||||||
|
'replyError': ResponseError,
|
||||||
|
}
|
||||||
|
if connection.decode_responses:
|
||||||
|
kwargs['encoding'] = connection.encoding
|
||||||
|
self._reader = hiredis.Reader(**kwargs)
|
||||||
|
|
||||||
|
def on_disconnect(self):
|
||||||
|
self._sock = None
|
||||||
|
self._reader = None
|
||||||
|
|
||||||
|
def read_response(self):
|
||||||
|
if not self._reader:
|
||||||
|
raise ConnectionError("Socket closed on remote end")
|
||||||
|
response = self._reader.gets()
|
||||||
|
while response is False:
|
||||||
|
try:
|
||||||
|
buffer = self._sock.recv(4096)
|
||||||
|
except (socket.error, socket.timeout):
|
||||||
|
e = sys.exc_info()[1]
|
||||||
|
raise ConnectionError("Error while reading from socket: %s" %
|
||||||
|
(e.args,))
|
||||||
|
if not buffer:
|
||||||
|
raise ConnectionError("Socket closed on remote end")
|
||||||
|
self._reader.feed(buffer)
|
||||||
|
# proactively, but not conclusively, check if more data is in the
|
||||||
|
# buffer. if the data received doesn't end with \n, there's more.
|
||||||
|
if not buffer.endswith(SYM_LF):
|
||||||
|
continue
|
||||||
|
response = self._reader.gets()
|
||||||
|
return response
|
||||||
|
|
||||||
|
if HIREDIS_AVAILABLE:
|
||||||
|
DefaultParser = HiredisParser
|
||||||
|
else:
|
||||||
|
DefaultParser = PythonParser
|
||||||
|
|
||||||
|
|
||||||
|
class Connection(object):
|
||||||
|
"Manages TCP communication to and from a Redis server"
|
||||||
|
def __init__(self, host='localhost', port=6379, db=0, password=None,
|
||||||
|
socket_timeout=None, encoding='utf-8',
|
||||||
|
encoding_errors='strict', decode_responses=False,
|
||||||
|
parser_class=DefaultParser):
|
||||||
|
self.pid = os.getpid()
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.db = db
|
||||||
|
self.password = password
|
||||||
|
self.socket_timeout = socket_timeout
|
||||||
|
self.encoding = encoding
|
||||||
|
self.encoding_errors = encoding_errors
|
||||||
|
self.decode_responses = decode_responses
|
||||||
|
self._sock = None
|
||||||
|
self._parser = parser_class()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
try:
|
||||||
|
self.disconnect()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
"Connects to the Redis server if not already connected"
|
||||||
|
if self._sock:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
sock = self._connect()
|
||||||
|
except socket.error:
|
||||||
|
e = sys.exc_info()[1]
|
||||||
|
raise ConnectionError(self._error_message(e))
|
||||||
|
|
||||||
|
self._sock = sock
|
||||||
|
self.on_connect()
|
||||||
|
|
||||||
|
def _connect(self):
|
||||||
|
"Create a TCP socket connection"
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
sock.settimeout(self.socket_timeout)
|
||||||
|
sock.connect((self.host, self.port))
|
||||||
|
return sock
|
||||||
|
|
||||||
|
def _error_message(self, exception):
|
||||||
|
# args for socket.error can either be (errno, "message")
|
||||||
|
# or just "message"
|
||||||
|
if len(exception.args) == 1:
|
||||||
|
return "Error connecting to %s:%s. %s." % \
|
||||||
|
(self.host, self.port, exception.args[0])
|
||||||
|
else:
|
||||||
|
return "Error %s connecting %s:%s. %s." % \
|
||||||
|
(exception.args[0], self.host, self.port, exception.args[1])
|
||||||
|
|
||||||
|
def on_connect(self):
|
||||||
|
"Initialize the connection, authenticate and select a database"
|
||||||
|
self._parser.on_connect(self)
|
||||||
|
|
||||||
|
# if a password is specified, authenticate
|
||||||
|
if self.password:
|
||||||
|
self.send_command('AUTH', self.password)
|
||||||
|
if nativestr(self.read_response()) != 'OK':
|
||||||
|
raise AuthenticationError('Invalid Password')
|
||||||
|
|
||||||
|
# if a database is specified, switch to it
|
||||||
|
if self.db:
|
||||||
|
self.send_command('SELECT', self.db)
|
||||||
|
if nativestr(self.read_response()) != 'OK':
|
||||||
|
raise ConnectionError('Invalid Database')
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
"Disconnects from the Redis server"
|
||||||
|
self._parser.on_disconnect()
|
||||||
|
if self._sock is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self._sock.close()
|
||||||
|
except socket.error:
|
||||||
|
pass
|
||||||
|
self._sock = None
|
||||||
|
|
||||||
|
def send_packed_command(self, command):
|
||||||
|
"Send an already packed command to the Redis server"
|
||||||
|
if not self._sock:
|
||||||
|
self.connect()
|
||||||
|
try:
|
||||||
|
self._sock.sendall(command)
|
||||||
|
except socket.error:
|
||||||
|
e = sys.exc_info()[1]
|
||||||
|
self.disconnect()
|
||||||
|
if len(e.args) == 1:
|
||||||
|
_errno, errmsg = 'UNKNOWN', e.args[0]
|
||||||
|
else:
|
||||||
|
_errno, errmsg = e.args
|
||||||
|
raise ConnectionError("Error %s while writing to socket. %s." %
|
||||||
|
(_errno, errmsg))
|
||||||
|
except Exception:
|
||||||
|
self.disconnect()
|
||||||
|
raise
|
||||||
|
|
||||||
|
def send_command(self, *args):
|
||||||
|
"Pack and send a command to the Redis server"
|
||||||
|
self.send_packed_command(self.pack_command(*args))
|
||||||
|
|
||||||
|
def read_response(self):
|
||||||
|
"Read the response from a previously sent command"
|
||||||
|
try:
|
||||||
|
response = self._parser.read_response()
|
||||||
|
except Exception:
|
||||||
|
self.disconnect()
|
||||||
|
raise
|
||||||
|
if isinstance(response, ResponseError):
|
||||||
|
raise response
|
||||||
|
return response
|
||||||
|
|
||||||
|
def encode(self, value):
|
||||||
|
"Return a bytestring representation of the value"
|
||||||
|
if isinstance(value, bytes):
|
||||||
|
return value
|
||||||
|
if isinstance(value, float):
|
||||||
|
value = repr(value)
|
||||||
|
if not isinstance(value, basestring):
|
||||||
|
value = str(value)
|
||||||
|
if isinstance(value, unicode):
|
||||||
|
value = value.encode(self.encoding, self.encoding_errors)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def pack_command(self, *args):
|
||||||
|
"Pack a series of arguments into a value Redis command"
|
||||||
|
output = SYM_STAR + b(str(len(args))) + SYM_CRLF
|
||||||
|
for enc_value in imap(self.encode, args):
|
||||||
|
output += SYM_DOLLAR
|
||||||
|
output += b(str(len(enc_value)))
|
||||||
|
output += SYM_CRLF
|
||||||
|
output += enc_value
|
||||||
|
output += SYM_CRLF
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class UnixDomainSocketConnection(Connection):
|
||||||
|
def __init__(self, path='', db=0, password=None,
|
||||||
|
socket_timeout=None, encoding='utf-8',
|
||||||
|
encoding_errors='strict', decode_responses=False,
|
||||||
|
parser_class=DefaultParser):
|
||||||
|
self.pid = os.getpid()
|
||||||
|
self.path = path
|
||||||
|
self.db = db
|
||||||
|
self.password = password
|
||||||
|
self.socket_timeout = socket_timeout
|
||||||
|
self.encoding = encoding
|
||||||
|
self.encoding_errors = encoding_errors
|
||||||
|
self.decode_responses = decode_responses
|
||||||
|
self._sock = None
|
||||||
|
self._parser = parser_class()
|
||||||
|
|
||||||
|
def _connect(self):
|
||||||
|
"Create a Unix domain socket connection"
|
||||||
|
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||||
|
sock.settimeout(self.socket_timeout)
|
||||||
|
sock.connect(self.path)
|
||||||
|
return sock
|
||||||
|
|
||||||
|
def _error_message(self, exception):
|
||||||
|
# args for socket.error can either be (errno, "message")
|
||||||
|
# or just "message"
|
||||||
|
if len(exception.args) == 1:
|
||||||
|
return "Error connecting to unix socket: %s. %s." % \
|
||||||
|
(self.path, exception.args[0])
|
||||||
|
else:
|
||||||
|
return "Error %s connecting to unix socket: %s. %s." % \
|
||||||
|
(exception.args[0], self.path, exception.args[1])
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: add ability to block waiting on a connection to be released
|
||||||
|
class ConnectionPool(object):
|
||||||
|
"Generic connection pool"
|
||||||
|
def __init__(self, connection_class=Connection, max_connections=None,
|
||||||
|
**connection_kwargs):
|
||||||
|
self.pid = os.getpid()
|
||||||
|
self.connection_class = connection_class
|
||||||
|
self.connection_kwargs = connection_kwargs
|
||||||
|
self.max_connections = max_connections or 2 ** 31
|
||||||
|
self._created_connections = 0
|
||||||
|
self._available_connections = []
|
||||||
|
self._in_use_connections = set()
|
||||||
|
|
||||||
|
def _checkpid(self):
|
||||||
|
if self.pid != os.getpid():
|
||||||
|
self.disconnect()
|
||||||
|
self.__init__(self.connection_class, self.max_connections,
|
||||||
|
**self.connection_kwargs)
|
||||||
|
|
||||||
|
def get_connection(self, command_name, *keys, **options):
|
||||||
|
"Get a connection from the pool"
|
||||||
|
self._checkpid()
|
||||||
|
try:
|
||||||
|
connection = self._available_connections.pop()
|
||||||
|
except IndexError:
|
||||||
|
connection = self.make_connection()
|
||||||
|
self._in_use_connections.add(connection)
|
||||||
|
return connection
|
||||||
|
|
||||||
|
def make_connection(self):
|
||||||
|
"Create a new connection"
|
||||||
|
if self._created_connections >= self.max_connections:
|
||||||
|
raise ConnectionError("Too many connections")
|
||||||
|
self._created_connections += 1
|
||||||
|
return self.connection_class(**self.connection_kwargs)
|
||||||
|
|
||||||
|
def release(self, connection):
|
||||||
|
"Releases the connection back to the pool"
|
||||||
|
self._checkpid()
|
||||||
|
if connection.pid == self.pid:
|
||||||
|
self._in_use_connections.remove(connection)
|
||||||
|
self._available_connections.append(connection)
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
"Disconnects all connections in the pool"
|
||||||
|
all_conns = chain(self._available_connections,
|
||||||
|
self._in_use_connections)
|
||||||
|
for connection in all_conns:
|
||||||
|
connection.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
class BlockingConnectionPool(object):
|
||||||
|
"""
|
||||||
|
Thread-safe blocking connection pool::
|
||||||
|
|
||||||
|
>>> from redis.client import Redis
|
||||||
|
>>> client = Redis(connection_pool=BlockingConnectionPool())
|
||||||
|
|
||||||
|
It performs the same function as the default
|
||||||
|
``:py:class: ~redis.connection.ConnectionPool`` implementation, in that,
|
||||||
|
it maintains a pool of reusable connections that can be shared by
|
||||||
|
multiple redis clients (safely across threads if required).
|
||||||
|
|
||||||
|
The difference is that, in the event that a client tries to get a
|
||||||
|
connection from the pool when all of connections are in use, rather than
|
||||||
|
raising a ``:py:class: ~redis.exceptions.ConnectionError`` (as the default
|
||||||
|
``:py:class: ~redis.connection.ConnectionPool`` implementation does), it
|
||||||
|
makes the client wait ("blocks") for a specified number of seconds until
|
||||||
|
a connection becomes available.
|
||||||
|
|
||||||
|
Use ``max_connections`` to increase / decrease the pool size::
|
||||||
|
|
||||||
|
>>> pool = BlockingConnectionPool(max_connections=10)
|
||||||
|
|
||||||
|
Use ``timeout`` to tell it either how many seconds to wait for a connection
|
||||||
|
to become available, or to block forever:
|
||||||
|
|
||||||
|
# Block forever.
|
||||||
|
>>> pool = BlockingConnectionPool(timeout=None)
|
||||||
|
|
||||||
|
# Raise a ``ConnectionError`` after five seconds if a connection is
|
||||||
|
# not available.
|
||||||
|
>>> pool = BlockingConnectionPool(timeout=5)
|
||||||
|
"""
|
||||||
|
def __init__(self, max_connections=50, timeout=20, connection_class=None,
|
||||||
|
queue_class=None, **connection_kwargs):
|
||||||
|
"Compose and assign values."
|
||||||
|
# Compose.
|
||||||
|
if connection_class is None:
|
||||||
|
connection_class = Connection
|
||||||
|
if queue_class is None:
|
||||||
|
queue_class = LifoQueue
|
||||||
|
|
||||||
|
# Assign.
|
||||||
|
self.connection_class = connection_class
|
||||||
|
self.connection_kwargs = connection_kwargs
|
||||||
|
self.queue_class = queue_class
|
||||||
|
self.max_connections = max_connections
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
# Validate the ``max_connections``. With the "fill up the queue"
|
||||||
|
# algorithm we use, it must be a positive integer.
|
||||||
|
is_valid = isinstance(max_connections, int) and max_connections > 0
|
||||||
|
if not is_valid:
|
||||||
|
raise ValueError('``max_connections`` must be a positive integer')
|
||||||
|
|
||||||
|
# Get the current process id, so we can disconnect and reinstantiate if
|
||||||
|
# it changes.
|
||||||
|
self.pid = os.getpid()
|
||||||
|
|
||||||
|
# Create and fill up a thread safe queue with ``None`` values.
|
||||||
|
self.pool = self.queue_class(max_connections)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
self.pool.put_nowait(None)
|
||||||
|
except Full:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Keep a list of actual connection instances so that we can
|
||||||
|
# disconnect them later.
|
||||||
|
self._connections = []
|
||||||
|
|
||||||
|
def _checkpid(self):
|
||||||
|
"""
|
||||||
|
Check the current process id. If it has changed, disconnect and
|
||||||
|
re-instantiate this connection pool instance.
|
||||||
|
"""
|
||||||
|
# Get the current process id.
|
||||||
|
pid = os.getpid()
|
||||||
|
|
||||||
|
# If it hasn't changed since we were instantiated, then we're fine, so
|
||||||
|
# just exit, remaining connected.
|
||||||
|
if self.pid == pid:
|
||||||
|
return
|
||||||
|
|
||||||
|
# If it has changed, then disconnect and re-instantiate.
|
||||||
|
self.disconnect()
|
||||||
|
self.reinstantiate()
|
||||||
|
|
||||||
|
def make_connection(self):
|
||||||
|
"Make a fresh connection."
|
||||||
|
connection = self.connection_class(**self.connection_kwargs)
|
||||||
|
self._connections.append(connection)
|
||||||
|
return connection
|
||||||
|
|
||||||
|
def get_connection(self, command_name, *keys, **options):
|
||||||
|
"""
|
||||||
|
Get a connection, blocking for ``self.timeout`` until a connection
|
||||||
|
is available from the pool.
|
||||||
|
|
||||||
|
If the connection returned is ``None`` then creates a new connection.
|
||||||
|
Because we use a last-in first-out queue, the existing connections
|
||||||
|
(having been returned to the pool after the initial ``None`` values
|
||||||
|
were added) will be returned before ``None`` values. This means we only
|
||||||
|
create new connections when we need to, i.e.: the actual number of
|
||||||
|
connections will only increase in response to demand.
|
||||||
|
"""
|
||||||
|
# Make sure we haven't changed process.
|
||||||
|
self._checkpid()
|
||||||
|
|
||||||
|
# Try and get a connection from the pool. If one isn't available within
|
||||||
|
# self.timeout then raise a ``ConnectionError``.
|
||||||
|
connection = None
|
||||||
|
try:
|
||||||
|
connection = self.pool.get(block=True, timeout=self.timeout)
|
||||||
|
except Empty:
|
||||||
|
# Note that this is not caught by the redis client and will be
|
||||||
|
# raised unless handled by application code. If you want never to
|
||||||
|
raise ConnectionError("No connection available.")
|
||||||
|
|
||||||
|
# If the ``connection`` is actually ``None`` then that's a cue to make
|
||||||
|
# a new connection to add to the pool.
|
||||||
|
if connection is None:
|
||||||
|
connection = self.make_connection()
|
||||||
|
|
||||||
|
return connection
|
||||||
|
|
||||||
|
def release(self, connection):
|
||||||
|
"Releases the connection back to the pool."
|
||||||
|
# Make sure we haven't changed process.
|
||||||
|
self._checkpid()
|
||||||
|
|
||||||
|
# Put the connection back into the pool.
|
||||||
|
try:
|
||||||
|
self.pool.put_nowait(connection)
|
||||||
|
except Full:
|
||||||
|
# This shouldn't normally happen but might perhaps happen after a
|
||||||
|
# reinstantiation. So, we can handle the exception by not putting
|
||||||
|
# the connection back on the pool, because we definitely do not
|
||||||
|
# want to reuse it.
|
||||||
|
pass
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
"Disconnects all connections in the pool."
|
||||||
|
for connection in self._connections:
|
||||||
|
connection.disconnect()
|
||||||
|
|
||||||
|
def reinstantiate(self):
|
||||||
|
"""
|
||||||
|
Reinstatiate this instance within a new process with a new connection
|
||||||
|
pool set.
|
||||||
|
"""
|
||||||
|
self.__init__(max_connections=self.max_connections,
|
||||||
|
timeout=self.timeout,
|
||||||
|
connection_class=self.connection_class,
|
||||||
|
queue_class=self.queue_class, **self.connection_kwargs)
|
|
@ -0,0 +1,49 @@
|
||||||
|
"Core exceptions raised by the Redis client"
|
||||||
|
|
||||||
|
|
||||||
|
class RedisError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationError(RedisError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ServerError(RedisError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionError(ServerError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BusyLoadingError(ConnectionError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidResponse(ServerError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseError(RedisError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DataError(RedisError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PubSubError(RedisError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WatchError(RedisError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NoScriptError(ResponseError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ExecAbortError(ResponseError):
|
||||||
|
pass
|
|
@ -0,0 +1,16 @@
|
||||||
|
try:
|
||||||
|
import hiredis
|
||||||
|
HIREDIS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
HIREDIS_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def from_url(url, db=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Returns an active Redis client generated from the given database URL.
|
||||||
|
|
||||||
|
Will attempt to extract the database id from the path url fragment, if
|
||||||
|
none is provided.
|
||||||
|
"""
|
||||||
|
from redis.client import Redis
|
||||||
|
return Redis.from_url(url, db, **kwargs)
|
|
@ -0,0 +1,61 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from redis import __version__
|
||||||
|
|
||||||
|
try:
|
||||||
|
from setuptools import setup
|
||||||
|
from setuptools.command.test import test as TestCommand
|
||||||
|
|
||||||
|
class PyTest(TestCommand):
|
||||||
|
def finalize_options(self):
|
||||||
|
TestCommand.finalize_options(self)
|
||||||
|
self.test_args = []
|
||||||
|
self.test_suite = True
|
||||||
|
|
||||||
|
def run_tests(self):
|
||||||
|
# import here, because outside the eggs aren't loaded
|
||||||
|
import pytest
|
||||||
|
errno = pytest.main(self.test_args)
|
||||||
|
sys.exit(errno)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
|
||||||
|
from distutils.core import setup
|
||||||
|
PyTest = lambda x: x
|
||||||
|
|
||||||
|
f = open(os.path.join(os.path.dirname(__file__), 'README.rst'))
|
||||||
|
long_description = f.read()
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='redis',
|
||||||
|
version=__version__,
|
||||||
|
description='Python client for Redis key-value store',
|
||||||
|
long_description=long_description,
|
||||||
|
url='http://github.com/andymccurdy/redis-py',
|
||||||
|
author='Andy McCurdy',
|
||||||
|
author_email='sedrik@gmail.com',
|
||||||
|
maintainer='Andy McCurdy',
|
||||||
|
maintainer_email='sedrik@gmail.com',
|
||||||
|
keywords=['Redis', 'key-value store'],
|
||||||
|
license='MIT',
|
||||||
|
packages=['redis'],
|
||||||
|
tests_require=['pytest>=2.5.0'],
|
||||||
|
cmdclass={'test': PyTest},
|
||||||
|
classifiers=[
|
||||||
|
'Development Status :: 5 - Production/Stable',
|
||||||
|
'Environment :: Console',
|
||||||
|
'Intended Audience :: Developers',
|
||||||
|
'License :: OSI Approved :: MIT License',
|
||||||
|
'Operating System :: OS Independent',
|
||||||
|
'Programming Language :: Python',
|
||||||
|
'Programming Language :: Python :: 2.6',
|
||||||
|
'Programming Language :: Python :: 2.7',
|
||||||
|
'Programming Language :: Python :: 3',
|
||||||
|
'Programming Language :: Python :: 3.2',
|
||||||
|
'Programming Language :: Python :: 3.3',
|
||||||
|
'Programming Language :: Python :: 3.4',
|
||||||
|
]
|
||||||
|
)
|
|
@ -0,0 +1,46 @@
|
||||||
|
import pytest
|
||||||
|
import redis
|
||||||
|
|
||||||
|
from distutils.version import StrictVersion
|
||||||
|
|
||||||
|
|
||||||
|
_REDIS_VERSIONS = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_version(**kwargs):
|
||||||
|
params = {'host': 'localhost', 'port': 6379, 'db': 9}
|
||||||
|
params.update(kwargs)
|
||||||
|
key = '%s:%s' % (params['host'], params['port'])
|
||||||
|
if key not in _REDIS_VERSIONS:
|
||||||
|
client = redis.Redis(**params)
|
||||||
|
_REDIS_VERSIONS[key] = client.info()['redis_version']
|
||||||
|
client.connection_pool.disconnect()
|
||||||
|
return _REDIS_VERSIONS[key]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client(cls, request=None, **kwargs):
|
||||||
|
params = {'host': 'localhost', 'port': 6379, 'db': 9}
|
||||||
|
params.update(kwargs)
|
||||||
|
client = cls(**params)
|
||||||
|
client.flushdb()
|
||||||
|
if request:
|
||||||
|
def teardown():
|
||||||
|
client.flushdb()
|
||||||
|
client.connection_pool.disconnect()
|
||||||
|
request.addfinalizer(teardown)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def skip_if_server_version_lt(min_version):
|
||||||
|
check = StrictVersion(get_version()) < StrictVersion(min_version)
|
||||||
|
return pytest.mark.skipif(check, reason="")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def r(request, **kwargs):
|
||||||
|
return _get_client(redis.Redis, request, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def sr(request, **kwargs):
|
||||||
|
return _get_client(redis.StrictRedis, request, **kwargs)
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,402 @@
|
||||||
|
from __future__ import with_statement
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import redis
|
||||||
|
import time
|
||||||
|
import re
|
||||||
|
|
||||||
|
from threading import Thread
|
||||||
|
from redis.connection import ssl_available
|
||||||
|
from .conftest import skip_if_server_version_lt
|
||||||
|
|
||||||
|
|
||||||
|
class DummyConnection(object):
|
||||||
|
description_format = "DummyConnection<>"
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.pid = os.getpid()
|
||||||
|
|
||||||
|
|
||||||
|
class TestConnectionPool(object):
|
||||||
|
def get_pool(self, connection_kwargs=None, max_connections=None,
|
||||||
|
connection_class=DummyConnection):
|
||||||
|
connection_kwargs = connection_kwargs or {}
|
||||||
|
pool = redis.ConnectionPool(
|
||||||
|
connection_class=connection_class,
|
||||||
|
max_connections=max_connections,
|
||||||
|
**connection_kwargs)
|
||||||
|
return pool
|
||||||
|
|
||||||
|
def test_connection_creation(self):
|
||||||
|
connection_kwargs = {'foo': 'bar', 'biz': 'baz'}
|
||||||
|
pool = self.get_pool(connection_kwargs=connection_kwargs)
|
||||||
|
connection = pool.get_connection('_')
|
||||||
|
assert isinstance(connection, DummyConnection)
|
||||||
|
assert connection.kwargs == connection_kwargs
|
||||||
|
|
||||||
|
def test_multiple_connections(self):
|
||||||
|
pool = self.get_pool()
|
||||||
|
c1 = pool.get_connection('_')
|
||||||
|
c2 = pool.get_connection('_')
|
||||||
|
assert c1 != c2
|
||||||
|
|
||||||
|
def test_max_connections(self):
|
||||||
|
pool = self.get_pool(max_connections=2)
|
||||||
|
pool.get_connection('_')
|
||||||
|
pool.get_connection('_')
|
||||||
|
with pytest.raises(redis.ConnectionError):
|
||||||
|
pool.get_connection('_')
|
||||||
|
|
||||||
|
def test_reuse_previously_released_connection(self):
|
||||||
|
pool = self.get_pool()
|
||||||
|
c1 = pool.get_connection('_')
|
||||||
|
pool.release(c1)
|
||||||
|
c2 = pool.get_connection('_')
|
||||||
|
assert c1 == c2
|
||||||
|
|
||||||
|
def test_repr_contains_db_info_tcp(self):
|
||||||
|
connection_kwargs = {'host': 'localhost', 'port': 6379, 'db': 1}
|
||||||
|
pool = self.get_pool(connection_kwargs=connection_kwargs,
|
||||||
|
connection_class=redis.Connection)
|
||||||
|
expected = 'ConnectionPool<Connection<host=localhost,port=6379,db=1>>'
|
||||||
|
assert repr(pool) == expected
|
||||||
|
|
||||||
|
def test_repr_contains_db_info_unix(self):
|
||||||
|
connection_kwargs = {'path': '/abc', 'db': 1}
|
||||||
|
pool = self.get_pool(connection_kwargs=connection_kwargs,
|
||||||
|
connection_class=redis.UnixDomainSocketConnection)
|
||||||
|
expected = 'ConnectionPool<UnixDomainSocketConnection<path=/abc,db=1>>'
|
||||||
|
assert repr(pool) == expected
|
||||||
|
|
||||||
|
|
||||||
|
class TestBlockingConnectionPool(object):
|
||||||
|
def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20):
|
||||||
|
connection_kwargs = connection_kwargs or {}
|
||||||
|
pool = redis.BlockingConnectionPool(connection_class=DummyConnection,
|
||||||
|
max_connections=max_connections,
|
||||||
|
timeout=timeout,
|
||||||
|
**connection_kwargs)
|
||||||
|
return pool
|
||||||
|
|
||||||
|
def test_connection_creation(self):
|
||||||
|
connection_kwargs = {'foo': 'bar', 'biz': 'baz'}
|
||||||
|
pool = self.get_pool(connection_kwargs=connection_kwargs)
|
||||||
|
connection = pool.get_connection('_')
|
||||||
|
assert isinstance(connection, DummyConnection)
|
||||||
|
assert connection.kwargs == connection_kwargs
|
||||||
|
|
||||||
|
def test_multiple_connections(self):
|
||||||
|
pool = self.get_pool()
|
||||||
|
c1 = pool.get_connection('_')
|
||||||
|
c2 = pool.get_connection('_')
|
||||||
|
assert c1 != c2
|
||||||
|
|
||||||
|
def test_connection_pool_blocks_until_timeout(self):
|
||||||
|
"When out of connections, block for timeout seconds, then raise"
|
||||||
|
pool = self.get_pool(max_connections=1, timeout=0.1)
|
||||||
|
pool.get_connection('_')
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
with pytest.raises(redis.ConnectionError):
|
||||||
|
pool.get_connection('_')
|
||||||
|
# we should have waited at least 0.1 seconds
|
||||||
|
assert time.time() - start >= 0.1
|
||||||
|
|
||||||
|
def connection_pool_blocks_until_another_connection_released(self):
|
||||||
|
"""
|
||||||
|
When out of connections, block until another connection is released
|
||||||
|
to the pool
|
||||||
|
"""
|
||||||
|
pool = self.get_pool(max_connections=1, timeout=2)
|
||||||
|
c1 = pool.get_connection('_')
|
||||||
|
|
||||||
|
def target():
|
||||||
|
time.sleep(0.1)
|
||||||
|
pool.release(c1)
|
||||||
|
|
||||||
|
Thread(target=target).start()
|
||||||
|
start = time.time()
|
||||||
|
pool.get_connection('_')
|
||||||
|
assert time.time() - start >= 0.1
|
||||||
|
|
||||||
|
def test_reuse_previously_released_connection(self):
|
||||||
|
pool = self.get_pool()
|
||||||
|
c1 = pool.get_connection('_')
|
||||||
|
pool.release(c1)
|
||||||
|
c2 = pool.get_connection('_')
|
||||||
|
assert c1 == c2
|
||||||
|
|
||||||
|
def test_repr_contains_db_info_tcp(self):
|
||||||
|
pool = redis.ConnectionPool(host='localhost', port=6379, db=0)
|
||||||
|
expected = 'ConnectionPool<Connection<host=localhost,port=6379,db=0>>'
|
||||||
|
assert repr(pool) == expected
|
||||||
|
|
||||||
|
def test_repr_contains_db_info_unix(self):
|
||||||
|
pool = redis.ConnectionPool(
|
||||||
|
connection_class=redis.UnixDomainSocketConnection,
|
||||||
|
path='abc',
|
||||||
|
db=0,
|
||||||
|
)
|
||||||
|
expected = 'ConnectionPool<UnixDomainSocketConnection<path=abc,db=0>>'
|
||||||
|
assert repr(pool) == expected
|
||||||
|
|
||||||
|
|
||||||
|
class TestConnectionPoolURLParsing(object):
|
||||||
|
def test_defaults(self):
|
||||||
|
pool = redis.ConnectionPool.from_url('redis://localhost')
|
||||||
|
assert pool.connection_class == redis.Connection
|
||||||
|
assert pool.connection_kwargs == {
|
||||||
|
'host': 'localhost',
|
||||||
|
'port': 6379,
|
||||||
|
'db': 0,
|
||||||
|
'password': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_hostname(self):
|
||||||
|
pool = redis.ConnectionPool.from_url('redis://myhost')
|
||||||
|
assert pool.connection_class == redis.Connection
|
||||||
|
assert pool.connection_kwargs == {
|
||||||
|
'host': 'myhost',
|
||||||
|
'port': 6379,
|
||||||
|
'db': 0,
|
||||||
|
'password': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_port(self):
|
||||||
|
pool = redis.ConnectionPool.from_url('redis://localhost:6380')
|
||||||
|
assert pool.connection_class == redis.Connection
|
||||||
|
assert pool.connection_kwargs == {
|
||||||
|
'host': 'localhost',
|
||||||
|
'port': 6380,
|
||||||
|
'db': 0,
|
||||||
|
'password': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_password(self):
|
||||||
|
pool = redis.ConnectionPool.from_url('redis://:mypassword@localhost')
|
||||||
|
assert pool.connection_class == redis.Connection
|
||||||
|
assert pool.connection_kwargs == {
|
||||||
|
'host': 'localhost',
|
||||||
|
'port': 6379,
|
||||||
|
'db': 0,
|
||||||
|
'password': 'mypassword',
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_db_as_argument(self):
|
||||||
|
pool = redis.ConnectionPool.from_url('redis://localhost', db='1')
|
||||||
|
assert pool.connection_class == redis.Connection
|
||||||
|
assert pool.connection_kwargs == {
|
||||||
|
'host': 'localhost',
|
||||||
|
'port': 6379,
|
||||||
|
'db': 1,
|
||||||
|
'password': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_db_in_path(self):
|
||||||
|
pool = redis.ConnectionPool.from_url('redis://localhost/2', db='1')
|
||||||
|
assert pool.connection_class == redis.Connection
|
||||||
|
assert pool.connection_kwargs == {
|
||||||
|
'host': 'localhost',
|
||||||
|
'port': 6379,
|
||||||
|
'db': 2,
|
||||||
|
'password': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_db_in_querystring(self):
|
||||||
|
pool = redis.ConnectionPool.from_url('redis://localhost/2?db=3',
|
||||||
|
db='1')
|
||||||
|
assert pool.connection_class == redis.Connection
|
||||||
|
assert pool.connection_kwargs == {
|
||||||
|
'host': 'localhost',
|
||||||
|
'port': 6379,
|
||||||
|
'db': 3,
|
||||||
|
'password': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_extra_querystring_options(self):
|
||||||
|
pool = redis.ConnectionPool.from_url('redis://localhost?a=1&b=2')
|
||||||
|
assert pool.connection_class == redis.Connection
|
||||||
|
assert pool.connection_kwargs == {
|
||||||
|
'host': 'localhost',
|
||||||
|
'port': 6379,
|
||||||
|
'db': 0,
|
||||||
|
'password': None,
|
||||||
|
'a': '1',
|
||||||
|
'b': '2'
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_calling_from_subclass_returns_correct_instance(self):
|
||||||
|
pool = redis.BlockingConnectionPool.from_url('redis://localhost')
|
||||||
|
assert isinstance(pool, redis.BlockingConnectionPool)
|
||||||
|
|
||||||
|
def test_client_creates_connection_pool(self):
|
||||||
|
r = redis.StrictRedis.from_url('redis://myhost')
|
||||||
|
assert r.connection_pool.connection_class == redis.Connection
|
||||||
|
assert r.connection_pool.connection_kwargs == {
|
||||||
|
'host': 'myhost',
|
||||||
|
'port': 6379,
|
||||||
|
'db': 0,
|
||||||
|
'password': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestConnectionPoolUnixSocketURLParsing(object):
|
||||||
|
def test_defaults(self):
|
||||||
|
pool = redis.ConnectionPool.from_url('unix:///socket')
|
||||||
|
assert pool.connection_class == redis.UnixDomainSocketConnection
|
||||||
|
assert pool.connection_kwargs == {
|
||||||
|
'path': '/socket',
|
||||||
|
'db': 0,
|
||||||
|
'password': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_password(self):
|
||||||
|
pool = redis.ConnectionPool.from_url('unix://:mypassword@/socket')
|
||||||
|
assert pool.connection_class == redis.UnixDomainSocketConnection
|
||||||
|
assert pool.connection_kwargs == {
|
||||||
|
'path': '/socket',
|
||||||
|
'db': 0,
|
||||||
|
'password': 'mypassword',
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_db_as_argument(self):
|
||||||
|
pool = redis.ConnectionPool.from_url('unix:///socket', db=1)
|
||||||
|
assert pool.connection_class == redis.UnixDomainSocketConnection
|
||||||
|
assert pool.connection_kwargs == {
|
||||||
|
'path': '/socket',
|
||||||
|
'db': 1,
|
||||||
|
'password': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_db_in_querystring(self):
|
||||||
|
pool = redis.ConnectionPool.from_url('unix:///socket?db=2', db=1)
|
||||||
|
assert pool.connection_class == redis.UnixDomainSocketConnection
|
||||||
|
assert pool.connection_kwargs == {
|
||||||
|
'path': '/socket',
|
||||||
|
'db': 2,
|
||||||
|
'password': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_extra_querystring_options(self):
|
||||||
|
pool = redis.ConnectionPool.from_url('unix:///socket?a=1&b=2')
|
||||||
|
assert pool.connection_class == redis.UnixDomainSocketConnection
|
||||||
|
assert pool.connection_kwargs == {
|
||||||
|
'path': '/socket',
|
||||||
|
'db': 0,
|
||||||
|
'password': None,
|
||||||
|
'a': '1',
|
||||||
|
'b': '2'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSLConnectionURLParsing(object):
|
||||||
|
@pytest.mark.skipif(not ssl_available, reason="SSL not installed")
|
||||||
|
def test_defaults(self):
|
||||||
|
pool = redis.ConnectionPool.from_url('rediss://localhost')
|
||||||
|
assert pool.connection_class == redis.SSLConnection
|
||||||
|
assert pool.connection_kwargs == {
|
||||||
|
'host': 'localhost',
|
||||||
|
'port': 6379,
|
||||||
|
'db': 0,
|
||||||
|
'password': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not ssl_available, reason="SSL not installed")
|
||||||
|
def test_cert_reqs_options(self):
|
||||||
|
import ssl
|
||||||
|
pool = redis.ConnectionPool.from_url('rediss://?ssl_cert_reqs=none')
|
||||||
|
assert pool.get_connection('_').cert_reqs == ssl.CERT_NONE
|
||||||
|
|
||||||
|
pool = redis.ConnectionPool.from_url(
|
||||||
|
'rediss://?ssl_cert_reqs=optional')
|
||||||
|
assert pool.get_connection('_').cert_reqs == ssl.CERT_OPTIONAL
|
||||||
|
|
||||||
|
pool = redis.ConnectionPool.from_url(
|
||||||
|
'rediss://?ssl_cert_reqs=required')
|
||||||
|
assert pool.get_connection('_').cert_reqs == ssl.CERT_REQUIRED
|
||||||
|
|
||||||
|
|
||||||
|
class TestConnection(object):
|
||||||
|
def test_on_connect_error(self):
|
||||||
|
"""
|
||||||
|
An error in Connection.on_connect should disconnect from the server
|
||||||
|
see for details: https://github.com/andymccurdy/redis-py/issues/368
|
||||||
|
"""
|
||||||
|
# this assumes the Redis server being tested against doesn't have
|
||||||
|
# 9999 databases ;)
|
||||||
|
bad_connection = redis.Redis(db=9999)
|
||||||
|
# an error should be raised on connect
|
||||||
|
with pytest.raises(redis.RedisError):
|
||||||
|
bad_connection.info()
|
||||||
|
pool = bad_connection.connection_pool
|
||||||
|
assert len(pool._available_connections) == 1
|
||||||
|
assert not pool._available_connections[0]._sock
|
||||||
|
|
||||||
|
@skip_if_server_version_lt('2.8.8')
|
||||||
|
def test_busy_loading_disconnects_socket(self, r):
|
||||||
|
"""
|
||||||
|
If Redis raises a LOADING error, the connection should be
|
||||||
|
disconnected and a BusyLoadingError raised
|
||||||
|
"""
|
||||||
|
with pytest.raises(redis.BusyLoadingError):
|
||||||
|
r.execute_command('DEBUG', 'ERROR', 'LOADING fake message')
|
||||||
|
pool = r.connection_pool
|
||||||
|
assert len(pool._available_connections) == 1
|
||||||
|
assert not pool._available_connections[0]._sock
|
||||||
|
|
||||||
|
@skip_if_server_version_lt('2.8.8')
|
||||||
|
def test_busy_loading_from_pipeline_immediate_command(self, r):
|
||||||
|
"""
|
||||||
|
BusyLoadingErrors should raise from Pipelines that execute a
|
||||||
|
command immediately, like WATCH does.
|
||||||
|
"""
|
||||||
|
pipe = r.pipeline()
|
||||||
|
with pytest.raises(redis.BusyLoadingError):
|
||||||
|
pipe.immediate_execute_command('DEBUG', 'ERROR',
|
||||||
|
'LOADING fake message')
|
||||||
|
pool = r.connection_pool
|
||||||
|
assert not pipe.connection
|
||||||
|
assert len(pool._available_connections) == 1
|
||||||
|
assert not pool._available_connections[0]._sock
|
||||||
|
|
||||||
|
@skip_if_server_version_lt('2.8.8')
|
||||||
|
def test_busy_loading_from_pipeline(self, r):
|
||||||
|
"""
|
||||||
|
BusyLoadingErrors should be raised from a pipeline execution
|
||||||
|
regardless of the raise_on_error flag.
|
||||||
|
"""
|
||||||
|
pipe = r.pipeline()
|
||||||
|
pipe.execute_command('DEBUG', 'ERROR', 'LOADING fake message')
|
||||||
|
with pytest.raises(redis.BusyLoadingError):
|
||||||
|
pipe.execute()
|
||||||
|
pool = r.connection_pool
|
||||||
|
assert not pipe.connection
|
||||||
|
assert len(pool._available_connections) == 1
|
||||||
|
assert not pool._available_connections[0]._sock
|
||||||
|
|
||||||
|
@skip_if_server_version_lt('2.8.8')
|
||||||
|
def test_read_only_error(self, r):
|
||||||
|
"READONLY errors get turned in ReadOnlyError exceptions"
|
||||||
|
with pytest.raises(redis.ReadOnlyError):
|
||||||
|
r.execute_command('DEBUG', 'ERROR', 'READONLY blah blah')
|
||||||
|
|
||||||
|
def test_connect_from_url_tcp(self):
|
||||||
|
connection = redis.Redis.from_url('redis://localhost')
|
||||||
|
pool = connection.connection_pool
|
||||||
|
|
||||||
|
assert re.match('(.*)<(.*)<(.*)>>', repr(pool)).groups() == (
|
||||||
|
'ConnectionPool',
|
||||||
|
'Connection',
|
||||||
|
'host=localhost,port=6379,db=0',
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_connect_from_url_unix(self):
|
||||||
|
connection = redis.Redis.from_url('unix:///path/to/socket')
|
||||||
|
pool = connection.connection_pool
|
||||||
|
|
||||||
|
assert re.match('(.*)<(.*)<(.*)>>', repr(pool)).groups() == (
|
||||||
|
'ConnectionPool',
|
||||||
|
'UnixDomainSocketConnection',
|
||||||
|
'path=/path/to/socket,db=0',
|
||||||
|
)
|
|
@ -0,0 +1,33 @@
|
||||||
|
from __future__ import with_statement
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from redis._compat import unichr, u, unicode
|
||||||
|
from .conftest import r as _redis_client
|
||||||
|
|
||||||
|
|
||||||
|
class TestEncoding(object):
|
||||||
|
@pytest.fixture()
|
||||||
|
def r(self, request):
|
||||||
|
return _redis_client(request=request, decode_responses=True)
|
||||||
|
|
||||||
|
def test_simple_encoding(self, r):
|
||||||
|
unicode_string = unichr(3456) + u('abcd') + unichr(3421)
|
||||||
|
r['unicode-string'] = unicode_string
|
||||||
|
cached_val = r['unicode-string']
|
||||||
|
assert isinstance(cached_val, unicode)
|
||||||
|
assert unicode_string == cached_val
|
||||||
|
|
||||||
|
def test_list_encoding(self, r):
|
||||||
|
unicode_string = unichr(3456) + u('abcd') + unichr(3421)
|
||||||
|
result = [unicode_string, unicode_string, unicode_string]
|
||||||
|
r.rpush('a', *result)
|
||||||
|
assert r.lrange('a', 0, -1) == result
|
||||||
|
|
||||||
|
|
||||||
|
class TestCommandsAndTokensArentEncoded(object):
|
||||||
|
@pytest.fixture()
|
||||||
|
def r(self, request):
|
||||||
|
return _redis_client(request=request, charset='utf-16')
|
||||||
|
|
||||||
|
def test_basic_command(self, r):
|
||||||
|
r.set('hello', 'world')
|
|
@ -0,0 +1,167 @@
|
||||||
|
from __future__ import with_statement
|
||||||
|
import pytest
|
||||||
|
import time
|
||||||
|
|
||||||
|
from redis.exceptions import LockError, ResponseError
|
||||||
|
from redis.lock import Lock, LuaLock
|
||||||
|
|
||||||
|
|
||||||
|
class TestLock(object):
|
||||||
|
lock_class = Lock
|
||||||
|
|
||||||
|
def get_lock(self, redis, *args, **kwargs):
|
||||||
|
kwargs['lock_class'] = self.lock_class
|
||||||
|
return redis.lock(*args, **kwargs)
|
||||||
|
|
||||||
|
def test_lock(self, sr):
|
||||||
|
lock = self.get_lock(sr, 'foo')
|
||||||
|
assert lock.acquire(blocking=False)
|
||||||
|
assert sr.get('foo') == lock.local.token
|
||||||
|
assert sr.ttl('foo') == -1
|
||||||
|
lock.release()
|
||||||
|
assert sr.get('foo') is None
|
||||||
|
|
||||||
|
def test_competing_locks(self, sr):
|
||||||
|
lock1 = self.get_lock(sr, 'foo')
|
||||||
|
lock2 = self.get_lock(sr, 'foo')
|
||||||
|
assert lock1.acquire(blocking=False)
|
||||||
|
assert not lock2.acquire(blocking=False)
|
||||||
|
lock1.release()
|
||||||
|
assert lock2.acquire(blocking=False)
|
||||||
|
assert not lock1.acquire(blocking=False)
|
||||||
|
lock2.release()
|
||||||
|
|
||||||
|
def test_timeout(self, sr):
|
||||||
|
lock = self.get_lock(sr, 'foo', timeout=10)
|
||||||
|
assert lock.acquire(blocking=False)
|
||||||
|
assert 8 < sr.ttl('foo') <= 10
|
||||||
|
lock.release()
|
||||||
|
|
||||||
|
def test_float_timeout(self, sr):
|
||||||
|
lock = self.get_lock(sr, 'foo', timeout=9.5)
|
||||||
|
assert lock.acquire(blocking=False)
|
||||||
|
assert 8 < sr.pttl('foo') <= 9500
|
||||||
|
lock.release()
|
||||||
|
|
||||||
|
def test_blocking_timeout(self, sr):
|
||||||
|
lock1 = self.get_lock(sr, 'foo')
|
||||||
|
assert lock1.acquire(blocking=False)
|
||||||
|
lock2 = self.get_lock(sr, 'foo', blocking_timeout=0.2)
|
||||||
|
start = time.time()
|
||||||
|
assert not lock2.acquire()
|
||||||
|
assert (time.time() - start) > 0.2
|
||||||
|
lock1.release()
|
||||||
|
|
||||||
|
def test_context_manager(self, sr):
|
||||||
|
# blocking_timeout prevents a deadlock if the lock can't be acquired
|
||||||
|
# for some reason
|
||||||
|
with self.get_lock(sr, 'foo', blocking_timeout=0.2) as lock:
|
||||||
|
assert sr.get('foo') == lock.local.token
|
||||||
|
assert sr.get('foo') is None
|
||||||
|
|
||||||
|
def test_high_sleep_raises_error(self, sr):
|
||||||
|
"If sleep is higher than timeout, it should raise an error"
|
||||||
|
with pytest.raises(LockError):
|
||||||
|
self.get_lock(sr, 'foo', timeout=1, sleep=2)
|
||||||
|
|
||||||
|
def test_releasing_unlocked_lock_raises_error(self, sr):
|
||||||
|
lock = self.get_lock(sr, 'foo')
|
||||||
|
with pytest.raises(LockError):
|
||||||
|
lock.release()
|
||||||
|
|
||||||
|
def test_releasing_lock_no_longer_owned_raises_error(self, sr):
|
||||||
|
lock = self.get_lock(sr, 'foo')
|
||||||
|
lock.acquire(blocking=False)
|
||||||
|
# manually change the token
|
||||||
|
sr.set('foo', 'a')
|
||||||
|
with pytest.raises(LockError):
|
||||||
|
lock.release()
|
||||||
|
# even though we errored, the token is still cleared
|
||||||
|
assert lock.local.token is None
|
||||||
|
|
||||||
|
def test_extend_lock(self, sr):
|
||||||
|
lock = self.get_lock(sr, 'foo', timeout=10)
|
||||||
|
assert lock.acquire(blocking=False)
|
||||||
|
assert 8000 < sr.pttl('foo') <= 10000
|
||||||
|
assert lock.extend(10)
|
||||||
|
assert 16000 < sr.pttl('foo') <= 20000
|
||||||
|
lock.release()
|
||||||
|
|
||||||
|
def test_extend_lock_float(self, sr):
|
||||||
|
lock = self.get_lock(sr, 'foo', timeout=10.0)
|
||||||
|
assert lock.acquire(blocking=False)
|
||||||
|
assert 8000 < sr.pttl('foo') <= 10000
|
||||||
|
assert lock.extend(10.0)
|
||||||
|
assert 16000 < sr.pttl('foo') <= 20000
|
||||||
|
lock.release()
|
||||||
|
|
||||||
|
def test_extending_unlocked_lock_raises_error(self, sr):
|
||||||
|
lock = self.get_lock(sr, 'foo', timeout=10)
|
||||||
|
with pytest.raises(LockError):
|
||||||
|
lock.extend(10)
|
||||||
|
|
||||||
|
def test_extending_lock_with_no_timeout_raises_error(self, sr):
|
||||||
|
lock = self.get_lock(sr, 'foo')
|
||||||
|
assert lock.acquire(blocking=False)
|
||||||
|
with pytest.raises(LockError):
|
||||||
|
lock.extend(10)
|
||||||
|
lock.release()
|
||||||
|
|
||||||
|
def test_extending_lock_no_longer_owned_raises_error(self, sr):
|
||||||
|
lock = self.get_lock(sr, 'foo')
|
||||||
|
assert lock.acquire(blocking=False)
|
||||||
|
sr.set('foo', 'a')
|
||||||
|
with pytest.raises(LockError):
|
||||||
|
lock.extend(10)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLuaLock(TestLock):
|
||||||
|
lock_class = LuaLock
|
||||||
|
|
||||||
|
|
||||||
|
class TestLockClassSelection(object):
|
||||||
|
def test_lock_class_argument(self, sr):
|
||||||
|
lock = sr.lock('foo', lock_class=Lock)
|
||||||
|
assert type(lock) == Lock
|
||||||
|
lock = sr.lock('foo', lock_class=LuaLock)
|
||||||
|
assert type(lock) == LuaLock
|
||||||
|
|
||||||
|
def test_cached_lualock_flag(self, sr):
|
||||||
|
try:
|
||||||
|
sr._use_lua_lock = True
|
||||||
|
lock = sr.lock('foo')
|
||||||
|
assert type(lock) == LuaLock
|
||||||
|
finally:
|
||||||
|
sr._use_lua_lock = None
|
||||||
|
|
||||||
|
def test_cached_lock_flag(self, sr):
|
||||||
|
try:
|
||||||
|
sr._use_lua_lock = False
|
||||||
|
lock = sr.lock('foo')
|
||||||
|
assert type(lock) == Lock
|
||||||
|
finally:
|
||||||
|
sr._use_lua_lock = None
|
||||||
|
|
||||||
|
def test_lua_compatible_server(self, sr, monkeypatch):
|
||||||
|
@classmethod
|
||||||
|
def mock_register(cls, redis):
|
||||||
|
return
|
||||||
|
monkeypatch.setattr(LuaLock, 'register_scripts', mock_register)
|
||||||
|
try:
|
||||||
|
lock = sr.lock('foo')
|
||||||
|
assert type(lock) == LuaLock
|
||||||
|
assert sr._use_lua_lock is True
|
||||||
|
finally:
|
||||||
|
sr._use_lua_lock = None
|
||||||
|
|
||||||
|
def test_lua_unavailable(self, sr, monkeypatch):
|
||||||
|
@classmethod
|
||||||
|
def mock_register(cls, redis):
|
||||||
|
raise ResponseError()
|
||||||
|
monkeypatch.setattr(LuaLock, 'register_scripts', mock_register)
|
||||||
|
try:
|
||||||
|
lock = sr.lock('foo')
|
||||||
|
assert type(lock) == Lock
|
||||||
|
assert sr._use_lua_lock is False
|
||||||
|
finally:
|
||||||
|
sr._use_lua_lock = None
|
|
@ -0,0 +1,226 @@
|
||||||
|
from __future__ import with_statement
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import redis
|
||||||
|
from redis._compat import b, u, unichr, unicode
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipeline(object):
|
||||||
|
def test_pipeline(self, r):
|
||||||
|
with r.pipeline() as pipe:
|
||||||
|
pipe.set('a', 'a1').get('a').zadd('z', z1=1).zadd('z', z2=4)
|
||||||
|
pipe.zincrby('z', 'z1').zrange('z', 0, 5, withscores=True)
|
||||||
|
assert pipe.execute() == \
|
||||||
|
[
|
||||||
|
True,
|
||||||
|
b('a1'),
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
2.0,
|
||||||
|
[(b('z1'), 2.0), (b('z2'), 4)],
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_pipeline_length(self, r):
|
||||||
|
with r.pipeline() as pipe:
|
||||||
|
# Initially empty.
|
||||||
|
assert len(pipe) == 0
|
||||||
|
assert not pipe
|
||||||
|
|
||||||
|
# Fill 'er up!
|
||||||
|
pipe.set('a', 'a1').set('b', 'b1').set('c', 'c1')
|
||||||
|
assert len(pipe) == 3
|
||||||
|
assert pipe
|
||||||
|
|
||||||
|
# Execute calls reset(), so empty once again.
|
||||||
|
pipe.execute()
|
||||||
|
assert len(pipe) == 0
|
||||||
|
assert not pipe
|
||||||
|
|
||||||
|
def test_pipeline_no_transaction(self, r):
|
||||||
|
with r.pipeline(transaction=False) as pipe:
|
||||||
|
pipe.set('a', 'a1').set('b', 'b1').set('c', 'c1')
|
||||||
|
assert pipe.execute() == [True, True, True]
|
||||||
|
assert r['a'] == b('a1')
|
||||||
|
assert r['b'] == b('b1')
|
||||||
|
assert r['c'] == b('c1')
|
||||||
|
|
||||||
|
def test_pipeline_no_transaction_watch(self, r):
|
||||||
|
r['a'] = 0
|
||||||
|
|
||||||
|
with r.pipeline(transaction=False) as pipe:
|
||||||
|
pipe.watch('a')
|
||||||
|
a = pipe.get('a')
|
||||||
|
|
||||||
|
pipe.multi()
|
||||||
|
pipe.set('a', int(a) + 1)
|
||||||
|
assert pipe.execute() == [True]
|
||||||
|
|
||||||
|
def test_pipeline_no_transaction_watch_failure(self, r):
|
||||||
|
r['a'] = 0
|
||||||
|
|
||||||
|
with r.pipeline(transaction=False) as pipe:
|
||||||
|
pipe.watch('a')
|
||||||
|
a = pipe.get('a')
|
||||||
|
|
||||||
|
r['a'] = 'bad'
|
||||||
|
|
||||||
|
pipe.multi()
|
||||||
|
pipe.set('a', int(a) + 1)
|
||||||
|
|
||||||
|
with pytest.raises(redis.WatchError):
|
||||||
|
pipe.execute()
|
||||||
|
|
||||||
|
assert r['a'] == b('bad')
|
||||||
|
|
||||||
|
def test_exec_error_in_response(self, r):
|
||||||
|
"""
|
||||||
|
an invalid pipeline command at exec time adds the exception instance
|
||||||
|
to the list of returned values
|
||||||
|
"""
|
||||||
|
r['c'] = 'a'
|
||||||
|
with r.pipeline() as pipe:
|
||||||
|
pipe.set('a', 1).set('b', 2).lpush('c', 3).set('d', 4)
|
||||||
|
result = pipe.execute(raise_on_error=False)
|
||||||
|
|
||||||
|
assert result[0]
|
||||||
|
assert r['a'] == b('1')
|
||||||
|
assert result[1]
|
||||||
|
assert r['b'] == b('2')
|
||||||
|
|
||||||
|
# we can't lpush to a key that's a string value, so this should
|
||||||
|
# be a ResponseError exception
|
||||||
|
assert isinstance(result[2], redis.ResponseError)
|
||||||
|
assert r['c'] == b('a')
|
||||||
|
|
||||||
|
# since this isn't a transaction, the other commands after the
|
||||||
|
# error are still executed
|
||||||
|
assert result[3]
|
||||||
|
assert r['d'] == b('4')
|
||||||
|
|
||||||
|
# make sure the pipe was restored to a working state
|
||||||
|
assert pipe.set('z', 'zzz').execute() == [True]
|
||||||
|
assert r['z'] == b('zzz')
|
||||||
|
|
||||||
|
def test_exec_error_raised(self, r):
|
||||||
|
r['c'] = 'a'
|
||||||
|
with r.pipeline() as pipe:
|
||||||
|
pipe.set('a', 1).set('b', 2).lpush('c', 3).set('d', 4)
|
||||||
|
with pytest.raises(redis.ResponseError) as ex:
|
||||||
|
pipe.execute()
|
||||||
|
assert unicode(ex.value).startswith('Command # 3 (LPUSH c 3) of '
|
||||||
|
'pipeline caused error: ')
|
||||||
|
|
||||||
|
# make sure the pipe was restored to a working state
|
||||||
|
assert pipe.set('z', 'zzz').execute() == [True]
|
||||||
|
assert r['z'] == b('zzz')
|
||||||
|
|
||||||
|
def test_parse_error_raised(self, r):
|
||||||
|
with r.pipeline() as pipe:
|
||||||
|
# the zrem is invalid because we don't pass any keys to it
|
||||||
|
pipe.set('a', 1).zrem('b').set('b', 2)
|
||||||
|
with pytest.raises(redis.ResponseError) as ex:
|
||||||
|
pipe.execute()
|
||||||
|
|
||||||
|
assert unicode(ex.value).startswith('Command # 2 (ZREM b) of '
|
||||||
|
'pipeline caused error: ')
|
||||||
|
|
||||||
|
# make sure the pipe was restored to a working state
|
||||||
|
assert pipe.set('z', 'zzz').execute() == [True]
|
||||||
|
assert r['z'] == b('zzz')
|
||||||
|
|
||||||
|
def test_watch_succeed(self, r):
|
||||||
|
r['a'] = 1
|
||||||
|
r['b'] = 2
|
||||||
|
|
||||||
|
with r.pipeline() as pipe:
|
||||||
|
pipe.watch('a', 'b')
|
||||||
|
assert pipe.watching
|
||||||
|
a_value = pipe.get('a')
|
||||||
|
b_value = pipe.get('b')
|
||||||
|
assert a_value == b('1')
|
||||||
|
assert b_value == b('2')
|
||||||
|
pipe.multi()
|
||||||
|
|
||||||
|
pipe.set('c', 3)
|
||||||
|
assert pipe.execute() == [True]
|
||||||
|
assert not pipe.watching
|
||||||
|
|
||||||
|
def test_watch_failure(self, r):
|
||||||
|
r['a'] = 1
|
||||||
|
r['b'] = 2
|
||||||
|
|
||||||
|
with r.pipeline() as pipe:
|
||||||
|
pipe.watch('a', 'b')
|
||||||
|
r['b'] = 3
|
||||||
|
pipe.multi()
|
||||||
|
pipe.get('a')
|
||||||
|
with pytest.raises(redis.WatchError):
|
||||||
|
pipe.execute()
|
||||||
|
|
||||||
|
assert not pipe.watching
|
||||||
|
|
||||||
|
def test_unwatch(self, r):
|
||||||
|
r['a'] = 1
|
||||||
|
r['b'] = 2
|
||||||
|
|
||||||
|
with r.pipeline() as pipe:
|
||||||
|
pipe.watch('a', 'b')
|
||||||
|
r['b'] = 3
|
||||||
|
pipe.unwatch()
|
||||||
|
assert not pipe.watching
|
||||||
|
pipe.get('a')
|
||||||
|
assert pipe.execute() == [b('1')]
|
||||||
|
|
||||||
|
def test_transaction_callable(self, r):
|
||||||
|
r['a'] = 1
|
||||||
|
r['b'] = 2
|
||||||
|
has_run = []
|
||||||
|
|
||||||
|
def my_transaction(pipe):
|
||||||
|
a_value = pipe.get('a')
|
||||||
|
assert a_value in (b('1'), b('2'))
|
||||||
|
b_value = pipe.get('b')
|
||||||
|
assert b_value == b('2')
|
||||||
|
|
||||||
|
# silly run-once code... incr's "a" so WatchError should be raised
|
||||||
|
# forcing this all to run again. this should incr "a" once to "2"
|
||||||
|
if not has_run:
|
||||||
|
r.incr('a')
|
||||||
|
has_run.append('it has')
|
||||||
|
|
||||||
|
pipe.multi()
|
||||||
|
pipe.set('c', int(a_value) + int(b_value))
|
||||||
|
|
||||||
|
result = r.transaction(my_transaction, 'a', 'b')
|
||||||
|
assert result == [True]
|
||||||
|
assert r['c'] == b('4')
|
||||||
|
|
||||||
|
def test_exec_error_in_no_transaction_pipeline(self, r):
|
||||||
|
r['a'] = 1
|
||||||
|
with r.pipeline(transaction=False) as pipe:
|
||||||
|
pipe.llen('a')
|
||||||
|
pipe.expire('a', 100)
|
||||||
|
|
||||||
|
with pytest.raises(redis.ResponseError) as ex:
|
||||||
|
pipe.execute()
|
||||||
|
|
||||||
|
assert unicode(ex.value).startswith('Command # 1 (LLEN a) of '
|
||||||
|
'pipeline caused error: ')
|
||||||
|
|
||||||
|
assert r['a'] == b('1')
|
||||||
|
|
||||||
|
def test_exec_error_in_no_transaction_pipeline_unicode_command(self, r):
|
||||||
|
key = unichr(3456) + u('abcd') + unichr(3421)
|
||||||
|
r[key] = 1
|
||||||
|
with r.pipeline(transaction=False) as pipe:
|
||||||
|
pipe.llen(key)
|
||||||
|
pipe.expire(key, 100)
|
||||||
|
|
||||||
|
with pytest.raises(redis.ResponseError) as ex:
|
||||||
|
pipe.execute()
|
||||||
|
|
||||||
|
expected = unicode('Command # 1 (LLEN %s) of pipeline caused '
|
||||||
|
'error: ') % key
|
||||||
|
assert unicode(ex.value).startswith(expected)
|
||||||
|
|
||||||
|
assert r[key] == b('1')
|
|
@ -0,0 +1,392 @@
|
||||||
|
from __future__ import with_statement
|
||||||
|
import pytest
|
||||||
|
import time
|
||||||
|
|
||||||
|
import redis
|
||||||
|
from redis.exceptions import ConnectionError
|
||||||
|
from redis._compat import basestring, u, unichr
|
||||||
|
|
||||||
|
from .conftest import r as _redis_client
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False):
|
||||||
|
now = time.time()
|
||||||
|
timeout = now + timeout
|
||||||
|
while now < timeout:
|
||||||
|
message = pubsub.get_message(
|
||||||
|
ignore_subscribe_messages=ignore_subscribe_messages)
|
||||||
|
if message is not None:
|
||||||
|
return message
|
||||||
|
time.sleep(0.01)
|
||||||
|
now = time.time()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def make_message(type, channel, data, pattern=None):
|
||||||
|
return {
|
||||||
|
'type': type,
|
||||||
|
'pattern': pattern and pattern.encode('utf-8') or None,
|
||||||
|
'channel': channel.encode('utf-8'),
|
||||||
|
'data': data.encode('utf-8') if isinstance(data, basestring) else data
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def make_subscribe_test_data(pubsub, type):
|
||||||
|
if type == 'channel':
|
||||||
|
return {
|
||||||
|
'p': pubsub,
|
||||||
|
'sub_type': 'subscribe',
|
||||||
|
'unsub_type': 'unsubscribe',
|
||||||
|
'sub_func': pubsub.subscribe,
|
||||||
|
'unsub_func': pubsub.unsubscribe,
|
||||||
|
'keys': ['foo', 'bar', u('uni') + unichr(4456) + u('code')]
|
||||||
|
}
|
||||||
|
elif type == 'pattern':
|
||||||
|
return {
|
||||||
|
'p': pubsub,
|
||||||
|
'sub_type': 'psubscribe',
|
||||||
|
'unsub_type': 'punsubscribe',
|
||||||
|
'sub_func': pubsub.psubscribe,
|
||||||
|
'unsub_func': pubsub.punsubscribe,
|
||||||
|
'keys': ['f*', 'b*', u('uni') + unichr(4456) + u('*')]
|
||||||
|
}
|
||||||
|
assert False, 'invalid subscribe type: %s' % type
|
||||||
|
|
||||||
|
|
||||||
|
class TestPubSubSubscribeUnsubscribe(object):
|
||||||
|
|
||||||
|
def _test_subscribe_unsubscribe(self, p, sub_type, unsub_type, sub_func,
|
||||||
|
unsub_func, keys):
|
||||||
|
for key in keys:
|
||||||
|
assert sub_func(key) is None
|
||||||
|
|
||||||
|
# should be a message for each channel/pattern we just subscribed to
|
||||||
|
for i, key in enumerate(keys):
|
||||||
|
assert wait_for_message(p) == make_message(sub_type, key, i + 1)
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
assert unsub_func(key) is None
|
||||||
|
|
||||||
|
# should be a message for each channel/pattern we just unsubscribed
|
||||||
|
# from
|
||||||
|
for i, key in enumerate(keys):
|
||||||
|
i = len(keys) - 1 - i
|
||||||
|
assert wait_for_message(p) == make_message(unsub_type, key, i)
|
||||||
|
|
||||||
|
def test_channel_subscribe_unsubscribe(self, r):
|
||||||
|
kwargs = make_subscribe_test_data(r.pubsub(), 'channel')
|
||||||
|
self._test_subscribe_unsubscribe(**kwargs)
|
||||||
|
|
||||||
|
def test_pattern_subscribe_unsubscribe(self, r):
|
||||||
|
kwargs = make_subscribe_test_data(r.pubsub(), 'pattern')
|
||||||
|
self._test_subscribe_unsubscribe(**kwargs)
|
||||||
|
|
||||||
|
def _test_resubscribe_on_reconnection(self, p, sub_type, unsub_type,
|
||||||
|
sub_func, unsub_func, keys):
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
assert sub_func(key) is None
|
||||||
|
|
||||||
|
# should be a message for each channel/pattern we just subscribed to
|
||||||
|
for i, key in enumerate(keys):
|
||||||
|
assert wait_for_message(p) == make_message(sub_type, key, i + 1)
|
||||||
|
|
||||||
|
# manually disconnect
|
||||||
|
p.connection.disconnect()
|
||||||
|
|
||||||
|
# calling get_message again reconnects and resubscribes
|
||||||
|
# note, we may not re-subscribe to channels in exactly the same order
|
||||||
|
# so we have to do some extra checks to make sure we got them all
|
||||||
|
messages = []
|
||||||
|
for i in range(len(keys)):
|
||||||
|
messages.append(wait_for_message(p))
|
||||||
|
|
||||||
|
unique_channels = set()
|
||||||
|
assert len(messages) == len(keys)
|
||||||
|
for i, message in enumerate(messages):
|
||||||
|
assert message['type'] == sub_type
|
||||||
|
assert message['data'] == i + 1
|
||||||
|
assert isinstance(message['channel'], bytes)
|
||||||
|
channel = message['channel'].decode('utf-8')
|
||||||
|
unique_channels.add(channel)
|
||||||
|
|
||||||
|
assert len(unique_channels) == len(keys)
|
||||||
|
for channel in unique_channels:
|
||||||
|
assert channel in keys
|
||||||
|
|
||||||
|
def test_resubscribe_to_channels_on_reconnection(self, r):
|
||||||
|
kwargs = make_subscribe_test_data(r.pubsub(), 'channel')
|
||||||
|
self._test_resubscribe_on_reconnection(**kwargs)
|
||||||
|
|
||||||
|
def test_resubscribe_to_patterns_on_reconnection(self, r):
|
||||||
|
kwargs = make_subscribe_test_data(r.pubsub(), 'pattern')
|
||||||
|
self._test_resubscribe_on_reconnection(**kwargs)
|
||||||
|
|
||||||
|
def _test_subscribed_property(self, p, sub_type, unsub_type, sub_func,
|
||||||
|
unsub_func, keys):
|
||||||
|
|
||||||
|
assert p.subscribed is False
|
||||||
|
sub_func(keys[0])
|
||||||
|
# we're now subscribed even though we haven't processed the
|
||||||
|
# reply from the server just yet
|
||||||
|
assert p.subscribed is True
|
||||||
|
assert wait_for_message(p) == make_message(sub_type, keys[0], 1)
|
||||||
|
# we're still subscribed
|
||||||
|
assert p.subscribed is True
|
||||||
|
|
||||||
|
# unsubscribe from all channels
|
||||||
|
unsub_func()
|
||||||
|
# we're still technically subscribed until we process the
|
||||||
|
# response messages from the server
|
||||||
|
assert p.subscribed is True
|
||||||
|
assert wait_for_message(p) == make_message(unsub_type, keys[0], 0)
|
||||||
|
# now we're no longer subscribed as no more messages can be delivered
|
||||||
|
# to any channels we were listening to
|
||||||
|
assert p.subscribed is False
|
||||||
|
|
||||||
|
# subscribing again flips the flag back
|
||||||
|
sub_func(keys[0])
|
||||||
|
assert p.subscribed is True
|
||||||
|
assert wait_for_message(p) == make_message(sub_type, keys[0], 1)
|
||||||
|
|
||||||
|
# unsubscribe again
|
||||||
|
unsub_func()
|
||||||
|
assert p.subscribed is True
|
||||||
|
# subscribe to another channel before reading the unsubscribe response
|
||||||
|
sub_func(keys[1])
|
||||||
|
assert p.subscribed is True
|
||||||
|
# read the unsubscribe for key1
|
||||||
|
assert wait_for_message(p) == make_message(unsub_type, keys[0], 0)
|
||||||
|
# we're still subscribed to key2, so subscribed should still be True
|
||||||
|
assert p.subscribed is True
|
||||||
|
# read the key2 subscribe message
|
||||||
|
assert wait_for_message(p) == make_message(sub_type, keys[1], 1)
|
||||||
|
unsub_func()
|
||||||
|
# haven't read the message yet, so we're still subscribed
|
||||||
|
assert p.subscribed is True
|
||||||
|
assert wait_for_message(p) == make_message(unsub_type, keys[1], 0)
|
||||||
|
# now we're finally unsubscribed
|
||||||
|
assert p.subscribed is False
|
||||||
|
|
||||||
|
def test_subscribe_property_with_channels(self, r):
|
||||||
|
kwargs = make_subscribe_test_data(r.pubsub(), 'channel')
|
||||||
|
self._test_subscribed_property(**kwargs)
|
||||||
|
|
||||||
|
def test_subscribe_property_with_patterns(self, r):
|
||||||
|
kwargs = make_subscribe_test_data(r.pubsub(), 'pattern')
|
||||||
|
self._test_subscribed_property(**kwargs)
|
||||||
|
|
||||||
|
def test_ignore_all_subscribe_messages(self, r):
|
||||||
|
p = r.pubsub(ignore_subscribe_messages=True)
|
||||||
|
|
||||||
|
checks = (
|
||||||
|
(p.subscribe, 'foo'),
|
||||||
|
(p.unsubscribe, 'foo'),
|
||||||
|
(p.psubscribe, 'f*'),
|
||||||
|
(p.punsubscribe, 'f*'),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert p.subscribed is False
|
||||||
|
for func, channel in checks:
|
||||||
|
assert func(channel) is None
|
||||||
|
assert p.subscribed is True
|
||||||
|
assert wait_for_message(p) is None
|
||||||
|
assert p.subscribed is False
|
||||||
|
|
||||||
|
def test_ignore_individual_subscribe_messages(self, r):
|
||||||
|
p = r.pubsub()
|
||||||
|
|
||||||
|
checks = (
|
||||||
|
(p.subscribe, 'foo'),
|
||||||
|
(p.unsubscribe, 'foo'),
|
||||||
|
(p.psubscribe, 'f*'),
|
||||||
|
(p.punsubscribe, 'f*'),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert p.subscribed is False
|
||||||
|
for func, channel in checks:
|
||||||
|
assert func(channel) is None
|
||||||
|
assert p.subscribed is True
|
||||||
|
message = wait_for_message(p, ignore_subscribe_messages=True)
|
||||||
|
assert message is None
|
||||||
|
assert p.subscribed is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestPubSubMessages(object):
|
||||||
|
def setup_method(self, method):
|
||||||
|
self.message = None
|
||||||
|
|
||||||
|
def message_handler(self, message):
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
def test_published_message_to_channel(self, r):
|
||||||
|
p = r.pubsub(ignore_subscribe_messages=True)
|
||||||
|
p.subscribe('foo')
|
||||||
|
assert r.publish('foo', 'test message') == 1
|
||||||
|
|
||||||
|
message = wait_for_message(p)
|
||||||
|
assert isinstance(message, dict)
|
||||||
|
assert message == make_message('message', 'foo', 'test message')
|
||||||
|
|
||||||
|
def test_published_message_to_pattern(self, r):
|
||||||
|
p = r.pubsub(ignore_subscribe_messages=True)
|
||||||
|
p.subscribe('foo')
|
||||||
|
p.psubscribe('f*')
|
||||||
|
# 1 to pattern, 1 to channel
|
||||||
|
assert r.publish('foo', 'test message') == 2
|
||||||
|
|
||||||
|
message1 = wait_for_message(p)
|
||||||
|
message2 = wait_for_message(p)
|
||||||
|
assert isinstance(message1, dict)
|
||||||
|
assert isinstance(message2, dict)
|
||||||
|
|
||||||
|
expected = [
|
||||||
|
make_message('message', 'foo', 'test message'),
|
||||||
|
make_message('pmessage', 'foo', 'test message', pattern='f*')
|
||||||
|
]
|
||||||
|
|
||||||
|
assert message1 in expected
|
||||||
|
assert message2 in expected
|
||||||
|
assert message1 != message2
|
||||||
|
|
||||||
|
def test_channel_message_handler(self, r):
|
||||||
|
p = r.pubsub(ignore_subscribe_messages=True)
|
||||||
|
p.subscribe(foo=self.message_handler)
|
||||||
|
assert r.publish('foo', 'test message') == 1
|
||||||
|
assert wait_for_message(p) is None
|
||||||
|
assert self.message == make_message('message', 'foo', 'test message')
|
||||||
|
|
||||||
|
def test_pattern_message_handler(self, r):
|
||||||
|
p = r.pubsub(ignore_subscribe_messages=True)
|
||||||
|
p.psubscribe(**{'f*': self.message_handler})
|
||||||
|
assert r.publish('foo', 'test message') == 1
|
||||||
|
assert wait_for_message(p) is None
|
||||||
|
assert self.message == make_message('pmessage', 'foo', 'test message',
|
||||||
|
pattern='f*')
|
||||||
|
|
||||||
|
def test_unicode_channel_message_handler(self, r):
|
||||||
|
p = r.pubsub(ignore_subscribe_messages=True)
|
||||||
|
channel = u('uni') + unichr(4456) + u('code')
|
||||||
|
channels = {channel: self.message_handler}
|
||||||
|
p.subscribe(**channels)
|
||||||
|
assert r.publish(channel, 'test message') == 1
|
||||||
|
assert wait_for_message(p) is None
|
||||||
|
assert self.message == make_message('message', channel, 'test message')
|
||||||
|
|
||||||
|
def test_unicode_pattern_message_handler(self, r):
|
||||||
|
p = r.pubsub(ignore_subscribe_messages=True)
|
||||||
|
pattern = u('uni') + unichr(4456) + u('*')
|
||||||
|
channel = u('uni') + unichr(4456) + u('code')
|
||||||
|
p.psubscribe(**{pattern: self.message_handler})
|
||||||
|
assert r.publish(channel, 'test message') == 1
|
||||||
|
assert wait_for_message(p) is None
|
||||||
|
assert self.message == make_message('pmessage', channel,
|
||||||
|
'test message', pattern=pattern)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPubSubAutoDecoding(object):
|
||||||
|
"These tests only validate that we get unicode values back"
|
||||||
|
|
||||||
|
channel = u('uni') + unichr(4456) + u('code')
|
||||||
|
pattern = u('uni') + unichr(4456) + u('*')
|
||||||
|
data = u('abc') + unichr(4458) + u('123')
|
||||||
|
|
||||||
|
def make_message(self, type, channel, data, pattern=None):
|
||||||
|
return {
|
||||||
|
'type': type,
|
||||||
|
'channel': channel,
|
||||||
|
'pattern': pattern,
|
||||||
|
'data': data
|
||||||
|
}
|
||||||
|
|
||||||
|
def setup_method(self, method):
|
||||||
|
self.message = None
|
||||||
|
|
||||||
|
def message_handler(self, message):
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def r(self, request):
|
||||||
|
return _redis_client(request=request, decode_responses=True)
|
||||||
|
|
||||||
|
def test_channel_subscribe_unsubscribe(self, r):
|
||||||
|
p = r.pubsub()
|
||||||
|
p.subscribe(self.channel)
|
||||||
|
assert wait_for_message(p) == self.make_message('subscribe',
|
||||||
|
self.channel, 1)
|
||||||
|
|
||||||
|
p.unsubscribe(self.channel)
|
||||||
|
assert wait_for_message(p) == self.make_message('unsubscribe',
|
||||||
|
self.channel, 0)
|
||||||
|
|
||||||
|
def test_pattern_subscribe_unsubscribe(self, r):
|
||||||
|
p = r.pubsub()
|
||||||
|
p.psubscribe(self.pattern)
|
||||||
|
assert wait_for_message(p) == self.make_message('psubscribe',
|
||||||
|
self.pattern, 1)
|
||||||
|
|
||||||
|
p.punsubscribe(self.pattern)
|
||||||
|
assert wait_for_message(p) == self.make_message('punsubscribe',
|
||||||
|
self.pattern, 0)
|
||||||
|
|
||||||
|
def test_channel_publish(self, r):
|
||||||
|
p = r.pubsub(ignore_subscribe_messages=True)
|
||||||
|
p.subscribe(self.channel)
|
||||||
|
r.publish(self.channel, self.data)
|
||||||
|
assert wait_for_message(p) == self.make_message('message',
|
||||||
|
self.channel,
|
||||||
|
self.data)
|
||||||
|
|
||||||
|
def test_pattern_publish(self, r):
|
||||||
|
p = r.pubsub(ignore_subscribe_messages=True)
|
||||||
|
p.psubscribe(self.pattern)
|
||||||
|
r.publish(self.channel, self.data)
|
||||||
|
assert wait_for_message(p) == self.make_message('pmessage',
|
||||||
|
self.channel,
|
||||||
|
self.data,
|
||||||
|
pattern=self.pattern)
|
||||||
|
|
||||||
|
def test_channel_message_handler(self, r):
|
||||||
|
p = r.pubsub(ignore_subscribe_messages=True)
|
||||||
|
p.subscribe(**{self.channel: self.message_handler})
|
||||||
|
r.publish(self.channel, self.data)
|
||||||
|
assert wait_for_message(p) is None
|
||||||
|
assert self.message == self.make_message('message', self.channel,
|
||||||
|
self.data)
|
||||||
|
|
||||||
|
# test that we reconnected to the correct channel
|
||||||
|
p.connection.disconnect()
|
||||||
|
assert wait_for_message(p) is None # should reconnect
|
||||||
|
new_data = self.data + u('new data')
|
||||||
|
r.publish(self.channel, new_data)
|
||||||
|
assert wait_for_message(p) is None
|
||||||
|
assert self.message == self.make_message('message', self.channel,
|
||||||
|
new_data)
|
||||||
|
|
||||||
|
def test_pattern_message_handler(self, r):
|
||||||
|
p = r.pubsub(ignore_subscribe_messages=True)
|
||||||
|
p.psubscribe(**{self.pattern: self.message_handler})
|
||||||
|
r.publish(self.channel, self.data)
|
||||||
|
assert wait_for_message(p) is None
|
||||||
|
assert self.message == self.make_message('pmessage', self.channel,
|
||||||
|
self.data,
|
||||||
|
pattern=self.pattern)
|
||||||
|
|
||||||
|
# test that we reconnected to the correct pattern
|
||||||
|
p.connection.disconnect()
|
||||||
|
assert wait_for_message(p) is None # should reconnect
|
||||||
|
new_data = self.data + u('new data')
|
||||||
|
r.publish(self.channel, new_data)
|
||||||
|
assert wait_for_message(p) is None
|
||||||
|
assert self.message == self.make_message('pmessage', self.channel,
|
||||||
|
new_data,
|
||||||
|
pattern=self.pattern)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPubSubRedisDown(object):
|
||||||
|
|
||||||
|
def test_channel_subscribe(self, r):
|
||||||
|
r = redis.Redis(host='localhost', port=6390)
|
||||||
|
p = r.pubsub()
|
||||||
|
with pytest.raises(ConnectionError):
|
||||||
|
p.subscribe('foo')
|
|
@ -0,0 +1,82 @@
|
||||||
|
from __future__ import with_statement
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from redis import exceptions
|
||||||
|
from redis._compat import b
|
||||||
|
|
||||||
|
|
||||||
|
multiply_script = """
|
||||||
|
local value = redis.call('GET', KEYS[1])
|
||||||
|
value = tonumber(value)
|
||||||
|
return value * ARGV[1]"""
|
||||||
|
|
||||||
|
|
||||||
|
class TestScripting(object):
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_scripts(self, r):
|
||||||
|
r.script_flush()
|
||||||
|
|
||||||
|
def test_eval(self, r):
|
||||||
|
r.set('a', 2)
|
||||||
|
# 2 * 3 == 6
|
||||||
|
assert r.eval(multiply_script, 1, 'a', 3) == 6
|
||||||
|
|
||||||
|
def test_evalsha(self, r):
|
||||||
|
r.set('a', 2)
|
||||||
|
sha = r.script_load(multiply_script)
|
||||||
|
# 2 * 3 == 6
|
||||||
|
assert r.evalsha(sha, 1, 'a', 3) == 6
|
||||||
|
|
||||||
|
def test_evalsha_script_not_loaded(self, r):
|
||||||
|
r.set('a', 2)
|
||||||
|
sha = r.script_load(multiply_script)
|
||||||
|
# remove the script from Redis's cache
|
||||||
|
r.script_flush()
|
||||||
|
with pytest.raises(exceptions.NoScriptError):
|
||||||
|
r.evalsha(sha, 1, 'a', 3)
|
||||||
|
|
||||||
|
def test_script_loading(self, r):
|
||||||
|
# get the sha, then clear the cache
|
||||||
|
sha = r.script_load(multiply_script)
|
||||||
|
r.script_flush()
|
||||||
|
assert r.script_exists(sha) == [False]
|
||||||
|
r.script_load(multiply_script)
|
||||||
|
assert r.script_exists(sha) == [True]
|
||||||
|
|
||||||
|
def test_script_object(self, r):
|
||||||
|
r.set('a', 2)
|
||||||
|
multiply = r.register_script(multiply_script)
|
||||||
|
assert not multiply.sha
|
||||||
|
# test evalsha fail -> script load + retry
|
||||||
|
assert multiply(keys=['a'], args=[3]) == 6
|
||||||
|
assert multiply.sha
|
||||||
|
assert r.script_exists(multiply.sha) == [True]
|
||||||
|
# test first evalsha
|
||||||
|
assert multiply(keys=['a'], args=[3]) == 6
|
||||||
|
|
||||||
|
def test_script_object_in_pipeline(self, r):
|
||||||
|
multiply = r.register_script(multiply_script)
|
||||||
|
assert not multiply.sha
|
||||||
|
pipe = r.pipeline()
|
||||||
|
pipe.set('a', 2)
|
||||||
|
pipe.get('a')
|
||||||
|
multiply(keys=['a'], args=[3], client=pipe)
|
||||||
|
# even though the pipeline wasn't executed yet, we made sure the
|
||||||
|
# script was loaded and got a valid sha
|
||||||
|
assert multiply.sha
|
||||||
|
assert r.script_exists(multiply.sha) == [True]
|
||||||
|
# [SET worked, GET 'a', result of multiple script]
|
||||||
|
assert pipe.execute() == [True, b('2'), 6]
|
||||||
|
|
||||||
|
# purge the script from redis's cache and re-run the pipeline
|
||||||
|
# the multiply script object knows it's sha, so it shouldn't get
|
||||||
|
# reloaded until pipe.execute()
|
||||||
|
r.script_flush()
|
||||||
|
pipe = r.pipeline()
|
||||||
|
pipe.set('a', 2)
|
||||||
|
pipe.get('a')
|
||||||
|
assert multiply.sha
|
||||||
|
multiply(keys=['a'], args=[3], client=pipe)
|
||||||
|
assert r.script_exists(multiply.sha) == [False]
|
||||||
|
# [SET worked, GET 'a', result of multiple script]
|
||||||
|
assert pipe.execute() == [True, b('2'), 6]
|
|
@ -0,0 +1,173 @@
|
||||||
|
from __future__ import with_statement
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from redis import exceptions
|
||||||
|
from redis.sentinel import (Sentinel, SentinelConnectionPool,
|
||||||
|
MasterNotFoundError, SlaveNotFoundError)
|
||||||
|
from redis._compat import next
|
||||||
|
import redis.sentinel
|
||||||
|
|
||||||
|
|
||||||
|
class SentinelTestClient(object):
|
||||||
|
def __init__(self, cluster, id):
|
||||||
|
self.cluster = cluster
|
||||||
|
self.id = id
|
||||||
|
|
||||||
|
def sentinel_masters(self):
|
||||||
|
self.cluster.connection_error_if_down(self)
|
||||||
|
return {self.cluster.service_name: self.cluster.master}
|
||||||
|
|
||||||
|
def sentinel_slaves(self, master_name):
|
||||||
|
self.cluster.connection_error_if_down(self)
|
||||||
|
if master_name != self.cluster.service_name:
|
||||||
|
return []
|
||||||
|
return self.cluster.slaves
|
||||||
|
|
||||||
|
|
||||||
|
class SentinelTestCluster(object):
|
||||||
|
def __init__(self, service_name='mymaster', ip='127.0.0.1', port=6379):
|
||||||
|
self.clients = {}
|
||||||
|
self.master = {
|
||||||
|
'ip': ip,
|
||||||
|
'port': port,
|
||||||
|
'is_master': True,
|
||||||
|
'is_sdown': False,
|
||||||
|
'is_odown': False,
|
||||||
|
'num-other-sentinels': 0,
|
||||||
|
}
|
||||||
|
self.service_name = service_name
|
||||||
|
self.slaves = []
|
||||||
|
self.nodes_down = set()
|
||||||
|
|
||||||
|
def connection_error_if_down(self, node):
|
||||||
|
if node.id in self.nodes_down:
|
||||||
|
raise exceptions.ConnectionError
|
||||||
|
|
||||||
|
def client(self, host, port, **kwargs):
|
||||||
|
return SentinelTestClient(self, (host, port))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def cluster(request):
|
||||||
|
def teardown():
|
||||||
|
redis.sentinel.StrictRedis = saved_StrictRedis
|
||||||
|
cluster = SentinelTestCluster()
|
||||||
|
saved_StrictRedis = redis.sentinel.StrictRedis
|
||||||
|
redis.sentinel.StrictRedis = cluster.client
|
||||||
|
request.addfinalizer(teardown)
|
||||||
|
return cluster
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def sentinel(request, cluster):
|
||||||
|
return Sentinel([('foo', 26379), ('bar', 26379)])
|
||||||
|
|
||||||
|
|
||||||
|
def test_discover_master(sentinel):
|
||||||
|
address = sentinel.discover_master('mymaster')
|
||||||
|
assert address == ('127.0.0.1', 6379)
|
||||||
|
|
||||||
|
|
||||||
|
def test_discover_master_error(sentinel):
|
||||||
|
with pytest.raises(MasterNotFoundError):
|
||||||
|
sentinel.discover_master('xxx')
|
||||||
|
|
||||||
|
|
||||||
|
def test_discover_master_sentinel_down(cluster, sentinel):
|
||||||
|
# Put first sentinel 'foo' down
|
||||||
|
cluster.nodes_down.add(('foo', 26379))
|
||||||
|
address = sentinel.discover_master('mymaster')
|
||||||
|
assert address == ('127.0.0.1', 6379)
|
||||||
|
# 'bar' is now first sentinel
|
||||||
|
assert sentinel.sentinels[0].id == ('bar', 26379)
|
||||||
|
|
||||||
|
|
||||||
|
def test_master_min_other_sentinels(cluster):
|
||||||
|
sentinel = Sentinel([('foo', 26379)], min_other_sentinels=1)
|
||||||
|
# min_other_sentinels
|
||||||
|
with pytest.raises(MasterNotFoundError):
|
||||||
|
sentinel.discover_master('mymaster')
|
||||||
|
cluster.master['num-other-sentinels'] = 2
|
||||||
|
address = sentinel.discover_master('mymaster')
|
||||||
|
assert address == ('127.0.0.1', 6379)
|
||||||
|
|
||||||
|
|
||||||
|
def test_master_odown(cluster, sentinel):
|
||||||
|
cluster.master['is_odown'] = True
|
||||||
|
with pytest.raises(MasterNotFoundError):
|
||||||
|
sentinel.discover_master('mymaster')
|
||||||
|
|
||||||
|
|
||||||
|
def test_master_sdown(cluster, sentinel):
|
||||||
|
cluster.master['is_sdown'] = True
|
||||||
|
with pytest.raises(MasterNotFoundError):
|
||||||
|
sentinel.discover_master('mymaster')
|
||||||
|
|
||||||
|
|
||||||
|
def test_discover_slaves(cluster, sentinel):
|
||||||
|
assert sentinel.discover_slaves('mymaster') == []
|
||||||
|
|
||||||
|
cluster.slaves = [
|
||||||
|
{'ip': 'slave0', 'port': 1234, 'is_odown': False, 'is_sdown': False},
|
||||||
|
{'ip': 'slave1', 'port': 1234, 'is_odown': False, 'is_sdown': False},
|
||||||
|
]
|
||||||
|
assert sentinel.discover_slaves('mymaster') == [
|
||||||
|
('slave0', 1234), ('slave1', 1234)]
|
||||||
|
|
||||||
|
# slave0 -> ODOWN
|
||||||
|
cluster.slaves[0]['is_odown'] = True
|
||||||
|
assert sentinel.discover_slaves('mymaster') == [
|
||||||
|
('slave1', 1234)]
|
||||||
|
|
||||||
|
# slave1 -> SDOWN
|
||||||
|
cluster.slaves[1]['is_sdown'] = True
|
||||||
|
assert sentinel.discover_slaves('mymaster') == []
|
||||||
|
|
||||||
|
cluster.slaves[0]['is_odown'] = False
|
||||||
|
cluster.slaves[1]['is_sdown'] = False
|
||||||
|
|
||||||
|
# node0 -> DOWN
|
||||||
|
cluster.nodes_down.add(('foo', 26379))
|
||||||
|
assert sentinel.discover_slaves('mymaster') == [
|
||||||
|
('slave0', 1234), ('slave1', 1234)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_master_for(cluster, sentinel):
|
||||||
|
master = sentinel.master_for('mymaster', db=9)
|
||||||
|
assert master.ping()
|
||||||
|
assert master.connection_pool.master_address == ('127.0.0.1', 6379)
|
||||||
|
|
||||||
|
# Use internal connection check
|
||||||
|
master = sentinel.master_for('mymaster', db=9, check_connection=True)
|
||||||
|
assert master.ping()
|
||||||
|
|
||||||
|
|
||||||
|
def test_slave_for(cluster, sentinel):
|
||||||
|
cluster.slaves = [
|
||||||
|
{'ip': '127.0.0.1', 'port': 6379,
|
||||||
|
'is_odown': False, 'is_sdown': False},
|
||||||
|
]
|
||||||
|
slave = sentinel.slave_for('mymaster', db=9)
|
||||||
|
assert slave.ping()
|
||||||
|
|
||||||
|
|
||||||
|
def test_slave_for_slave_not_found_error(cluster, sentinel):
|
||||||
|
cluster.master['is_odown'] = True
|
||||||
|
slave = sentinel.slave_for('mymaster', db=9)
|
||||||
|
with pytest.raises(SlaveNotFoundError):
|
||||||
|
slave.ping()
|
||||||
|
|
||||||
|
|
||||||
|
def test_slave_round_robin(cluster, sentinel):
|
||||||
|
cluster.slaves = [
|
||||||
|
{'ip': 'slave0', 'port': 6379, 'is_odown': False, 'is_sdown': False},
|
||||||
|
{'ip': 'slave1', 'port': 6379, 'is_odown': False, 'is_sdown': False},
|
||||||
|
]
|
||||||
|
pool = SentinelConnectionPool('mymaster', sentinel)
|
||||||
|
rotator = pool.rotate_slaves()
|
||||||
|
assert next(rotator) in (('slave0', 6379), ('slave1', 6379))
|
||||||
|
assert next(rotator) in (('slave0', 6379), ('slave1', 6379))
|
||||||
|
# Fallback to master
|
||||||
|
assert next(rotator) == ('127.0.0.1', 6379)
|
||||||
|
with pytest.raises(SlaveNotFoundError):
|
||||||
|
next(rotator)
|
Loading…
Reference in New Issue