forked from mirror/ledisdb
add tx support for ledis-py
This commit is contained in:
parent
7285c50bfc
commit
63d5c98580
|
@ -1,6 +1,7 @@
|
|||
from __future__ import with_statement
|
||||
import datetime
|
||||
import time as mod_time
|
||||
from itertools import chain, starmap
|
||||
from ledis._compat import (b, izip, imap, iteritems,
|
||||
basestring, long, nativestr, bytes)
|
||||
from ledis.connection import ConnectionPool, UnixDomainSocketConnection
|
||||
|
@ -8,6 +9,8 @@ from ledis.exceptions import (
|
|||
ConnectionError,
|
||||
DataError,
|
||||
LedisError,
|
||||
ResponseError,
|
||||
TxNotBeginError
|
||||
)
|
||||
|
||||
SYM_EMPTY = b('')
|
||||
|
@ -170,6 +173,26 @@ class Ledis(object):
|
|||
"Set a custom Response Callback"
|
||||
self.response_callbacks[command] = callback
|
||||
|
||||
|
||||
# def pipeline(self, transaction=True, shard_hint=None):
|
||||
# """
|
||||
# Return a new pipeline object that can queue multiple commands for
|
||||
# later execution. ``transaction`` indicates whether all commands
|
||||
# should be executed atomically. Apart from making a group of operations
|
||||
# atomic, pipelines are useful for reducing the back-and-forth overhead
|
||||
# between the client and server.
|
||||
# """
|
||||
# return StrictPipeline(
|
||||
# self.connection_pool,
|
||||
# self.response_callbacks,
|
||||
# transaction,
|
||||
# shard_hint)
|
||||
|
||||
def tx(self):
|
||||
return Transaction(
|
||||
self.connection_pool,
|
||||
self.response_callbacks)
|
||||
|
||||
#### COMMAND EXECUTION AND PROTOCOL PARSING ####
|
||||
|
||||
def execute_command(self, *args, **options):
|
||||
|
@ -869,3 +892,43 @@ class Ledis(object):
|
|||
def bpersist(self, name):
|
||||
"Removes an expiration on name"
|
||||
return self.execute_command('BPERSIST', name)
|
||||
|
||||
|
||||
class Transaction(Ledis):
|
||||
def __init__(self, connection_pool, response_callbacks):
|
||||
self.connection_pool = connection_pool
|
||||
self.response_callbacks = response_callbacks
|
||||
self.connection = None
|
||||
|
||||
def execute_command(self, *args, **options):
|
||||
"Execute a command and return a parsed response"
|
||||
command_name = args[0]
|
||||
|
||||
connection = self.connection
|
||||
if self.connection is None:
|
||||
raise TxNotBeginError
|
||||
|
||||
try:
|
||||
connection.send_command(*args)
|
||||
return self.parse_response(connection, command_name, **options)
|
||||
except ConnectionError:
|
||||
connection.disconnect()
|
||||
connection.send_command(*args)
|
||||
return self.parse_response(connection, command_name, **options)
|
||||
|
||||
def begin(self):
|
||||
self.connection = self.connection_pool.get_connection('begin')
|
||||
return self.execute_command("BEGIN")
|
||||
|
||||
def commit(self):
|
||||
res = self.execute_command("COMMIT")
|
||||
self.connection_pool.release(self.connection)
|
||||
self.connection = None
|
||||
return res
|
||||
|
||||
def rollback(self):
|
||||
res = self.execute_command("ROLLBACK")
|
||||
self.connection_pool.release(self.connection)
|
||||
self.connection = None
|
||||
return res
|
||||
|
||||
|
|
|
@ -16,6 +16,10 @@ class BusyLoadingError(ConnectionError):
|
|||
pass
|
||||
|
||||
|
||||
class TimeoutError(LedisError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidResponse(ServerError):
|
||||
pass
|
||||
|
||||
|
@ -30,3 +34,6 @@ class DataError(LedisError):
|
|||
|
||||
class ExecAbortError(ResponseError):
|
||||
pass
|
||||
|
||||
class TxNotBeginError(LedisError):
|
||||
pass
|
|
@ -0,0 +1,38 @@
|
|||
import unittest
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
|
||||
import ledis
|
||||
|
||||
|
||||
class TestTx(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.l = ledis.Ledis(port=6380)
|
||||
|
||||
def tearDown(self):
|
||||
self.l.delete("a")
|
||||
|
||||
def test_commit(self):
|
||||
tx = self.l.tx()
|
||||
self.l.set("a", "no-tx")
|
||||
assert self.l.get("a") == "no-tx"
|
||||
tx.begin()
|
||||
tx.set("a", "tx")
|
||||
assert self.l.get("a") == "no-tx"
|
||||
assert tx.get("a") == "tx"
|
||||
|
||||
tx.commit()
|
||||
assert self.l.get("a") == "tx"
|
||||
|
||||
def test_rollback(self):
|
||||
tx = self.l.tx()
|
||||
self.l.set("a", "no-tx")
|
||||
assert self.l.get("a") == "no-tx"
|
||||
|
||||
tx.begin()
|
||||
tx.set("a", "tx")
|
||||
assert tx.get("a") == "tx"
|
||||
assert self.l.get("a") == "no-tx"
|
||||
|
||||
tx.rollback()
|
||||
assert self.l.get("a") == "no-tx"
|
Loading…
Reference in New Issue