diff --git a/client/ledis-py/ledis/client.py b/client/ledis-py/ledis/client.py index 4593a70..2e8a3f7 100644 --- a/client/ledis-py/ledis/client.py +++ b/client/ledis-py/ledis/client.py @@ -1,6 +1,7 @@ from __future__ import with_statement import datetime import time as mod_time +from itertools import chain, starmap from ledis._compat import (b, izip, imap, iteritems, basestring, long, nativestr, bytes) from ledis.connection import ConnectionPool, UnixDomainSocketConnection @@ -8,6 +9,8 @@ from ledis.exceptions import ( ConnectionError, DataError, LedisError, + ResponseError, + TxNotBeginError ) SYM_EMPTY = b('') @@ -169,6 +172,26 @@ class Ledis(object): def set_response_callback(self, command, callback): "Set a custom Response Callback" self.response_callbacks[command] = callback + + + # def pipeline(self, transaction=True, shard_hint=None): + # """ + # Return a new pipeline object that can queue multiple commands for + # later execution. ``transaction`` indicates whether all commands + # should be executed atomically. Apart from making a group of operations + # atomic, pipelines are useful for reducing the back-and-forth overhead + # between the client and server. + # """ + # return StrictPipeline( + # self.connection_pool, + # self.response_callbacks, + # transaction, + # shard_hint) + + def tx(self): + return Transaction( + self.connection_pool, + self.response_callbacks) #### COMMAND EXECUTION AND PROTOCOL PARSING #### @@ -868,4 +891,44 @@ class Ledis(object): def bpersist(self, name): "Removes an expiration on name" - return self.execute_command('BPERSIST', name) \ No newline at end of file + return self.execute_command('BPERSIST', name) + + +class Transaction(Ledis): + def __init__(self, connection_pool, response_callbacks): + self.connection_pool = connection_pool + self.response_callbacks = response_callbacks + self.connection = None + + def execute_command(self, *args, **options): + "Execute a command and return a parsed response" + command_name = args[0] + + connection = self.connection + if self.connection is None: + raise TxNotBeginError + + try: + connection.send_command(*args) + return self.parse_response(connection, command_name, **options) + except ConnectionError: + connection.disconnect() + connection.send_command(*args) + return self.parse_response(connection, command_name, **options) + + def begin(self): + self.connection = self.connection_pool.get_connection('begin') + return self.execute_command("BEGIN") + + def commit(self): + res = self.execute_command("COMMIT") + self.connection_pool.release(self.connection) + self.connection = None + return res + + def rollback(self): + res = self.execute_command("ROLLBACK") + self.connection_pool.release(self.connection) + self.connection = None + return res + diff --git a/client/ledis-py/ledis/exceptions.py b/client/ledis-py/ledis/exceptions.py index 91bada4..9150db6 100644 --- a/client/ledis-py/ledis/exceptions.py +++ b/client/ledis-py/ledis/exceptions.py @@ -16,6 +16,10 @@ class BusyLoadingError(ConnectionError): pass +class TimeoutError(LedisError): + pass + + class InvalidResponse(ServerError): pass @@ -30,3 +34,6 @@ class DataError(LedisError): class ExecAbortError(ResponseError): pass + +class TxNotBeginError(LedisError): + pass \ No newline at end of file diff --git a/client/ledis-py/tests/test_tx.py b/client/ledis-py/tests/test_tx.py new file mode 100644 index 0000000..b589dc7 --- /dev/null +++ b/client/ledis-py/tests/test_tx.py @@ -0,0 +1,38 @@ +import unittest +import sys +sys.path.append("..") + +import ledis + + +class TestTx(unittest.TestCase): + def setUp(self): + self.l = ledis.Ledis(port=6380) + + def tearDown(self): + self.l.delete("a") + + def test_commit(self): + tx = self.l.tx() + self.l.set("a", "no-tx") + assert self.l.get("a") == "no-tx" + tx.begin() + tx.set("a", "tx") + assert self.l.get("a") == "no-tx" + assert tx.get("a") == "tx" + + tx.commit() + assert self.l.get("a") == "tx" + + def test_rollback(self): + tx = self.l.tx() + self.l.set("a", "no-tx") + assert self.l.get("a") == "no-tx" + + tx.begin() + tx.set("a", "tx") + assert tx.get("a") == "tx" + assert self.l.get("a") == "no-tx" + + tx.rollback() + assert self.l.get("a") == "no-tx" \ No newline at end of file