add tx support for ledis-py

This commit is contained in:
holys 2014-08-26 10:56:29 +08:00
parent 7285c50bfc
commit 63d5c98580
3 changed files with 109 additions and 1 deletions

View File

@ -1,6 +1,7 @@
from __future__ import with_statement from __future__ import with_statement
import datetime import datetime
import time as mod_time import time as mod_time
from itertools import chain, starmap
from ledis._compat import (b, izip, imap, iteritems, from ledis._compat import (b, izip, imap, iteritems,
basestring, long, nativestr, bytes) basestring, long, nativestr, bytes)
from ledis.connection import ConnectionPool, UnixDomainSocketConnection from ledis.connection import ConnectionPool, UnixDomainSocketConnection
@ -8,6 +9,8 @@ from ledis.exceptions import (
ConnectionError, ConnectionError,
DataError, DataError,
LedisError, LedisError,
ResponseError,
TxNotBeginError
) )
SYM_EMPTY = b('') SYM_EMPTY = b('')
@ -170,6 +173,26 @@ class Ledis(object):
"Set a custom Response Callback" "Set a custom Response Callback"
self.response_callbacks[command] = callback self.response_callbacks[command] = callback
# def pipeline(self, transaction=True, shard_hint=None):
# """
# Return a new pipeline object that can queue multiple commands for
# later execution. ``transaction`` indicates whether all commands
# should be executed atomically. Apart from making a group of operations
# atomic, pipelines are useful for reducing the back-and-forth overhead
# between the client and server.
# """
# return StrictPipeline(
# self.connection_pool,
# self.response_callbacks,
# transaction,
# shard_hint)
def tx(self):
return Transaction(
self.connection_pool,
self.response_callbacks)
#### COMMAND EXECUTION AND PROTOCOL PARSING #### #### COMMAND EXECUTION AND PROTOCOL PARSING ####
def execute_command(self, *args, **options): def execute_command(self, *args, **options):
@ -869,3 +892,43 @@ class Ledis(object):
def bpersist(self, name): def bpersist(self, name):
"Removes an expiration on name" "Removes an expiration on name"
return self.execute_command('BPERSIST', name) return self.execute_command('BPERSIST', name)
class Transaction(Ledis):
def __init__(self, connection_pool, response_callbacks):
self.connection_pool = connection_pool
self.response_callbacks = response_callbacks
self.connection = None
def execute_command(self, *args, **options):
"Execute a command and return a parsed response"
command_name = args[0]
connection = self.connection
if self.connection is None:
raise TxNotBeginError
try:
connection.send_command(*args)
return self.parse_response(connection, command_name, **options)
except ConnectionError:
connection.disconnect()
connection.send_command(*args)
return self.parse_response(connection, command_name, **options)
def begin(self):
self.connection = self.connection_pool.get_connection('begin')
return self.execute_command("BEGIN")
def commit(self):
res = self.execute_command("COMMIT")
self.connection_pool.release(self.connection)
self.connection = None
return res
def rollback(self):
res = self.execute_command("ROLLBACK")
self.connection_pool.release(self.connection)
self.connection = None
return res

View File

@ -16,6 +16,10 @@ class BusyLoadingError(ConnectionError):
pass pass
class TimeoutError(LedisError):
pass
class InvalidResponse(ServerError): class InvalidResponse(ServerError):
pass pass
@ -30,3 +34,6 @@ class DataError(LedisError):
class ExecAbortError(ResponseError): class ExecAbortError(ResponseError):
pass pass
class TxNotBeginError(LedisError):
pass

View File

@ -0,0 +1,38 @@
import unittest
import sys
sys.path.append("..")
import ledis
class TestTx(unittest.TestCase):
def setUp(self):
self.l = ledis.Ledis(port=6380)
def tearDown(self):
self.l.delete("a")
def test_commit(self):
tx = self.l.tx()
self.l.set("a", "no-tx")
assert self.l.get("a") == "no-tx"
tx.begin()
tx.set("a", "tx")
assert self.l.get("a") == "no-tx"
assert tx.get("a") == "tx"
tx.commit()
assert self.l.get("a") == "tx"
def test_rollback(self):
tx = self.l.tx()
self.l.set("a", "no-tx")
assert self.l.get("a") == "no-tx"
tx.begin()
tx.set("a", "tx")
assert tx.get("a") == "tx"
assert self.l.get("a") == "no-tx"
tx.rollback()
assert self.l.get("a") == "no-tx"