From 7cfd84b20bb5391c0abc0a0300401fdacd997722 Mon Sep 17 00:00:00 2001 From: siddontang Date: Fri, 13 Jun 2014 08:37:51 +0800 Subject: [PATCH 1/9] add openresty lexis client --- client/openresty/ledis.lua | 362 +++++++++++++++++++++++++++++++++++++ 1 file changed, 362 insertions(+) create mode 100644 client/openresty/ledis.lua diff --git a/client/openresty/ledis.lua b/client/openresty/ledis.lua new file mode 100644 index 0000000..f567ad9 --- /dev/null +++ b/client/openresty/ledis.lua @@ -0,0 +1,362 @@ +--- refer from openresty redis lib + +local sub = string.sub +local byte = string.byte +local tcp = ngx.socket.tcp +local concat = table.concat +local null = ngx.null +local pairs = pairs +local unpack = unpack +local setmetatable = setmetatable +local tonumber = tonumber +local error = error + + +local ok, new_tab = pcall(require, "table.new") +if not ok then + new_tab = function (narr, nrec) return {} end +end + + +local _M = new_tab(0, 155) +_M._VERSION = '0.01' + + +local commands = { + --[[kv]] + "decr", + "decrby", + "del", + "exists", + "get", + "getset", + "incr", + "incrby", + "mget", + "mset", + "set", + "setnx", + + --[[hash]] + "hdel", + "hexists", + "hget", + "hgetall", + "hincrby", + "hkeys", + "hlen", + "hmget", + --[["hmset",]] + "hset", + "hvals", + "hclear", + + --[[list]] + "lindex", + "llen", + "lpop", + "lrange", + "lpush", + "rpop", + "rpush", + "lclear", + + --[[zset]] + "zadd", + "zcard", + "zcount", + "zincrby", + "zrange", + "zrangebyscore", + "zrank", + "zrem", + "zremrangebyrank", + "zremrangebyscore", + "zrevrange", + "zrevrank", + "zrevrangebyscore", + "zscore", + "zclear", + + --[[server]] + "ping", + "echo", + "select" +} + + + +local mt = { __index = _M } + + +function _M.new(self) + local sock, err = tcp() + if not sock then + return nil, err + end + return setmetatable({ sock = sock }, mt) +end + + +function _M.set_timeout(self, timeout) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + return sock:settimeout(timeout) +end + + +function _M.connect(self, ...) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + return sock:connect(...) +end + + +function _M.set_keepalive(self, ...) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + return sock:setkeepalive(...) +end + + +function _M.get_reused_times(self) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + return sock:getreusedtimes() +end + + +local function close(self) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + return sock:close() +end +_M.close = close + + +local function _read_reply(self, sock) + local line, err = sock:receive() + if not line then + if err == "timeout" then + sock:close() + end + return nil, err + end + + local prefix = byte(line) + + if prefix == 36 then -- char '$' + -- print("bulk reply") + + local size = tonumber(sub(line, 2)) + if size < 0 then + return null + end + + local data, err = sock:receive(size) + if not data then + if err == "timeout" then + sock:close() + end + return nil, err + end + + local dummy, err = sock:receive(2) -- ignore CRLF + if not dummy then + return nil, err + end + + return data + + elseif prefix == 43 then -- char '+' + -- print("status reply") + + return sub(line, 2) + + elseif prefix == 42 then -- char '*' + local n = tonumber(sub(line, 2)) + + -- print("multi-bulk reply: ", n) + if n < 0 then + return null + end + + local vals = new_tab(n, 0); + local nvals = 0 + for i = 1, n do + local res, err = _read_reply(self, sock) + if res then + nvals = nvals + 1 + vals[nvals] = res + + elseif res == nil then + return nil, err + + else + -- be a valid redis error value + nvals = nvals + 1 + vals[nvals] = {false, err} + end + end + + return vals + + elseif prefix == 58 then -- char ':' + -- print("integer reply") + return tonumber(sub(line, 2)) + + elseif prefix == 45 then -- char '-' + -- print("error reply: ", n) + + return false, sub(line, 2) + + else + return nil, "unkown prefix: \"" .. prefix .. "\"" + end +end + + +local function _gen_req(args) + local nargs = #args + + local req = new_tab(nargs + 1, 0) + req[1] = "*" .. nargs .. "\r\n" + local nbits = 1 + + for i = 1, nargs do + local arg = args[i] + nbits = nbits + 1 + + if not arg then + req[nbits] = "$-1\r\n" + + else + if type(arg) ~= "string" then + arg = tostring(arg) + end + req[nbits] = "$" .. #arg .. "\r\n" .. arg .. "\r\n" + end + end + + -- it is faster to do string concatenation on the Lua land + return concat(req) +end + + +local function _do_cmd(self, ...) + local args = {...} + + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + local req = _gen_req(args) + + local reqs = self._reqs + if reqs then + reqs[#reqs + 1] = req + return + end + + -- print("request: ", table.concat(req)) + + local bytes, err = sock:send(req) + if not bytes then + return nil, err + end + + return _read_reply(self, sock) +end + + + + +function _M.read_reply(self) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + local res, err = _read_reply(self, sock) + + return res, err +end + + +for i = 1, #commands do + local cmd = commands[i] + + _M[cmd] = + function (self, ...) + return _do_cmd(self, cmd, ...) + end +end + + +function _M.hmset(self, hashname, ...) + local args = {...} + if #args == 1 then + local t = args[1] + + local n = 0 + for k, v in pairs(t) do + n = n + 2 + end + + local array = new_tab(n, 0) + + local i = 0 + for k, v in pairs(t) do + array[i + 1] = k + array[i + 2] = v + i = i + 2 + end + -- print("key", hashname) + return _do_cmd(self, "hmset", hashname, unpack(array)) + end + + -- backwards compatibility + return _do_cmd(self, "hmset", hashname, ...) +end + + +function _M.array_to_hash(self, t) + local n = #t + -- print("n = ", n) + local h = new_tab(0, n / 2) + for i = 1, n, 2 do + h[t[i]] = t[i + 1] + end + return h +end + + +function _M.add_commands(...) + local cmds = {...} + for i = 1, #cmds do + local cmd = cmds[i] + _M[cmd] = + function (self, ...) + return _do_cmd(self, cmd, ...) + end + end +end + + +return _M \ No newline at end of file From 852fce9f4ca0bebb591adb9edf10d60f74e82597 Mon Sep 17 00:00:00 2001 From: silentsai Date: Mon, 16 Jun 2014 19:24:37 +0800 Subject: [PATCH 2/9] add server commands of expire/ttl --- server/cmd_hash.go | 58 ++++++++++++++++++++++++++++++++++++++ server/cmd_kv.go | 62 +++++++++++++++++++++++++++++++++++++++++ server/cmd_list.go | 59 ++++++++++++++++++++++++++++++++++++++- server/cmd_ttl_test.go | 63 ++++++++++++++++++++++++++++++++++++++++++ server/cmd_zset.go | 60 +++++++++++++++++++++++++++++++++++++++- 5 files changed, 300 insertions(+), 2 deletions(-) create mode 100644 server/cmd_ttl_test.go diff --git a/server/cmd_hash.go b/server/cmd_hash.go index b86e339..bce1be2 100644 --- a/server/cmd_hash.go +++ b/server/cmd_hash.go @@ -207,6 +207,61 @@ func hclearCommand(c *client) error { return nil } +func hexpireCommand(c *client) error { + args := c.args + if len(args) == 0 { + return ErrCmdParams + } + + duration, err := ledis.StrInt64(args[1], nil) + if err != nil { + return err + } + + if v, err := c.db.HExpire(args[0], duration); err != nil { + return err + } else { + c.writeInteger(v) + } + + return nil +} + +func hexpireAtCommand(c *client) error { + args := c.args + if len(args) == 0 { + return ErrCmdParams + } + + when, err := ledis.StrInt64(args[1], nil) + if err != nil { + return err + } + + if v, err := c.db.HExpireAt(args[0], when); err != nil { + return err + } else { + c.writeInteger(v) + } + + return nil +} + +func httlCommand(c *client) error { + args := c.args + if len(args) == 0 { + return ErrCmdParams + } + + if v, err := c.db.HTTL(args[0]); err != nil { + return err + } else { + c.writeInteger(v) + } + + return nil +} + func init() { register("hdel", hdelCommand) register("hexists", hexistsCommand) @@ -223,4 +278,7 @@ func init() { //ledisdb special command register("hclear", hclearCommand) + register("hexpire", hexpireCommand) + register("hexpireat", hexpireAtCommand) + register("httl", httlCommand) } diff --git a/server/cmd_kv.go b/server/cmd_kv.go index 1c42cec..f70a508 100644 --- a/server/cmd_kv.go +++ b/server/cmd_kv.go @@ -203,6 +203,65 @@ func mgetCommand(c *client) error { return nil } +func expireCommand(c *client) error { + args := c.args + if len(args) == 0 { + return ErrCmdParams + } + + duration, err := ledis.StrInt64(args[1], nil) + if err != nil { + return err + } + + if v, err := c.db.Expire(args[0], duration); err != nil { + return err + } else { + c.writeInteger(v) + } + + return nil +} + +func expireAtCommand(c *client) error { + args := c.args + if len(args) == 0 { + return ErrCmdParams + } + + when, err := ledis.StrInt64(args[1], nil) + if err != nil { + return err + } + + if v, err := c.db.ExpireAt(args[0], when); err != nil { + return err + } else { + c.writeInteger(v) + } + + return nil +} + +func ttlCommand(c *client) error { + args := c.args + if len(args) == 0 { + return ErrCmdParams + } + + if v, err := c.db.TTL(args[0]); err != nil { + return err + } else { + c.writeInteger(v) + } + + return nil +} + +// func (db *DB) Expire(key []byte, duration int6 +// func (db *DB) ExpireAt(key []byte, when int64) +// func (db *DB) TTL(key []byte) (int64, error) + func init() { register("decr", decrCommand) register("decrby", decrbyCommand) @@ -216,4 +275,7 @@ func init() { register("mset", msetCommand) register("set", setCommand) register("setnx", setnxCommand) + register("expire", expireCommand) + register("expireat", expireAtCommand) + register("ttl", ttlCommand) } diff --git a/server/cmd_list.go b/server/cmd_list.go index 193e27d..58d64f3 100644 --- a/server/cmd_list.go +++ b/server/cmd_list.go @@ -143,6 +143,61 @@ func lclearCommand(c *client) error { return nil } +func lexpireCommand(c *client) error { + args := c.args + if len(args) == 0 { + return ErrCmdParams + } + + duration, err := ledis.StrInt64(args[1], nil) + if err != nil { + return err + } + + if v, err := c.db.LExpire(args[0], duration); err != nil { + return err + } else { + c.writeInteger(v) + } + + return nil +} + +func lexpireAtCommand(c *client) error { + args := c.args + if len(args) == 0 { + return ErrCmdParams + } + + when, err := ledis.StrInt64(args[1], nil) + if err != nil { + return err + } + + if v, err := c.db.LExpireAt(args[0], when); err != nil { + return err + } else { + c.writeInteger(v) + } + + return nil +} + +func lttlCommand(c *client) error { + args := c.args + if len(args) == 0 { + return ErrCmdParams + } + + if v, err := c.db.LTTL(args[0]); err != nil { + return err + } else { + c.writeInteger(v) + } + + return nil +} + func init() { register("lindex", lindexCommand) register("llen", llenCommand) @@ -155,5 +210,7 @@ func init() { //ledisdb special command register("lclear", lclearCommand) - + register("lexpire", lexpireCommand) + register("lexpireat", lexpireAtCommand) + register("lttl", lttlCommand) } diff --git a/server/cmd_ttl_test.go b/server/cmd_ttl_test.go new file mode 100644 index 0000000..da90d6c --- /dev/null +++ b/server/cmd_ttl_test.go @@ -0,0 +1,63 @@ +package server + +import ( + "github.com/garyburd/redigo/redis" + "testing" + "time" +) + +func now() int64 { + return time.Now().Unix() +} + +func TestKVExpire(t *testing.T) { + c := getTestConn() + defer c.Close() + + k := "a_ttl" + c.Do("set", k, "123") + + // expire + ttl + exp := int64(10) + if n, err := redis.Int(c.Do("expire", k, exp)); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Fatal(n) + } + + if ttl, err := redis.Int64(c.Do("ttl", k)); err != nil { + t.Fatal(err) + } else if ttl != exp { + t.Fatal(ttl) + } + + // expireat + ttl + tm := now() + 3 + if n, err := redis.Int(c.Do("expireat", k, tm)); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Fatal(n) + } + + if ttl, err := redis.Int64(c.Do("ttl", k)); err != nil { + t.Fatal(err) + } else if ttl != 3 { + t.Fatal(ttl) + } + + kErr := "not_exist_ttl" + + // err - expire, expireat + if n, err := redis.Int(c.Do("expire", kErr, tm)); err != nil || n != 0 { + t.Fatal(false) + } + + if n, err := redis.Int(c.Do("expireat", kErr, tm)); err != nil || n != 0 { + t.Fatal(false) + } + + if n, err := redis.Int(c.Do("ttl", kErr)); err != nil || n != -1 { + t.Fatal(false) + } + +} diff --git a/server/cmd_zset.go b/server/cmd_zset.go index c72fa71..dfc24f8 100644 --- a/server/cmd_zset.go +++ b/server/cmd_zset.go @@ -421,6 +421,61 @@ func zclearCommand(c *client) error { return nil } +func zexpireCommand(c *client) error { + args := c.args + if len(args) == 0 { + return ErrCmdParams + } + + duration, err := ledis.StrInt64(args[1], nil) + if err != nil { + return err + } + + if v, err := c.db.ZExpire(args[0], duration); err != nil { + return err + } else { + c.writeInteger(v) + } + + return nil +} + +func zexpireAtCommand(c *client) error { + args := c.args + if len(args) == 0 { + return ErrCmdParams + } + + when, err := ledis.StrInt64(args[1], nil) + if err != nil { + return err + } + + if v, err := c.db.ZExpireAt(args[0], when); err != nil { + return err + } else { + c.writeInteger(v) + } + + return nil +} + +func zttlCommand(c *client) error { + args := c.args + if len(args) == 0 { + return ErrCmdParams + } + + if v, err := c.db.ZTTL(args[0]); err != nil { + return err + } else { + c.writeInteger(v) + } + + return nil +} + func init() { register("zadd", zaddCommand) register("zcard", zcardCommand) @@ -438,6 +493,9 @@ func init() { register("zscore", zscoreCommand) //ledisdb special command - register("zclear", zclearCommand) + register("zclear", zclearCommand) + register("zexpire", zexpireCommand) + register("zexpireat", zexpireAtCommand) + register("zttl", zttlCommand) } From a0bd2e90e55e3afeb283db076962735011feedb5 Mon Sep 17 00:00:00 2001 From: silentsai Date: Wed, 18 Jun 2014 10:24:40 +0800 Subject: [PATCH 3/9] add client for python --- client/ledis-py/ledis/__init__.py | 31 + client/ledis-py/ledis/_compat.py | 79 + client/ledis-py/ledis/client.py | 2093 +++++++++++++++++ client/ledis-py/ledis/connection.py | 580 +++++ client/ledis-py/ledis/exceptions.py | 49 + client/ledis-py/ledis/utils.py | 16 + client/ledis-py/setup.py | 61 + client/ledis-py/tests/__init__.py | 0 client/ledis-py/tests/conftest.py | 46 + client/ledis-py/tests/test_commands.py | 1419 +++++++++++ client/ledis-py/tests/test_connection_pool.py | 402 ++++ client/ledis-py/tests/test_encoding.py | 33 + client/ledis-py/tests/test_lock.py | 167 ++ client/ledis-py/tests/test_pipeline.py | 226 ++ client/ledis-py/tests/test_pubsub.py | 392 +++ client/ledis-py/tests/test_scripting.py | 82 + client/ledis-py/tests/test_sentinel.py | 173 ++ 17 files changed, 5849 insertions(+) create mode 100644 client/ledis-py/ledis/__init__.py create mode 100644 client/ledis-py/ledis/_compat.py create mode 100644 client/ledis-py/ledis/client.py create mode 100644 client/ledis-py/ledis/connection.py create mode 100644 client/ledis-py/ledis/exceptions.py create mode 100644 client/ledis-py/ledis/utils.py create mode 100644 client/ledis-py/setup.py create mode 100644 client/ledis-py/tests/__init__.py create mode 100644 client/ledis-py/tests/conftest.py create mode 100644 client/ledis-py/tests/test_commands.py create mode 100644 client/ledis-py/tests/test_connection_pool.py create mode 100644 client/ledis-py/tests/test_encoding.py create mode 100644 client/ledis-py/tests/test_lock.py create mode 100644 client/ledis-py/tests/test_pipeline.py create mode 100644 client/ledis-py/tests/test_pubsub.py create mode 100644 client/ledis-py/tests/test_scripting.py create mode 100644 client/ledis-py/tests/test_sentinel.py diff --git a/client/ledis-py/ledis/__init__.py b/client/ledis-py/ledis/__init__.py new file mode 100644 index 0000000..9549ec7 --- /dev/null +++ b/client/ledis-py/ledis/__init__.py @@ -0,0 +1,31 @@ +from redis.client import Redis, StrictRedis +from redis.connection import ( + BlockingConnectionPool, + ConnectionPool, + Connection, + UnixDomainSocketConnection +) +from redis.utils import from_url +from redis.exceptions import ( + AuthenticationError, + ConnectionError, + BusyLoadingError, + DataError, + InvalidResponse, + PubSubError, + RedisError, + ResponseError, + WatchError, +) + + +__version__ = '2.7.6' +VERSION = tuple(map(int, __version__.split('.'))) + +__all__ = [ + 'Redis', 'StrictRedis', 'ConnectionPool', 'BlockingConnectionPool', + 'Connection', 'UnixDomainSocketConnection', + 'RedisError', 'ConnectionError', 'ResponseError', 'AuthenticationError', + 'InvalidResponse', 'DataError', 'PubSubError', 'WatchError', 'from_url', + 'BusyLoadingError' +] diff --git a/client/ledis-py/ledis/_compat.py b/client/ledis-py/ledis/_compat.py new file mode 100644 index 0000000..38a7316 --- /dev/null +++ b/client/ledis-py/ledis/_compat.py @@ -0,0 +1,79 @@ +"""Internal module for Python 2 backwards compatibility.""" +import sys + + +if sys.version_info[0] < 3: + from urlparse import urlparse + from itertools import imap, izip + from string import letters as ascii_letters + from Queue import Queue + try: + from cStringIO import StringIO as BytesIO + except ImportError: + from StringIO import StringIO as BytesIO + + iteritems = lambda x: x.iteritems() + iterkeys = lambda x: x.iterkeys() + itervalues = lambda x: x.itervalues() + nativestr = lambda x: \ + x if isinstance(x, str) else x.encode('utf-8', 'replace') + u = lambda x: x.decode() + b = lambda x: x + next = lambda x: x.next() + byte_to_chr = lambda x: x + unichr = unichr + xrange = xrange + basestring = basestring + unicode = unicode + bytes = str + long = long +else: + from urllib.parse import urlparse + from io import BytesIO + from string import ascii_letters + from queue import Queue + + iteritems = lambda x: iter(x.items()) + iterkeys = lambda x: iter(x.keys()) + itervalues = lambda x: iter(x.values()) + byte_to_chr = lambda x: chr(x) + nativestr = lambda x: \ + x if isinstance(x, str) else x.decode('utf-8', 'replace') + u = lambda x: x + b = lambda x: x.encode('iso-8859-1') if not isinstance(x, bytes) else x + next = next + unichr = chr + imap = map + izip = zip + xrange = range + basestring = str + unicode = str + bytes = bytes + long = int + +try: # Python 3 + from queue import LifoQueue, Empty, Full +except ImportError: + from Queue import Empty, Full + try: # Python 2.6 - 2.7 + from Queue import LifoQueue + except ImportError: # Python 2.5 + from Queue import Queue + # From the Python 2.7 lib. Python 2.5 already extracted the core + # methods to aid implementating different queue organisations. + + class LifoQueue(Queue): + "Override queue methods to implement a last-in first-out queue." + + def _init(self, maxsize): + self.maxsize = maxsize + self.queue = [] + + def _qsize(self, len=len): + return len(self.queue) + + def _put(self, item): + self.queue.append(item) + + def _get(self): + return self.queue.pop() diff --git a/client/ledis-py/ledis/client.py b/client/ledis-py/ledis/client.py new file mode 100644 index 0000000..f98dc2b --- /dev/null +++ b/client/ledis-py/ledis/client.py @@ -0,0 +1,2093 @@ +from __future__ import with_statement +from itertools import chain, starmap +import datetime +import sys +import warnings +import time as mod_time +from redis._compat import (b, izip, imap, iteritems, iterkeys, itervalues, + basestring, long, nativestr, urlparse, bytes) +from redis.connection import ConnectionPool, UnixDomainSocketConnection +from redis.exceptions import ( + ConnectionError, + DataError, + RedisError, + ResponseError, + WatchError, + NoScriptError, + ExecAbortError, +) + +SYM_EMPTY = b('') + + +def list_or_args(keys, args): + # returns a single list combining keys and args + try: + iter(keys) + # a string or bytes instance can be iterated, but indicates + # keys wasn't passed as a list + if isinstance(keys, (basestring, bytes)): + keys = [keys] + except TypeError: + keys = [keys] + if args: + keys.extend(args) + return keys + + +def timestamp_to_datetime(response): + "Converts a unix timestamp to a Python datetime object" + if not response: + return None + try: + response = int(response) + except ValueError: + return None + return datetime.datetime.fromtimestamp(response) + + +def string_keys_to_dict(key_string, callback): + return dict.fromkeys(key_string.split(), callback) + + +def dict_merge(*dicts): + merged = {} + [merged.update(d) for d in dicts] + return merged + + +def parse_debug_object(response): + "Parse the results of Redis's DEBUG OBJECT command into a Python dict" + # The 'type' of the object is the first item in the response, but isn't + # prefixed with a name + response = nativestr(response) + response = 'type:' + response + response = dict([kv.split(':') for kv in response.split()]) + + # parse some expected int values from the string response + # note: this cmd isn't spec'd so these may not appear in all redis versions + int_fields = ('refcount', 'serializedlength', 'lru', 'lru_seconds_idle') + for field in int_fields: + if field in response: + response[field] = int(response[field]) + + return response + + +def parse_object(response, infotype): + "Parse the results of an OBJECT command" + if infotype in ('idletime', 'refcount'): + return int(response) + return response + + +def parse_info(response): + "Parse the result of Redis's INFO command into a Python dict" + info = {} + response = nativestr(response) + + def get_value(value): + if ',' not in value or '=' not in value: + try: + if '.' in value: + return float(value) + else: + return int(value) + except ValueError: + return value + else: + sub_dict = {} + for item in value.split(','): + k, v = item.rsplit('=', 1) + sub_dict[k] = get_value(v) + return sub_dict + + for line in response.splitlines(): + if line and not line.startswith('#'): + key, value = line.split(':') + info[key] = get_value(value) + return info + + +def pairs_to_dict(response): + "Create a dict given a list of key/value pairs" + it = iter(response) + return dict(izip(it, it)) + + +def zset_score_pairs(response, **options): + """ + If ``withscores`` is specified in the options, return the response as + a list of (value, score) pairs + """ + if not response or not options['withscores']: + return response + score_cast_func = options.get('score_cast_func', float) + it = iter(response) + return list(izip(it, imap(score_cast_func, it))) + + +def sort_return_tuples(response, **options): + """ + If ``groups`` is specified, return the response as a list of + n-element tuples with n being the value found in options['groups'] + """ + if not response or not options['groups']: + return response + n = options['groups'] + return list(izip(*[response[i::n] for i in range(n)])) + + +def int_or_none(response): + if response is None: + return None + return int(response) + + +def float_or_none(response): + if response is None: + return None + return float(response) + + +def parse_client(response, **options): + parse = options['parse'] + if parse == 'LIST': + clients = [] + for c in nativestr(response).splitlines(): + clients.append(dict([pair.split('=') for pair in c.split(' ')])) + return clients + elif parse == 'KILL': + return bool(response) + elif parse == 'GETNAME': + return response and nativestr(response) + elif parse == 'SETNAME': + return nativestr(response) == 'OK' + + +def parse_config(response, **options): + if options['parse'] == 'GET': + response = [nativestr(i) if i is not None else None for i in response] + return response and pairs_to_dict(response) or {} + return nativestr(response) == 'OK' + + +def parse_script(response, **options): + parse = options['parse'] + if parse in ('FLUSH', 'KILL'): + return response == 'OK' + if parse == 'EXISTS': + return list(imap(bool, response)) + return response + + +class StrictRedis(object): + """ + Implementation of the Redis protocol. + + This abstract class provides a Python interface to all Redis commands + and an implementation of the Redis protocol. + + Connection and Pipeline derive from this, implementing how + the commands are sent and received to the Redis server + """ + RESPONSE_CALLBACKS = dict_merge( + string_keys_to_dict( + 'AUTH EXISTS EXPIRE EXPIREAT HEXISTS HMSET MOVE MSETNX PERSIST ' + 'PSETEX RENAMENX SISMEMBER SMOVE SETEX SETNX', + bool + ), + string_keys_to_dict( + 'BITCOUNT DECRBY DEL GETBIT HDEL HLEN INCRBY LINSERT LLEN LPUSHX ' + 'RPUSHX SADD SCARD SDIFFSTORE SETBIT SETRANGE SINTERSTORE SREM ' + 'STRLEN SUNIONSTORE ZADD ZCARD ZREM ZREMRANGEBYRANK ' + 'ZREMRANGEBYSCORE', + int + ), + string_keys_to_dict('INCRBYFLOAT HINCRBYFLOAT', float), + string_keys_to_dict( + # these return OK, or int if redis-server is >=1.3.4 + 'LPUSH RPUSH', + lambda r: isinstance(r, long) and r or nativestr(r) == 'OK' + ), + string_keys_to_dict('SORT', sort_return_tuples), + string_keys_to_dict('ZSCORE ZINCRBY', float_or_none), + string_keys_to_dict( + 'FLUSHALL FLUSHDB LSET LTRIM MSET RENAME ' + 'SAVE SELECT SHUTDOWN SLAVEOF WATCH UNWATCH', + lambda r: nativestr(r) == 'OK' + ), + string_keys_to_dict('BLPOP BRPOP', lambda r: r and tuple(r) or None), + string_keys_to_dict( + 'SDIFF SINTER SMEMBERS SUNION', + lambda r: r and set(r) or set() + ), + string_keys_to_dict( + 'ZRANGE ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE', + zset_score_pairs + ), + string_keys_to_dict('ZRANK ZREVRANK', int_or_none), + { + 'BGREWRITEAOF': ( + lambda r: nativestr(r) == ('Background rewriting of AOF ' + 'file started') + ), + 'BGSAVE': lambda r: nativestr(r) == 'Background saving started', + 'CLIENT': parse_client, + 'CONFIG': parse_config, + 'DEBUG': parse_debug_object, + 'HGETALL': lambda r: r and pairs_to_dict(r) or {}, + 'INFO': parse_info, + 'LASTSAVE': timestamp_to_datetime, + 'OBJECT': parse_object, + 'PING': lambda r: nativestr(r) == 'PONG', + 'RANDOMKEY': lambda r: r and r or None, + 'SCRIPT': parse_script, + 'SET': lambda r: r and nativestr(r) == 'OK', + 'TIME': lambda x: (int(x[0]), int(x[1])) + } + ) + + @classmethod + def from_url(cls, url, db=None, **kwargs): + """ + Return a Redis client object configured from the given URL. + + For example:: + + redis://username:password@localhost:6379/0 + + If ``db`` is None, this method will attempt to extract the database ID + from the URL path component. + + Any additional keyword arguments will be passed along to the Redis + class's initializer. + """ + url = urlparse(url) + + # We only support redis:// schemes. + assert url.scheme == 'redis' or not url.scheme + + # Extract the database ID from the path component if hasn't been given. + if db is None: + try: + db = int(url.path.replace('/', '')) + except (AttributeError, ValueError): + db = 0 + + return cls(host=url.hostname, port=int(url.port or 6379), db=db, + password=url.password, **kwargs) + + def __init__(self, host='localhost', port=6379, + db=0, password=None, socket_timeout=None, + connection_pool=None, charset='utf-8', + errors='strict', decode_responses=False, + unix_socket_path=None): + if not connection_pool: + kwargs = { + 'db': db, + 'password': password, + 'socket_timeout': socket_timeout, + 'encoding': charset, + 'encoding_errors': errors, + 'decode_responses': decode_responses, + } + # based on input, setup appropriate connection args + if unix_socket_path: + kwargs.update({ + 'path': unix_socket_path, + 'connection_class': UnixDomainSocketConnection + }) + else: + kwargs.update({ + 'host': host, + 'port': port + }) + connection_pool = ConnectionPool(**kwargs) + self.connection_pool = connection_pool + + self.response_callbacks = self.__class__.RESPONSE_CALLBACKS.copy() + + def set_response_callback(self, command, callback): + "Set a custom Response Callback" + self.response_callbacks[command] = callback + + def pipeline(self, transaction=True, shard_hint=None): + """ + Return a new pipeline object that can queue multiple commands for + later execution. ``transaction`` indicates whether all commands + should be executed atomically. Apart from making a group of operations + atomic, pipelines are useful for reducing the back-and-forth overhead + between the client and server. + """ + return StrictPipeline( + self.connection_pool, + self.response_callbacks, + transaction, + shard_hint) + + def transaction(self, func, *watches, **kwargs): + """ + Convenience method for executing the callable `func` as a transaction + while watching all keys specified in `watches`. The 'func' callable + should expect a single arguement which is a Pipeline object. + """ + shard_hint = kwargs.pop('shard_hint', None) + value_from_callable = kwargs.pop('value_from_callable', False) + with self.pipeline(True, shard_hint) as pipe: + while 1: + try: + if watches: + pipe.watch(*watches) + func_value = func(pipe) + exec_value = pipe.execute() + return func_value if value_from_callable else exec_value + except WatchError: + continue + + def lock(self, name, timeout=None, sleep=0.1): + """ + Return a new Lock object using key ``name`` that mimics + the behavior of threading.Lock. + + If specified, ``timeout`` indicates a maximum life for the lock. + By default, it will remain locked until release() is called. + + ``sleep`` indicates the amount of time to sleep per loop iteration + when the lock is in blocking mode and another client is currently + holding the lock. + """ + return Lock(self, name, timeout=timeout, sleep=sleep) + + def pubsub(self, shard_hint=None): + """ + Return a Publish/Subscribe object. With this object, you can + subscribe to channels and listen for messages that get published to + them. + """ + return PubSub(self.connection_pool, shard_hint) + + #### COMMAND EXECUTION AND PROTOCOL PARSING #### + def execute_command(self, *args, **options): + "Execute a command and return a parsed response" + pool = self.connection_pool + command_name = args[0] + connection = pool.get_connection(command_name, **options) + try: + connection.send_command(*args) + return self.parse_response(connection, command_name, **options) + except ConnectionError: + connection.disconnect() + connection.send_command(*args) + return self.parse_response(connection, command_name, **options) + finally: + pool.release(connection) + + def parse_response(self, connection, command_name, **options): + "Parses a response from the Redis server" + response = connection.read_response() + if command_name in self.response_callbacks: + return self.response_callbacks[command_name](response, **options) + return response + + #### SERVER INFORMATION #### + def bgrewriteaof(self): + "Tell the Redis server to rewrite the AOF file from data in memory." + return self.execute_command('BGREWRITEAOF') + + def bgsave(self): + """ + Tell the Redis server to save its data to disk. Unlike save(), + this method is asynchronous and returns immediately. + """ + return self.execute_command('BGSAVE') + + def client_kill(self, address): + "Disconnects the client at ``address`` (ip:port)" + return self.execute_command('CLIENT', 'KILL', address, parse='KILL') + + def client_list(self): + "Returns a list of currently connected clients" + return self.execute_command('CLIENT', 'LIST', parse='LIST') + + def client_getname(self): + "Returns the current connection name" + return self.execute_command('CLIENT', 'GETNAME', parse='GETNAME') + + def client_setname(self, name): + "Sets the current connection name" + return self.execute_command('CLIENT', 'SETNAME', name, parse='SETNAME') + + def config_get(self, pattern="*"): + "Return a dictionary of configuration based on the ``pattern``" + return self.execute_command('CONFIG', 'GET', pattern, parse='GET') + + def config_set(self, name, value): + "Set config item ``name`` with ``value``" + return self.execute_command('CONFIG', 'SET', name, value, parse='SET') + + def config_resetstat(self): + "Reset runtime statistics" + return self.execute_command('CONFIG', 'RESETSTAT', parse='RESETSTAT') + + def dbsize(self): + "Returns the number of keys in the current database" + return self.execute_command('DBSIZE') + + def debug_object(self, key): + "Returns version specific metainformation about a give key" + return self.execute_command('DEBUG', 'OBJECT', key) + + def echo(self, value): + "Echo the string back from the server" + return self.execute_command('ECHO', value) + + def flushall(self): + "Delete all keys in all databases on the current host" + return self.execute_command('FLUSHALL') + + def flushdb(self): + "Delete all keys in the current database" + return self.execute_command('FLUSHDB') + + def info(self, section=None): + """ + Returns a dictionary containing information about the Redis server + + The ``section`` option can be used to select a specific section + of information + + The section option is not supported by older versions of Redis Server, + and will generate ResponseError + """ + if section is None: + return self.execute_command('INFO') + else: + return self.execute_command('INFO', section) + + def lastsave(self): + """ + Return a Python datetime object representing the last time the + Redis database was saved to disk + """ + return self.execute_command('LASTSAVE') + + def object(self, infotype, key): + "Return the encoding, idletime, or refcount about the key" + return self.execute_command('OBJECT', infotype, key, infotype=infotype) + + def ping(self): + "Ping the Redis server" + return self.execute_command('PING') + + def save(self): + """ + Tell the Redis server to save its data to disk, + blocking until the save is complete + """ + return self.execute_command('SAVE') + + def shutdown(self): + "Shutdown the server" + try: + self.execute_command('SHUTDOWN') + except ConnectionError: + # a ConnectionError here is expected + return + raise RedisError("SHUTDOWN seems to have failed.") + + def slaveof(self, host=None, port=None): + """ + Set the server to be a replicated slave of the instance identified + by the ``host`` and ``port``. If called without arguements, the + instance is promoted to a master instead. + """ + if host is None and port is None: + return self.execute_command("SLAVEOF", "NO", "ONE") + return self.execute_command("SLAVEOF", host, port) + + def time(self): + """ + Returns the server time as a 2-item tuple of ints: + (seconds since epoch, microseconds into this second). + """ + return self.execute_command('TIME') + + #### BASIC KEY COMMANDS #### + def append(self, key, value): + """ + Appends the string ``value`` to the value at ``key``. If ``key`` + doesn't already exist, create it with a value of ``value``. + Returns the new length of the value at ``key``. + """ + return self.execute_command('APPEND', key, value) + + def bitcount(self, key, start=None, end=None): + """ + Returns the count of set bits in the value of ``key``. Optional + ``start`` and ``end`` paramaters indicate which bytes to consider + """ + params = [key] + if start is not None and end is not None: + params.append(start) + params.append(end) + elif (start is not None and end is None) or \ + (end is not None and start is None): + raise RedisError("Both start and end must be specified") + return self.execute_command('BITCOUNT', *params) + + def bitop(self, operation, dest, *keys): + """ + Perform a bitwise operation using ``operation`` between ``keys`` and + store the result in ``dest``. + """ + return self.execute_command('BITOP', operation, dest, *keys) + + def decr(self, name, amount=1): + """ + Decrements the value of ``key`` by ``amount``. If no key exists, + the value will be initialized as 0 - ``amount`` + """ + return self.execute_command('DECRBY', name, amount) + + def delete(self, *names): + "Delete one or more keys specified by ``names``" + return self.execute_command('DEL', *names) + __delitem__ = delete + + def exists(self, name): + "Returns a boolean indicating whether key ``name`` exists" + return self.execute_command('EXISTS', name) + __contains__ = exists + + def expire(self, name, time): + """ + Set an expire flag on key ``name`` for ``time`` seconds. ``time`` + can be represented by an integer or a Python timedelta object. + """ + if isinstance(time, datetime.timedelta): + time = time.seconds + time.days * 24 * 3600 + return self.execute_command('EXPIRE', name, time) + + def expireat(self, name, when): + """ + Set an expire flag on key ``name``. ``when`` can be represented + as an integer indicating unix time or a Python datetime object. + """ + if isinstance(when, datetime.datetime): + when = int(mod_time.mktime(when.timetuple())) + return self.execute_command('EXPIREAT', name, when) + + def get(self, name): + """ + Return the value at key ``name``, or None if the key doesn't exist + """ + return self.execute_command('GET', name) + + def __getitem__(self, name): + """ + Return the value at key ``name``, raises a KeyError if the key + doesn't exist. + """ + value = self.get(name) + if value: + return value + raise KeyError(name) + + def getbit(self, name, offset): + "Returns a boolean indicating the value of ``offset`` in ``name``" + return self.execute_command('GETBIT', name, offset) + + def getrange(self, key, start, end): + """ + Returns the substring of the string value stored at ``key``, + determined by the offsets ``start`` and ``end`` (both are inclusive) + """ + return self.execute_command('GETRANGE', key, start, end) + + def getset(self, name, value): + """ + Set the value at key ``name`` to ``value`` if key doesn't exist + Return the value at key ``name`` atomically + """ + return self.execute_command('GETSET', name, value) + + def incr(self, name, amount=1): + """ + Increments the value of ``key`` by ``amount``. If no key exists, + the value will be initialized as ``amount`` + """ + return self.execute_command('INCRBY', name, amount) + + def incrby(self, name, amount=1): + """ + Increments the value of ``key`` by ``amount``. If no key exists, + the value will be initialized as ``amount`` + """ + + # An alias for ``incr()``, because it is already implemented + # as INCRBY redis command. + return self.incr(name, amount) + + def incrbyfloat(self, name, amount=1.0): + """ + Increments the value at key ``name`` by floating ``amount``. + If no key exists, the value will be initialized as ``amount`` + """ + return self.execute_command('INCRBYFLOAT', name, amount) + + def keys(self, pattern='*'): + "Returns a list of keys matching ``pattern``" + return self.execute_command('KEYS', pattern) + + def mget(self, keys, *args): + """ + Returns a list of values ordered identically to ``keys`` + """ + args = list_or_args(keys, args) + return self.execute_command('MGET', *args) + + def mset(self, *args, **kwargs): + """ + Sets key/values based on a mapping. Mapping can be supplied as a single + dictionary argument or as kwargs. + """ + if args: + if len(args) != 1 or not isinstance(args[0], dict): + raise RedisError('MSET requires **kwargs or a single dict arg') + kwargs.update(args[0]) + items = [] + for pair in iteritems(kwargs): + items.extend(pair) + return self.execute_command('MSET', *items) + + def msetnx(self, *args, **kwargs): + """ + Sets key/values based on a mapping if none of the keys are already set. + Mapping can be supplied as a single dictionary argument or as kwargs. + Returns a boolean indicating if the operation was successful. + """ + if args: + if len(args) != 1 or not isinstance(args[0], dict): + raise RedisError('MSETNX requires **kwargs or a single ' + 'dict arg') + kwargs.update(args[0]) + items = [] + for pair in iteritems(kwargs): + items.extend(pair) + return self.execute_command('MSETNX', *items) + + def move(self, name, db): + "Moves the key ``name`` to a different Redis database ``db``" + return self.execute_command('MOVE', name, db) + + def persist(self, name): + "Removes an expiration on ``name``" + return self.execute_command('PERSIST', name) + + def pexpire(self, name, time): + """ + Set an expire flag on key ``name`` for ``time`` milliseconds. + ``time`` can be represented by an integer or a Python timedelta + object. + """ + if isinstance(time, datetime.timedelta): + ms = int(time.microseconds / 1000) + time = (time.seconds + time.days * 24 * 3600) * 1000 + ms + return self.execute_command('PEXPIRE', name, time) + + def pexpireat(self, name, when): + """ + Set an expire flag on key ``name``. ``when`` can be represented + as an integer representing unix time in milliseconds (unix time * 1000) + or a Python datetime object. + """ + if isinstance(when, datetime.datetime): + ms = int(when.microsecond / 1000) + when = int(mod_time.mktime(when.timetuple())) * 1000 + ms + return self.execute_command('PEXPIREAT', name, when) + + def psetex(self, name, time_ms, value): + """ + Set the value of key ``name`` to ``value`` that expires in ``time_ms`` + milliseconds. ``time_ms`` can be represented by an integer or a Python + timedelta object + """ + if isinstance(time_ms, datetime.timedelta): + ms = int(time_ms.microseconds / 1000) + time_ms = (time_ms.seconds + time_ms.days * 24 * 3600) * 1000 + ms + return self.execute_command('PSETEX', name, time_ms, value) + + def pttl(self, name): + "Returns the number of milliseconds until the key ``name`` will expire" + return self.execute_command('PTTL', name) + + def randomkey(self): + "Returns the name of a random key" + return self.execute_command('RANDOMKEY') + + def rename(self, src, dst): + """ + Rename key ``src`` to ``dst`` + """ + return self.execute_command('RENAME', src, dst) + + def renamenx(self, src, dst): + "Rename key ``src`` to ``dst`` if ``dst`` doesn't already exist" + return self.execute_command('RENAMENX', src, dst) + + def set(self, name, value, ex=None, px=None, nx=False, xx=False): + """ + Set the value at key ``name`` to ``value`` + + ``ex`` sets an expire flag on key ``name`` for ``ex`` seconds. + + ``px`` sets an expire flag on key ``name`` for ``px`` milliseconds. + + ``nx`` if set to True, set the value at key ``name`` to ``value`` if it + does not already exist. + + ``xx`` if set to True, set the value at key ``name`` to ``value`` if it + already exists. + """ + pieces = [name, value] + if ex: + pieces.append('EX') + if isinstance(ex, datetime.timedelta): + ex = ex.seconds + ex.days * 24 * 3600 + pieces.append(ex) + if px: + pieces.append('PX') + if isinstance(px, datetime.timedelta): + ms = int(px.microseconds / 1000) + px = (px.seconds + px.days * 24 * 3600) * 1000 + ms + pieces.append(px) + + if nx: + pieces.append('NX') + if xx: + pieces.append('XX') + return self.execute_command('SET', *pieces) + __setitem__ = set + + def setbit(self, name, offset, value): + """ + Flag the ``offset`` in ``name`` as ``value``. Returns a boolean + indicating the previous value of ``offset``. + """ + value = value and 1 or 0 + return self.execute_command('SETBIT', name, offset, value) + + def setex(self, name, time, value): + """ + Set the value of key ``name`` to ``value`` that expires in ``time`` + seconds. ``time`` can be represented by an integer or a Python + timedelta object. + """ + if isinstance(time, datetime.timedelta): + time = time.seconds + time.days * 24 * 3600 + return self.execute_command('SETEX', name, time, value) + + def setnx(self, name, value): + "Set the value of key ``name`` to ``value`` if key doesn't exist" + return self.execute_command('SETNX', name, value) + + def setrange(self, name, offset, value): + """ + Overwrite bytes in the value of ``name`` starting at ``offset`` with + ``value``. If ``offset`` plus the length of ``value`` exceeds the + length of the original value, the new value will be larger than before. + If ``offset`` exceeds the length of the original value, null bytes + will be used to pad between the end of the previous value and the start + of what's being injected. + + Returns the length of the new string. + """ + return self.execute_command('SETRANGE', name, offset, value) + + def strlen(self, name): + "Return the number of bytes stored in the value of ``name``" + return self.execute_command('STRLEN', name) + + def substr(self, name, start, end=-1): + """ + Return a substring of the string at key ``name``. ``start`` and ``end`` + are 0-based integers specifying the portion of the string to return. + """ + return self.execute_command('SUBSTR', name, start, end) + + def ttl(self, name): + "Returns the number of seconds until the key ``name`` will expire" + return self.execute_command('TTL', name) + + def type(self, name): + "Returns the type of key ``name``" + return self.execute_command('TYPE', name) + + def watch(self, *names): + """ + Watches the values at keys ``names``, or None if the key doesn't exist + """ + warnings.warn(DeprecationWarning('Call WATCH from a Pipeline object')) + + def unwatch(self): + """ + Unwatches the value at key ``name``, or None of the key doesn't exist + """ + warnings.warn( + DeprecationWarning('Call UNWATCH from a Pipeline object')) + + #### LIST COMMANDS #### + def blpop(self, keys, timeout=0): + """ + LPOP a value off of the first non-empty list + named in the ``keys`` list. + + If none of the lists in ``keys`` has a value to LPOP, then block + for ``timeout`` seconds, or until a value gets pushed on to one + of the lists. + + If timeout is 0, then block indefinitely. + """ + if timeout is None: + timeout = 0 + if isinstance(keys, basestring): + keys = [keys] + else: + keys = list(keys) + keys.append(timeout) + return self.execute_command('BLPOP', *keys) + + def brpop(self, keys, timeout=0): + """ + RPOP a value off of the first non-empty list + named in the ``keys`` list. + + If none of the lists in ``keys`` has a value to LPOP, then block + for ``timeout`` seconds, or until a value gets pushed on to one + of the lists. + + If timeout is 0, then block indefinitely. + """ + if timeout is None: + timeout = 0 + if isinstance(keys, basestring): + keys = [keys] + else: + keys = list(keys) + keys.append(timeout) + return self.execute_command('BRPOP', *keys) + + def brpoplpush(self, src, dst, timeout=0): + """ + Pop a value off the tail of ``src``, push it on the head of ``dst`` + and then return it. + + This command blocks until a value is in ``src`` or until ``timeout`` + seconds elapse, whichever is first. A ``timeout`` value of 0 blocks + forever. + """ + if timeout is None: + timeout = 0 + return self.execute_command('BRPOPLPUSH', src, dst, timeout) + + def lindex(self, name, index): + """ + Return the item from list ``name`` at position ``index`` + + Negative indexes are supported and will return an item at the + end of the list + """ + return self.execute_command('LINDEX', name, index) + + def linsert(self, name, where, refvalue, value): + """ + Insert ``value`` in list ``name`` either immediately before or after + [``where``] ``refvalue`` + + Returns the new length of the list on success or -1 if ``refvalue`` + is not in the list. + """ + return self.execute_command('LINSERT', name, where, refvalue, value) + + def llen(self, name): + "Return the length of the list ``name``" + return self.execute_command('LLEN', name) + + def lpop(self, name): + "Remove and return the first item of the list ``name``" + return self.execute_command('LPOP', name) + + def lpush(self, name, *values): + "Push ``values`` onto the head of the list ``name``" + return self.execute_command('LPUSH', name, *values) + + def lpushx(self, name, value): + "Push ``value`` onto the head of the list ``name`` if ``name`` exists" + return self.execute_command('LPUSHX', name, value) + + def lrange(self, name, start, end): + """ + Return a slice of the list ``name`` between + position ``start`` and ``end`` + + ``start`` and ``end`` can be negative numbers just like + Python slicing notation + """ + return self.execute_command('LRANGE', name, start, end) + + def lrem(self, name, count, value): + """ + Remove the first ``count`` occurrences of elements equal to ``value`` + from the list stored at ``name``. + + The count argument influences the operation in the following ways: + count > 0: Remove elements equal to value moving from head to tail. + count < 0: Remove elements equal to value moving from tail to head. + count = 0: Remove all elements equal to value. + """ + return self.execute_command('LREM', name, count, value) + + def lset(self, name, index, value): + "Set ``position`` of list ``name`` to ``value``" + return self.execute_command('LSET', name, index, value) + + def ltrim(self, name, start, end): + """ + Trim the list ``name``, removing all values not within the slice + between ``start`` and ``end`` + + ``start`` and ``end`` can be negative numbers just like + Python slicing notation + """ + return self.execute_command('LTRIM', name, start, end) + + def rpop(self, name): + "Remove and return the last item of the list ``name``" + return self.execute_command('RPOP', name) + + def rpoplpush(self, src, dst): + """ + RPOP a value off of the ``src`` list and atomically LPUSH it + on to the ``dst`` list. Returns the value. + """ + return self.execute_command('RPOPLPUSH', src, dst) + + def rpush(self, name, *values): + "Push ``values`` onto the tail of the list ``name``" + return self.execute_command('RPUSH', name, *values) + + def rpushx(self, name, value): + "Push ``value`` onto the tail of the list ``name`` if ``name`` exists" + return self.execute_command('RPUSHX', name, value) + + def sort(self, name, start=None, num=None, by=None, get=None, + desc=False, alpha=False, store=None, groups=False): + """ + Sort and return the list, set or sorted set at ``name``. + + ``start`` and ``num`` allow for paging through the sorted data + + ``by`` allows using an external key to weight and sort the items. + Use an "*" to indicate where in the key the item value is located + + ``get`` allows for returning items from external keys rather than the + sorted data itself. Use an "*" to indicate where int he key + the item value is located + + ``desc`` allows for reversing the sort + + ``alpha`` allows for sorting lexicographically rather than numerically + + ``store`` allows for storing the result of the sort into + the key ``store`` + + ``groups`` if set to True and if ``get`` contains at least two + elements, sort will return a list of tuples, each containing the + values fetched from the arguments to ``get``. + + """ + if (start is not None and num is None) or \ + (num is not None and start is None): + raise RedisError("``start`` and ``num`` must both be specified") + + pieces = [name] + if by is not None: + pieces.append('BY') + pieces.append(by) + if start is not None and num is not None: + pieces.append('LIMIT') + pieces.append(start) + pieces.append(num) + if get is not None: + # If get is a string assume we want to get a single value. + # Otherwise assume it's an interable and we want to get multiple + # values. We can't just iterate blindly because strings are + # iterable. + if isinstance(get, basestring): + pieces.append('GET') + pieces.append(get) + else: + for g in get: + pieces.append('GET') + pieces.append(g) + if desc: + pieces.append('DESC') + if alpha: + pieces.append('ALPHA') + if store is not None: + pieces.append('STORE') + pieces.append(store) + + if groups: + if not get or isinstance(get, basestring) or len(get) < 2: + raise DataError('when using "groups" the "get" argument ' + 'must be specified and contain at least ' + 'two keys') + + options = {'groups': len(get) if groups else None} + return self.execute_command('SORT', *pieces, **options) + + #### SET COMMANDS #### + def sadd(self, name, *values): + "Add ``value(s)`` to set ``name``" + return self.execute_command('SADD', name, *values) + + def scard(self, name): + "Return the number of elements in set ``name``" + return self.execute_command('SCARD', name) + + def sdiff(self, keys, *args): + "Return the difference of sets specified by ``keys``" + args = list_or_args(keys, args) + return self.execute_command('SDIFF', *args) + + def sdiffstore(self, dest, keys, *args): + """ + Store the difference of sets specified by ``keys`` into a new + set named ``dest``. Returns the number of keys in the new set. + """ + args = list_or_args(keys, args) + return self.execute_command('SDIFFSTORE', dest, *args) + + def sinter(self, keys, *args): + "Return the intersection of sets specified by ``keys``" + args = list_or_args(keys, args) + return self.execute_command('SINTER', *args) + + def sinterstore(self, dest, keys, *args): + """ + Store the intersection of sets specified by ``keys`` into a new + set named ``dest``. Returns the number of keys in the new set. + """ + args = list_or_args(keys, args) + return self.execute_command('SINTERSTORE', dest, *args) + + def sismember(self, name, value): + "Return a boolean indicating if ``value`` is a member of set ``name``" + return self.execute_command('SISMEMBER', name, value) + + def smembers(self, name): + "Return all members of the set ``name``" + return self.execute_command('SMEMBERS', name) + + def smove(self, src, dst, value): + "Move ``value`` from set ``src`` to set ``dst`` atomically" + return self.execute_command('SMOVE', src, dst, value) + + def spop(self, name): + "Remove and return a random member of set ``name``" + return self.execute_command('SPOP', name) + + def srandmember(self, name, number=None): + """ + If ``number`` is None, returns a random member of set ``name``. + + If ``number`` is supplied, returns a list of ``number`` random + memebers of set ``name``. Note this is only available when running + Redis 2.6+. + """ + args = number and [number] or [] + return self.execute_command('SRANDMEMBER', name, *args) + + def srem(self, name, *values): + "Remove ``values`` from set ``name``" + return self.execute_command('SREM', name, *values) + + def sunion(self, keys, *args): + "Return the union of sets specifiued by ``keys``" + args = list_or_args(keys, args) + return self.execute_command('SUNION', *args) + + def sunionstore(self, dest, keys, *args): + """ + Store the union of sets specified by ``keys`` into a new + set named ``dest``. Returns the number of keys in the new set. + """ + args = list_or_args(keys, args) + return self.execute_command('SUNIONSTORE', dest, *args) + + #### SORTED SET COMMANDS #### + def zadd(self, name, *args, **kwargs): + """ + Set any number of score, element-name pairs to the key ``name``. Pairs + can be specified in two ways: + + As *args, in the form of: score1, name1, score2, name2, ... + or as **kwargs, in the form of: name1=score1, name2=score2, ... + + The following example would add four values to the 'my-key' key: + redis.zadd('my-key', 1.1, 'name1', 2.2, 'name2', name3=3.3, name4=4.4) + """ + pieces = [] + if args: + if len(args) % 2 != 0: + raise RedisError("ZADD requires an equal number of " + "values and scores") + pieces.extend(args) + for pair in iteritems(kwargs): + pieces.append(pair[1]) + pieces.append(pair[0]) + return self.execute_command('ZADD', name, *pieces) + + def zcard(self, name): + "Return the number of elements in the sorted set ``name``" + return self.execute_command('ZCARD', name) + + def zcount(self, name, min, max): + return self.execute_command('ZCOUNT', name, min, max) + + def zincrby(self, name, value, amount=1): + "Increment the score of ``value`` in sorted set ``name`` by ``amount``" + return self.execute_command('ZINCRBY', name, amount, value) + + def zinterstore(self, dest, keys, aggregate=None): + """ + Intersect multiple sorted sets specified by ``keys`` into + a new sorted set, ``dest``. Scores in the destination will be + aggregated based on the ``aggregate``, or SUM if none is provided. + """ + return self._zaggregate('ZINTERSTORE', dest, keys, aggregate) + + def zrange(self, name, start, end, desc=False, withscores=False, + score_cast_func=float): + """ + Return a range of values from sorted set ``name`` between + ``start`` and ``end`` sorted in ascending order. + + ``start`` and ``end`` can be negative, indicating the end of the range. + + ``desc`` a boolean indicating whether to sort the results descendingly + + ``withscores`` indicates to return the scores along with the values. + The return type is a list of (value, score) pairs + + ``score_cast_func`` a callable used to cast the score return value + """ + if desc: + return self.zrevrange(name, start, end, withscores, + score_cast_func) + pieces = ['ZRANGE', name, start, end] + if withscores: + pieces.append('withscores') + options = { + 'withscores': withscores, 'score_cast_func': score_cast_func} + return self.execute_command(*pieces, **options) + + def zrangebyscore(self, name, min, max, start=None, num=None, + withscores=False, score_cast_func=float): + """ + Return a range of values from the sorted set ``name`` with scores + between ``min`` and ``max``. + + If ``start`` and ``num`` are specified, then return a slice + of the range. + + ``withscores`` indicates to return the scores along with the values. + The return type is a list of (value, score) pairs + + `score_cast_func`` a callable used to cast the score return value + """ + if (start is not None and num is None) or \ + (num is not None and start is None): + raise RedisError("``start`` and ``num`` must both be specified") + pieces = ['ZRANGEBYSCORE', name, min, max] + if start is not None and num is not None: + pieces.extend(['LIMIT', start, num]) + if withscores: + pieces.append('withscores') + options = { + 'withscores': withscores, 'score_cast_func': score_cast_func} + return self.execute_command(*pieces, **options) + + def zrank(self, name, value): + """ + Returns a 0-based value indicating the rank of ``value`` in sorted set + ``name`` + """ + return self.execute_command('ZRANK', name, value) + + def zrem(self, name, *values): + "Remove member ``values`` from sorted set ``name``" + return self.execute_command('ZREM', name, *values) + + def zremrangebyrank(self, name, min, max): + """ + Remove all elements in the sorted set ``name`` with ranks between + ``min`` and ``max``. Values are 0-based, ordered from smallest score + to largest. Values can be negative indicating the highest scores. + Returns the number of elements removed + """ + return self.execute_command('ZREMRANGEBYRANK', name, min, max) + + def zremrangebyscore(self, name, min, max): + """ + Remove all elements in the sorted set ``name`` with scores + between ``min`` and ``max``. Returns the number of elements removed. + """ + return self.execute_command('ZREMRANGEBYSCORE', name, min, max) + + def zrevrange(self, name, start, num, withscores=False, + score_cast_func=float): + """ + Return a range of values from sorted set ``name`` between + ``start`` and ``num`` sorted in descending order. + + ``start`` and ``num`` can be negative, indicating the end of the range. + + ``withscores`` indicates to return the scores along with the values + The return type is a list of (value, score) pairs + + ``score_cast_func`` a callable used to cast the score return value + """ + pieces = ['ZREVRANGE', name, start, num] + if withscores: + pieces.append('withscores') + options = { + 'withscores': withscores, 'score_cast_func': score_cast_func} + return self.execute_command(*pieces, **options) + + def zrevrangebyscore(self, name, max, min, start=None, num=None, + withscores=False, score_cast_func=float): + """ + Return a range of values from the sorted set ``name`` with scores + between ``min`` and ``max`` in descending order. + + If ``start`` and ``num`` are specified, then return a slice + of the range. + + ``withscores`` indicates to return the scores along with the values. + The return type is a list of (value, score) pairs + + ``score_cast_func`` a callable used to cast the score return value + """ + if (start is not None and num is None) or \ + (num is not None and start is None): + raise RedisError("``start`` and ``num`` must both be specified") + pieces = ['ZREVRANGEBYSCORE', name, max, min] + if start is not None and num is not None: + pieces.extend(['LIMIT', start, num]) + if withscores: + pieces.append('withscores') + options = { + 'withscores': withscores, 'score_cast_func': score_cast_func} + return self.execute_command(*pieces, **options) + + def zrevrank(self, name, value): + """ + Returns a 0-based value indicating the descending rank of + ``value`` in sorted set ``name`` + """ + return self.execute_command('ZREVRANK', name, value) + + def zscore(self, name, value): + "Return the score of element ``value`` in sorted set ``name``" + return self.execute_command('ZSCORE', name, value) + + def zunionstore(self, dest, keys, aggregate=None): + """ + Union multiple sorted sets specified by ``keys`` into + a new sorted set, ``dest``. Scores in the destination will be + aggregated based on the ``aggregate``, or SUM if none is provided. + """ + return self._zaggregate('ZUNIONSTORE', dest, keys, aggregate) + + def _zaggregate(self, command, dest, keys, aggregate=None): + pieces = [command, dest, len(keys)] + if isinstance(keys, dict): + keys, weights = iterkeys(keys), itervalues(keys) + else: + weights = None + pieces.extend(keys) + if weights: + pieces.append('WEIGHTS') + pieces.extend(weights) + if aggregate: + pieces.append('AGGREGATE') + pieces.append(aggregate) + return self.execute_command(*pieces) + + #### HASH COMMANDS #### + def hdel(self, name, *keys): + "Delete ``keys`` from hash ``name``" + return self.execute_command('HDEL', name, *keys) + + def hexists(self, name, key): + "Returns a boolean indicating if ``key`` exists within hash ``name``" + return self.execute_command('HEXISTS', name, key) + + def hget(self, name, key): + "Return the value of ``key`` within the hash ``name``" + return self.execute_command('HGET', name, key) + + def hgetall(self, name): + "Return a Python dict of the hash's name/value pairs" + return self.execute_command('HGETALL', name) + + def hincrby(self, name, key, amount=1): + "Increment the value of ``key`` in hash ``name`` by ``amount``" + return self.execute_command('HINCRBY', name, key, amount) + + def hincrbyfloat(self, name, key, amount=1.0): + """ + Increment the value of ``key`` in hash ``name`` by floating ``amount`` + """ + return self.execute_command('HINCRBYFLOAT', name, key, amount) + + def hkeys(self, name): + "Return the list of keys within hash ``name``" + return self.execute_command('HKEYS', name) + + def hlen(self, name): + "Return the number of elements in hash ``name``" + return self.execute_command('HLEN', name) + + def hset(self, name, key, value): + """ + Set ``key`` to ``value`` within hash ``name`` + Returns 1 if HSET created a new field, otherwise 0 + """ + return self.execute_command('HSET', name, key, value) + + def hsetnx(self, name, key, value): + """ + Set ``key`` to ``value`` within hash ``name`` if ``key`` does not + exist. Returns 1 if HSETNX created a field, otherwise 0. + """ + return self.execute_command("HSETNX", name, key, value) + + def hmset(self, name, mapping): + """ + Sets each key in the ``mapping`` dict to its corresponding value + in the hash ``name`` + """ + if not mapping: + raise DataError("'hmset' with 'mapping' of length 0") + items = [] + for pair in iteritems(mapping): + items.extend(pair) + return self.execute_command('HMSET', name, *items) + + def hmget(self, name, keys, *args): + "Returns a list of values ordered identically to ``keys``" + args = list_or_args(keys, args) + return self.execute_command('HMGET', name, *args) + + def hvals(self, name): + "Return the list of values within hash ``name``" + return self.execute_command('HVALS', name) + + def publish(self, channel, message): + """ + Publish ``message`` on ``channel``. + Returns the number of subscribers the message was delivered to. + """ + return self.execute_command('PUBLISH', channel, message) + + def eval(self, script, numkeys, *keys_and_args): + """ + Execute the LUA ``script``, specifying the ``numkeys`` the script + will touch and the key names and argument values in ``keys_and_args``. + Returns the result of the script. + + In practice, use the object returned by ``register_script``. This + function exists purely for Redis API completion. + """ + return self.execute_command('EVAL', script, numkeys, *keys_and_args) + + def evalsha(self, sha, numkeys, *keys_and_args): + """ + Use the ``sha`` to execute a LUA script already registered via EVAL + or SCRIPT LOAD. Specify the ``numkeys`` the script will touch and the + key names and argument values in ``keys_and_args``. Returns the result + of the script. + + In practice, use the object returned by ``register_script``. This + function exists purely for Redis API completion. + """ + return self.execute_command('EVALSHA', sha, numkeys, *keys_and_args) + + def script_exists(self, *args): + """ + Check if a script exists in the script cache by specifying the SHAs of + each script as ``args``. Returns a list of boolean values indicating if + if each already script exists in the cache. + """ + options = {'parse': 'EXISTS'} + return self.execute_command('SCRIPT', 'EXISTS', *args, **options) + + def script_flush(self): + "Flush all scripts from the script cache" + options = {'parse': 'FLUSH'} + return self.execute_command('SCRIPT', 'FLUSH', **options) + + def script_kill(self): + "Kill the currently executing LUA script" + options = {'parse': 'KILL'} + return self.execute_command('SCRIPT', 'KILL', **options) + + def script_load(self, script): + "Load a LUA ``script`` into the script cache. Returns the SHA." + options = {'parse': 'LOAD'} + return self.execute_command('SCRIPT', 'LOAD', script, **options) + + def register_script(self, script): + """ + Register a LUA ``script`` specifying the ``keys`` it will touch. + Returns a Script object that is callable and hides the complexity of + deal with scripts, keys, and shas. This is the preferred way to work + with LUA scripts. + """ + return Script(self, script) + + +class Redis(StrictRedis): + """ + Provides backwards compatibility with older versions of redis-py that + changed arguments to some commands to be more Pythonic, sane, or by + accident. + """ + + # Overridden callbacks + RESPONSE_CALLBACKS = dict_merge( + StrictRedis.RESPONSE_CALLBACKS, + { + 'TTL': lambda r: r != -1 and r or None, + 'PTTL': lambda r: r != -1 and r or None, + } + ) + + def pipeline(self, transaction=True, shard_hint=None): + """ + Return a new pipeline object that can queue multiple commands for + later execution. ``transaction`` indicates whether all commands + should be executed atomically. Apart from making a group of operations + atomic, pipelines are useful for reducing the back-and-forth overhead + between the client and server. + """ + return Pipeline( + self.connection_pool, + self.response_callbacks, + transaction, + shard_hint) + + def setex(self, name, value, time): + """ + Set the value of key ``name`` to ``value`` that expires in ``time`` + seconds. ``time`` can be represented by an integer or a Python + timedelta object. + """ + if isinstance(time, datetime.timedelta): + time = time.seconds + time.days * 24 * 3600 + return self.execute_command('SETEX', name, time, value) + + def lrem(self, name, value, num=0): + """ + Remove the first ``num`` occurrences of elements equal to ``value`` + from the list stored at ``name``. + + The ``num`` argument influences the operation in the following ways: + num > 0: Remove elements equal to value moving from head to tail. + num < 0: Remove elements equal to value moving from tail to head. + num = 0: Remove all elements equal to value. + """ + return self.execute_command('LREM', name, num, value) + + def zadd(self, name, *args, **kwargs): + """ + NOTE: The order of arguments differs from that of the official ZADD + command. For backwards compatability, this method accepts arguments + in the form of name1, score1, name2, score2, while the official Redis + documents expects score1, name1, score2, name2. + + If you're looking to use the standard syntax, consider using the + StrictRedis class. See the API Reference section of the docs for more + information. + + Set any number of element-name, score pairs to the key ``name``. Pairs + can be specified in two ways: + + As *args, in the form of: name1, score1, name2, score2, ... + or as **kwargs, in the form of: name1=score1, name2=score2, ... + + The following example would add four values to the 'my-key' key: + redis.zadd('my-key', 'name1', 1.1, 'name2', 2.2, name3=3.3, name4=4.4) + """ + pieces = [] + if args: + if len(args) % 2 != 0: + raise RedisError("ZADD requires an equal number of " + "values and scores") + pieces.extend(reversed(args)) + for pair in iteritems(kwargs): + pieces.append(pair[1]) + pieces.append(pair[0]) + return self.execute_command('ZADD', name, *pieces) + + +class PubSub(object): + """ + PubSub provides publish, subscribe and listen support to Redis channels. + + After subscribing to one or more channels, the listen() method will block + until a message arrives on one of the subscribed channels. That message + will be returned and it's safe to start listening again. + """ + def __init__(self, connection_pool, shard_hint=None): + self.connection_pool = connection_pool + self.shard_hint = shard_hint + self.connection = None + self.channels = set() + self.patterns = set() + self.subscription_count = 0 + self.subscribe_commands = set( + ('subscribe', 'psubscribe', 'unsubscribe', 'punsubscribe') + ) + + def __del__(self): + try: + # if this object went out of scope prior to shutting down + # subscriptions, close the connection manually before + # returning it to the connection pool + if self.connection and (self.channels or self.patterns): + self.connection.disconnect() + self.reset() + except Exception: + pass + + def reset(self): + if self.connection: + self.connection.disconnect() + self.connection_pool.release(self.connection) + self.connection = None + + def close(self): + self.reset() + + def execute_command(self, *args, **kwargs): + "Execute a publish/subscribe command" + + # NOTE: don't parse the response in this function. it could pull a + # legitmate message off the stack if the connection is already + # subscribed to one or more channels + + if self.connection is None: + self.connection = self.connection_pool.get_connection( + 'pubsub', + self.shard_hint + ) + connection = self.connection + try: + connection.send_command(*args) + except ConnectionError: + connection.disconnect() + # Connect manually here. If the Redis server is down, this will + # fail and raise a ConnectionError as desired. + connection.connect() + # resubscribe to all channels and patterns before + # resending the current command + for channel in self.channels: + self.subscribe(channel) + for pattern in self.patterns: + self.psubscribe(pattern) + connection.send_command(*args) + + def parse_response(self): + "Parse the response from a publish/subscribe command" + response = self.connection.read_response() + if nativestr(response[0]) in self.subscribe_commands: + self.subscription_count = response[2] + # if we've just unsubscribed from the remaining channels, + # release the connection back to the pool + if not self.subscription_count: + self.reset() + return response + + def psubscribe(self, patterns): + "Subscribe to all channels matching any pattern in ``patterns``" + if isinstance(patterns, basestring): + patterns = [patterns] + for pattern in patterns: + self.patterns.add(pattern) + return self.execute_command('PSUBSCRIBE', *patterns) + + def punsubscribe(self, patterns=[]): + """ + Unsubscribe from any channel matching any pattern in ``patterns``. + If empty, unsubscribe from all channels. + """ + if isinstance(patterns, basestring): + patterns = [patterns] + for pattern in patterns: + try: + self.patterns.remove(pattern) + except KeyError: + pass + return self.execute_command('PUNSUBSCRIBE', *patterns) + + def subscribe(self, channels): + "Subscribe to ``channels``, waiting for messages to be published" + if isinstance(channels, basestring): + channels = [channels] + for channel in channels: + self.channels.add(channel) + return self.execute_command('SUBSCRIBE', *channels) + + def unsubscribe(self, channels=[]): + """ + Unsubscribe from ``channels``. If empty, unsubscribe + from all channels + """ + if isinstance(channels, basestring): + channels = [channels] + for channel in channels: + try: + self.channels.remove(channel) + except KeyError: + pass + return self.execute_command('UNSUBSCRIBE', *channels) + + def listen(self): + "Listen for messages on channels this client has been subscribed to" + while self.subscription_count or self.channels or self.patterns: + r = self.parse_response() + msg_type = nativestr(r[0]) + if msg_type == 'pmessage': + msg = { + 'type': msg_type, + 'pattern': nativestr(r[1]), + 'channel': nativestr(r[2]), + 'data': r[3] + } + else: + msg = { + 'type': msg_type, + 'pattern': None, + 'channel': nativestr(r[1]), + 'data': r[2] + } + yield msg + + +class BasePipeline(object): + """ + Pipelines provide a way to transmit multiple commands to the Redis server + in one transmission. This is convenient for batch processing, such as + saving all the values in a list to Redis. + + All commands executed within a pipeline are wrapped with MULTI and EXEC + calls. This guarantees all commands executed in the pipeline will be + executed atomically. + + Any command raising an exception does *not* halt the execution of + subsequent commands in the pipeline. Instead, the exception is caught + and its instance is placed into the response list returned by execute(). + Code iterating over the response list should be able to deal with an + instance of an exception as a potential value. In general, these will be + ResponseError exceptions, such as those raised when issuing a command + on a key of a different datatype. + """ + + UNWATCH_COMMANDS = set(('DISCARD', 'EXEC', 'UNWATCH')) + + def __init__(self, connection_pool, response_callbacks, transaction, + shard_hint): + self.connection_pool = connection_pool + self.connection = None + self.response_callbacks = response_callbacks + self.transaction = transaction + self.shard_hint = shard_hint + + self.watching = False + self.reset() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.reset() + + def __del__(self): + try: + self.reset() + except Exception: + pass + + def __len__(self): + return len(self.command_stack) + + def reset(self): + self.command_stack = [] + self.scripts = set() + # make sure to reset the connection state in the event that we were + # watching something + if self.watching and self.connection: + try: + # call this manually since our unwatch or + # immediate_execute_command methods can call reset() + self.connection.send_command('UNWATCH') + self.connection.read_response() + except ConnectionError: + # disconnect will also remove any previous WATCHes + self.connection.disconnect() + # clean up the other instance attributes + self.watching = False + self.explicit_transaction = False + # we can safely return the connection to the pool here since we're + # sure we're no longer WATCHing anything + if self.connection: + self.connection_pool.release(self.connection) + self.connection = None + + def multi(self): + """ + Start a transactional block of the pipeline after WATCH commands + are issued. End the transactional block with `execute`. + """ + if self.explicit_transaction: + raise RedisError('Cannot issue nested calls to MULTI') + if self.command_stack: + raise RedisError('Commands without an initial WATCH have already ' + 'been issued') + self.explicit_transaction = True + + def execute_command(self, *args, **kwargs): + if (self.watching or args[0] == 'WATCH') and \ + not self.explicit_transaction: + return self.immediate_execute_command(*args, **kwargs) + return self.pipeline_execute_command(*args, **kwargs) + + def immediate_execute_command(self, *args, **options): + """ + Execute a command immediately, but don't auto-retry on a + ConnectionError if we're already WATCHing a variable. Used when + issuing WATCH or subsequent commands retrieving their values but before + MULTI is called. + """ + command_name = args[0] + conn = self.connection + # if this is the first call, we need a connection + if not conn: + conn = self.connection_pool.get_connection(command_name, + self.shard_hint) + self.connection = conn + try: + conn.send_command(*args) + return self.parse_response(conn, command_name, **options) + except ConnectionError: + conn.disconnect() + # if we're not already watching, we can safely retry the command + # assuming it was a connection timeout + if not self.watching: + conn.send_command(*args) + return self.parse_response(conn, command_name, **options) + self.reset() + raise + + def pipeline_execute_command(self, *args, **options): + """ + Stage a command to be executed when execute() is next called + + Returns the current Pipeline object back so commands can be + chained together, such as: + + pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') + + At some other point, you can then run: pipe.execute(), + which will execute all commands queued in the pipe. + """ + self.command_stack.append((args, options)) + return self + + def _execute_transaction(self, connection, commands, raise_on_error): + cmds = chain([(('MULTI', ), {})], commands, [(('EXEC', ), {})]) + all_cmds = SYM_EMPTY.join( + starmap(connection.pack_command, + [args for args, options in cmds])) + connection.send_packed_command(all_cmds) + # parse off the response for MULTI + self.parse_response(connection, '_') + # and all the other commands + errors = [] + for i, _ in enumerate(commands): + try: + self.parse_response(connection, '_') + except ResponseError: + errors.append((i, sys.exc_info()[1])) + + # parse the EXEC. + try: + response = self.parse_response(connection, '_') + except ExecAbortError: + if self.explicit_transaction: + self.immediate_execute_command('DISCARD') + if errors: + raise errors[0][1] + raise sys.exc_info()[1] + + if response is None: + raise WatchError("Watched variable changed.") + + # put any parse errors into the response + for i, e in errors: + response.insert(i, e) + + if len(response) != len(commands): + raise ResponseError("Wrong number of response items from " + "pipeline execution") + + # find any errors in the response and raise if necessary + if raise_on_error: + self.raise_first_error(response) + + # We have to run response callbacks manually + data = [] + for r, cmd in izip(response, commands): + if not isinstance(r, Exception): + args, options = cmd + command_name = args[0] + if command_name in self.response_callbacks: + r = self.response_callbacks[command_name](r, **options) + data.append(r) + return data + + def _execute_pipeline(self, connection, commands, raise_on_error): + # build up all commands into a single request to increase network perf + all_cmds = SYM_EMPTY.join( + starmap(connection.pack_command, + [args for args, options in commands])) + connection.send_packed_command(all_cmds) + + response = [] + for args, options in commands: + try: + response.append( + self.parse_response(connection, args[0], **options)) + except ResponseError: + response.append(sys.exc_info()[1]) + + if raise_on_error: + self.raise_first_error(response) + return response + + def raise_first_error(self, response): + for r in response: + if isinstance(r, ResponseError): + raise r + + def parse_response(self, connection, command_name, **options): + result = StrictRedis.parse_response( + self, connection, command_name, **options) + if command_name in self.UNWATCH_COMMANDS: + self.watching = False + elif command_name == 'WATCH': + self.watching = True + return result + + def load_scripts(self): + # make sure all scripts that are about to be run on this pipeline exist + scripts = list(self.scripts) + immediate = self.immediate_execute_command + shas = [s.sha for s in scripts] + exists = immediate('SCRIPT', 'EXISTS', *shas, **{'parse': 'EXISTS'}) + if not all(exists): + for s, exist in izip(scripts, exists): + if not exist: + immediate('SCRIPT', 'LOAD', s.script, **{'parse': 'LOAD'}) + + def execute(self, raise_on_error=True): + "Execute all the commands in the current pipeline" + if self.scripts: + self.load_scripts() + stack = self.command_stack + if self.transaction or self.explicit_transaction: + execute = self._execute_transaction + else: + execute = self._execute_pipeline + + conn = self.connection + if not conn: + conn = self.connection_pool.get_connection('MULTI', + self.shard_hint) + # assign to self.connection so reset() releases the connection + # back to the pool after we're done + self.connection = conn + + try: + return execute(conn, stack, raise_on_error) + except ConnectionError: + conn.disconnect() + # if we were watching a variable, the watch is no longer valid + # since this connection has died. raise a WatchError, which + # indicates the user should retry his transaction. If this is more + # than a temporary failure, the WATCH that the user next issue + # will fail, propegating the real ConnectionError + if self.watching: + raise WatchError("A ConnectionError occured on while watching " + "one or more keys") + # otherwise, it's safe to retry since the transaction isn't + # predicated on any state + return execute(conn, stack, raise_on_error) + finally: + self.reset() + + def watch(self, *names): + "Watches the values at keys ``names``" + if self.explicit_transaction: + raise RedisError('Cannot issue a WATCH after a MULTI') + return self.execute_command('WATCH', *names) + + def unwatch(self): + "Unwatches all previously specified keys" + return self.watching and self.execute_command('UNWATCH') or True + + def script_load_for_pipeline(self, script): + "Make sure scripts are loaded prior to pipeline execution" + self.scripts.add(script) + + +class StrictPipeline(BasePipeline, StrictRedis): + "Pipeline for the StrictRedis class" + pass + + +class Pipeline(BasePipeline, Redis): + "Pipeline for the Redis class" + pass + + +class Script(object): + "An executable LUA script object returned by ``register_script``" + + def __init__(self, registered_client, script): + self.registered_client = registered_client + self.script = script + self.sha = registered_client.script_load(script) + + def __call__(self, keys=[], args=[], client=None): + "Execute the script, passing any required ``args``" + if client is None: + client = self.registered_client + args = tuple(keys) + tuple(args) + # make sure the Redis server knows about the script + if isinstance(client, BasePipeline): + # make sure this script is good to go on pipeline + client.script_load_for_pipeline(self) + try: + return client.evalsha(self.sha, len(keys), *args) + except NoScriptError: + # Maybe the client is pointed to a differnet server than the client + # that created this instance? + self.sha = client.script_load(self.script) + return client.evalsha(self.sha, len(keys), *args) + + +class LockError(RedisError): + "Errors thrown from the Lock" + pass + + +class Lock(object): + """ + A shared, distributed Lock. Using Redis for locking allows the Lock + to be shared across processes and/or machines. + + It's left to the user to resolve deadlock issues and make sure + multiple clients play nicely together. + """ + + LOCK_FOREVER = float(2 ** 31 + 1) # 1 past max unix time + + def __init__(self, redis, name, timeout=None, sleep=0.1): + """ + Create a new Lock instnace named ``name`` using the Redis client + supplied by ``redis``. + + ``timeout`` indicates a maximum life for the lock. + By default, it will remain locked until release() is called. + + ``sleep`` indicates the amount of time to sleep per loop iteration + when the lock is in blocking mode and another client is currently + holding the lock. + + Note: If using ``timeout``, you should make sure all the hosts + that are running clients have their time synchronized with a network + time service like ntp. + """ + self.redis = redis + self.name = name + self.acquired_until = None + self.timeout = timeout + self.sleep = sleep + if self.timeout and self.sleep > self.timeout: + raise LockError("'sleep' must be less than 'timeout'") + + def __enter__(self): + return self.acquire() + + def __exit__(self, exc_type, exc_value, traceback): + self.release() + + def acquire(self, blocking=True): + """ + Use Redis to hold a shared, distributed lock named ``name``. + Returns True once the lock is acquired. + + If ``blocking`` is False, always return immediately. If the lock + was acquired, return True, otherwise return False. + """ + sleep = self.sleep + timeout = self.timeout + while 1: + unixtime = int(mod_time.time()) + if timeout: + timeout_at = unixtime + timeout + else: + timeout_at = Lock.LOCK_FOREVER + timeout_at = float(timeout_at) + if self.redis.setnx(self.name, timeout_at): + self.acquired_until = timeout_at + return True + # We want blocking, but didn't acquire the lock + # check to see if the current lock is expired + existing = float(self.redis.get(self.name) or 1) + if existing < unixtime: + # the previous lock is expired, attempt to overwrite it + existing = float(self.redis.getset(self.name, timeout_at) or 1) + if existing < unixtime: + # we successfully acquired the lock + self.acquired_until = timeout_at + return True + if not blocking: + return False + mod_time.sleep(sleep) + + def release(self): + "Releases the already acquired lock" + if self.acquired_until is None: + raise ValueError("Cannot release an unlocked lock") + existing = float(self.redis.get(self.name) or 1) + # if the lock time is in the future, delete the lock + if existing >= self.acquired_until: + self.redis.delete(self.name) + self.acquired_until = None diff --git a/client/ledis-py/ledis/connection.py b/client/ledis-py/ledis/connection.py new file mode 100644 index 0000000..4b509b1 --- /dev/null +++ b/client/ledis-py/ledis/connection.py @@ -0,0 +1,580 @@ +from itertools import chain +import os +import socket +import sys + +from redis._compat import (b, xrange, imap, byte_to_chr, unicode, bytes, long, + BytesIO, nativestr, basestring, + LifoQueue, Empty, Full) +from redis.exceptions import ( + RedisError, + ConnectionError, + BusyLoadingError, + ResponseError, + InvalidResponse, + AuthenticationError, + NoScriptError, + ExecAbortError, +) +from redis.utils import HIREDIS_AVAILABLE +if HIREDIS_AVAILABLE: + import hiredis + + +SYM_STAR = b('*') +SYM_DOLLAR = b('$') +SYM_CRLF = b('\r\n') +SYM_LF = b('\n') + + +class PythonParser(object): + "Plain Python parsing class" + MAX_READ_LENGTH = 1000000 + encoding = None + + EXCEPTION_CLASSES = { + 'ERR': ResponseError, + 'EXECABORT': ExecAbortError, + 'LOADING': BusyLoadingError, + 'NOSCRIPT': NoScriptError, + } + + def __init__(self): + self._fp = None + + def __del__(self): + try: + self.on_disconnect() + except Exception: + pass + + def on_connect(self, connection): + "Called when the socket connects" + self._fp = connection._sock.makefile('rb') + if connection.decode_responses: + self.encoding = connection.encoding + + def on_disconnect(self): + "Called when the socket disconnects" + if self._fp is not None: + self._fp.close() + self._fp = None + + def read(self, length=None): + """ + Read a line from the socket if no length is specified, + otherwise read ``length`` bytes. Always strip away the newlines. + """ + try: + if length is not None: + bytes_left = length + 2 # read the line ending + if length > self.MAX_READ_LENGTH: + # apparently reading more than 1MB or so from a windows + # socket can cause MemoryErrors. See: + # https://github.com/andymccurdy/redis-py/issues/205 + # read smaller chunks at a time to work around this + try: + buf = BytesIO() + while bytes_left > 0: + read_len = min(bytes_left, self.MAX_READ_LENGTH) + buf.write(self._fp.read(read_len)) + bytes_left -= read_len + buf.seek(0) + return buf.read(length) + finally: + buf.close() + return self._fp.read(bytes_left)[:-2] + + # no length, read a full line + return self._fp.readline()[:-2] + except (socket.error, socket.timeout): + e = sys.exc_info()[1] + raise ConnectionError("Error while reading from socket: %s" % + (e.args,)) + + def parse_error(self, response): + "Parse an error response" + error_code = response.split(' ')[0] + if error_code in self.EXCEPTION_CLASSES: + response = response[len(error_code) + 1:] + return self.EXCEPTION_CLASSES[error_code](response) + return ResponseError(response) + + def read_response(self): + response = self.read() + if not response: + raise ConnectionError("Socket closed on remote end") + + byte, response = byte_to_chr(response[0]), response[1:] + + if byte not in ('-', '+', ':', '$', '*'): + raise InvalidResponse("Protocol Error") + + # server returned an error + if byte == '-': + response = nativestr(response) + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == '+': + pass + # int value + elif byte == ':': + response = long(response) + # bulk response + elif byte == '$': + length = int(response) + if length == -1: + return None + response = self.read(length) + # multi-bulk response + elif byte == '*': + length = int(response) + if length == -1: + return None + response = [self.read_response() for i in xrange(length)] + if isinstance(response, bytes) and self.encoding: + response = response.decode(self.encoding) + return response + + +class HiredisParser(object): + "Parser class for connections using Hiredis" + def __init__(self): + if not HIREDIS_AVAILABLE: + raise RedisError("Hiredis is not installed") + + def __del__(self): + try: + self.on_disconnect() + except Exception: + pass + + def on_connect(self, connection): + self._sock = connection._sock + kwargs = { + 'protocolError': InvalidResponse, + 'replyError': ResponseError, + } + if connection.decode_responses: + kwargs['encoding'] = connection.encoding + self._reader = hiredis.Reader(**kwargs) + + def on_disconnect(self): + self._sock = None + self._reader = None + + def read_response(self): + if not self._reader: + raise ConnectionError("Socket closed on remote end") + response = self._reader.gets() + while response is False: + try: + buffer = self._sock.recv(4096) + except (socket.error, socket.timeout): + e = sys.exc_info()[1] + raise ConnectionError("Error while reading from socket: %s" % + (e.args,)) + if not buffer: + raise ConnectionError("Socket closed on remote end") + self._reader.feed(buffer) + # proactively, but not conclusively, check if more data is in the + # buffer. if the data received doesn't end with \n, there's more. + if not buffer.endswith(SYM_LF): + continue + response = self._reader.gets() + return response + +if HIREDIS_AVAILABLE: + DefaultParser = HiredisParser +else: + DefaultParser = PythonParser + + +class Connection(object): + "Manages TCP communication to and from a Redis server" + def __init__(self, host='localhost', port=6379, db=0, password=None, + socket_timeout=None, encoding='utf-8', + encoding_errors='strict', decode_responses=False, + parser_class=DefaultParser): + self.pid = os.getpid() + self.host = host + self.port = port + self.db = db + self.password = password + self.socket_timeout = socket_timeout + self.encoding = encoding + self.encoding_errors = encoding_errors + self.decode_responses = decode_responses + self._sock = None + self._parser = parser_class() + + def __del__(self): + try: + self.disconnect() + except Exception: + pass + + def connect(self): + "Connects to the Redis server if not already connected" + if self._sock: + return + try: + sock = self._connect() + except socket.error: + e = sys.exc_info()[1] + raise ConnectionError(self._error_message(e)) + + self._sock = sock + self.on_connect() + + def _connect(self): + "Create a TCP socket connection" + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(self.socket_timeout) + sock.connect((self.host, self.port)) + return sock + + def _error_message(self, exception): + # args for socket.error can either be (errno, "message") + # or just "message" + if len(exception.args) == 1: + return "Error connecting to %s:%s. %s." % \ + (self.host, self.port, exception.args[0]) + else: + return "Error %s connecting %s:%s. %s." % \ + (exception.args[0], self.host, self.port, exception.args[1]) + + def on_connect(self): + "Initialize the connection, authenticate and select a database" + self._parser.on_connect(self) + + # if a password is specified, authenticate + if self.password: + self.send_command('AUTH', self.password) + if nativestr(self.read_response()) != 'OK': + raise AuthenticationError('Invalid Password') + + # if a database is specified, switch to it + if self.db: + self.send_command('SELECT', self.db) + if nativestr(self.read_response()) != 'OK': + raise ConnectionError('Invalid Database') + + def disconnect(self): + "Disconnects from the Redis server" + self._parser.on_disconnect() + if self._sock is None: + return + try: + self._sock.close() + except socket.error: + pass + self._sock = None + + def send_packed_command(self, command): + "Send an already packed command to the Redis server" + if not self._sock: + self.connect() + try: + self._sock.sendall(command) + except socket.error: + e = sys.exc_info()[1] + self.disconnect() + if len(e.args) == 1: + _errno, errmsg = 'UNKNOWN', e.args[0] + else: + _errno, errmsg = e.args + raise ConnectionError("Error %s while writing to socket. %s." % + (_errno, errmsg)) + except Exception: + self.disconnect() + raise + + def send_command(self, *args): + "Pack and send a command to the Redis server" + self.send_packed_command(self.pack_command(*args)) + + def read_response(self): + "Read the response from a previously sent command" + try: + response = self._parser.read_response() + except Exception: + self.disconnect() + raise + if isinstance(response, ResponseError): + raise response + return response + + def encode(self, value): + "Return a bytestring representation of the value" + if isinstance(value, bytes): + return value + if isinstance(value, float): + value = repr(value) + if not isinstance(value, basestring): + value = str(value) + if isinstance(value, unicode): + value = value.encode(self.encoding, self.encoding_errors) + return value + + def pack_command(self, *args): + "Pack a series of arguments into a value Redis command" + output = SYM_STAR + b(str(len(args))) + SYM_CRLF + for enc_value in imap(self.encode, args): + output += SYM_DOLLAR + output += b(str(len(enc_value))) + output += SYM_CRLF + output += enc_value + output += SYM_CRLF + return output + + +class UnixDomainSocketConnection(Connection): + def __init__(self, path='', db=0, password=None, + socket_timeout=None, encoding='utf-8', + encoding_errors='strict', decode_responses=False, + parser_class=DefaultParser): + self.pid = os.getpid() + self.path = path + self.db = db + self.password = password + self.socket_timeout = socket_timeout + self.encoding = encoding + self.encoding_errors = encoding_errors + self.decode_responses = decode_responses + self._sock = None + self._parser = parser_class() + + def _connect(self): + "Create a Unix domain socket connection" + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.settimeout(self.socket_timeout) + sock.connect(self.path) + return sock + + def _error_message(self, exception): + # args for socket.error can either be (errno, "message") + # or just "message" + if len(exception.args) == 1: + return "Error connecting to unix socket: %s. %s." % \ + (self.path, exception.args[0]) + else: + return "Error %s connecting to unix socket: %s. %s." % \ + (exception.args[0], self.path, exception.args[1]) + + +# TODO: add ability to block waiting on a connection to be released +class ConnectionPool(object): + "Generic connection pool" + def __init__(self, connection_class=Connection, max_connections=None, + **connection_kwargs): + self.pid = os.getpid() + self.connection_class = connection_class + self.connection_kwargs = connection_kwargs + self.max_connections = max_connections or 2 ** 31 + self._created_connections = 0 + self._available_connections = [] + self._in_use_connections = set() + + def _checkpid(self): + if self.pid != os.getpid(): + self.disconnect() + self.__init__(self.connection_class, self.max_connections, + **self.connection_kwargs) + + def get_connection(self, command_name, *keys, **options): + "Get a connection from the pool" + self._checkpid() + try: + connection = self._available_connections.pop() + except IndexError: + connection = self.make_connection() + self._in_use_connections.add(connection) + return connection + + def make_connection(self): + "Create a new connection" + if self._created_connections >= self.max_connections: + raise ConnectionError("Too many connections") + self._created_connections += 1 + return self.connection_class(**self.connection_kwargs) + + def release(self, connection): + "Releases the connection back to the pool" + self._checkpid() + if connection.pid == self.pid: + self._in_use_connections.remove(connection) + self._available_connections.append(connection) + + def disconnect(self): + "Disconnects all connections in the pool" + all_conns = chain(self._available_connections, + self._in_use_connections) + for connection in all_conns: + connection.disconnect() + + +class BlockingConnectionPool(object): + """ + Thread-safe blocking connection pool:: + + >>> from redis.client import Redis + >>> client = Redis(connection_pool=BlockingConnectionPool()) + + It performs the same function as the default + ``:py:class: ~redis.connection.ConnectionPool`` implementation, in that, + it maintains a pool of reusable connections that can be shared by + multiple redis clients (safely across threads if required). + + The difference is that, in the event that a client tries to get a + connection from the pool when all of connections are in use, rather than + raising a ``:py:class: ~redis.exceptions.ConnectionError`` (as the default + ``:py:class: ~redis.connection.ConnectionPool`` implementation does), it + makes the client wait ("blocks") for a specified number of seconds until + a connection becomes available. + + Use ``max_connections`` to increase / decrease the pool size:: + + >>> pool = BlockingConnectionPool(max_connections=10) + + Use ``timeout`` to tell it either how many seconds to wait for a connection + to become available, or to block forever: + + # Block forever. + >>> pool = BlockingConnectionPool(timeout=None) + + # Raise a ``ConnectionError`` after five seconds if a connection is + # not available. + >>> pool = BlockingConnectionPool(timeout=5) + """ + def __init__(self, max_connections=50, timeout=20, connection_class=None, + queue_class=None, **connection_kwargs): + "Compose and assign values." + # Compose. + if connection_class is None: + connection_class = Connection + if queue_class is None: + queue_class = LifoQueue + + # Assign. + self.connection_class = connection_class + self.connection_kwargs = connection_kwargs + self.queue_class = queue_class + self.max_connections = max_connections + self.timeout = timeout + + # Validate the ``max_connections``. With the "fill up the queue" + # algorithm we use, it must be a positive integer. + is_valid = isinstance(max_connections, int) and max_connections > 0 + if not is_valid: + raise ValueError('``max_connections`` must be a positive integer') + + # Get the current process id, so we can disconnect and reinstantiate if + # it changes. + self.pid = os.getpid() + + # Create and fill up a thread safe queue with ``None`` values. + self.pool = self.queue_class(max_connections) + while True: + try: + self.pool.put_nowait(None) + except Full: + break + + # Keep a list of actual connection instances so that we can + # disconnect them later. + self._connections = [] + + def _checkpid(self): + """ + Check the current process id. If it has changed, disconnect and + re-instantiate this connection pool instance. + """ + # Get the current process id. + pid = os.getpid() + + # If it hasn't changed since we were instantiated, then we're fine, so + # just exit, remaining connected. + if self.pid == pid: + return + + # If it has changed, then disconnect and re-instantiate. + self.disconnect() + self.reinstantiate() + + def make_connection(self): + "Make a fresh connection." + connection = self.connection_class(**self.connection_kwargs) + self._connections.append(connection) + return connection + + def get_connection(self, command_name, *keys, **options): + """ + Get a connection, blocking for ``self.timeout`` until a connection + is available from the pool. + + If the connection returned is ``None`` then creates a new connection. + Because we use a last-in first-out queue, the existing connections + (having been returned to the pool after the initial ``None`` values + were added) will be returned before ``None`` values. This means we only + create new connections when we need to, i.e.: the actual number of + connections will only increase in response to demand. + """ + # Make sure we haven't changed process. + self._checkpid() + + # Try and get a connection from the pool. If one isn't available within + # self.timeout then raise a ``ConnectionError``. + connection = None + try: + connection = self.pool.get(block=True, timeout=self.timeout) + except Empty: + # Note that this is not caught by the redis client and will be + # raised unless handled by application code. If you want never to + raise ConnectionError("No connection available.") + + # If the ``connection`` is actually ``None`` then that's a cue to make + # a new connection to add to the pool. + if connection is None: + connection = self.make_connection() + + return connection + + def release(self, connection): + "Releases the connection back to the pool." + # Make sure we haven't changed process. + self._checkpid() + + # Put the connection back into the pool. + try: + self.pool.put_nowait(connection) + except Full: + # This shouldn't normally happen but might perhaps happen after a + # reinstantiation. So, we can handle the exception by not putting + # the connection back on the pool, because we definitely do not + # want to reuse it. + pass + + def disconnect(self): + "Disconnects all connections in the pool." + for connection in self._connections: + connection.disconnect() + + def reinstantiate(self): + """ + Reinstatiate this instance within a new process with a new connection + pool set. + """ + self.__init__(max_connections=self.max_connections, + timeout=self.timeout, + connection_class=self.connection_class, + queue_class=self.queue_class, **self.connection_kwargs) diff --git a/client/ledis-py/ledis/exceptions.py b/client/ledis-py/ledis/exceptions.py new file mode 100644 index 0000000..d67afa7 --- /dev/null +++ b/client/ledis-py/ledis/exceptions.py @@ -0,0 +1,49 @@ +"Core exceptions raised by the Redis client" + + +class RedisError(Exception): + pass + + +class AuthenticationError(RedisError): + pass + + +class ServerError(RedisError): + pass + + +class ConnectionError(ServerError): + pass + + +class BusyLoadingError(ConnectionError): + pass + + +class InvalidResponse(ServerError): + pass + + +class ResponseError(RedisError): + pass + + +class DataError(RedisError): + pass + + +class PubSubError(RedisError): + pass + + +class WatchError(RedisError): + pass + + +class NoScriptError(ResponseError): + pass + + +class ExecAbortError(ResponseError): + pass diff --git a/client/ledis-py/ledis/utils.py b/client/ledis-py/ledis/utils.py new file mode 100644 index 0000000..ee681bf --- /dev/null +++ b/client/ledis-py/ledis/utils.py @@ -0,0 +1,16 @@ +try: + import hiredis + HIREDIS_AVAILABLE = True +except ImportError: + HIREDIS_AVAILABLE = False + + +def from_url(url, db=None, **kwargs): + """ + Returns an active Redis client generated from the given database URL. + + Will attempt to extract the database id from the path url fragment, if + none is provided. + """ + from redis.client import Redis + return Redis.from_url(url, db, **kwargs) diff --git a/client/ledis-py/setup.py b/client/ledis-py/setup.py new file mode 100644 index 0000000..67706f1 --- /dev/null +++ b/client/ledis-py/setup.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +import os +import sys + +from redis import __version__ + +try: + from setuptools import setup + from setuptools.command.test import test as TestCommand + + class PyTest(TestCommand): + def finalize_options(self): + TestCommand.finalize_options(self) + self.test_args = [] + self.test_suite = True + + def run_tests(self): + # import here, because outside the eggs aren't loaded + import pytest + errno = pytest.main(self.test_args) + sys.exit(errno) + +except ImportError: + + from distutils.core import setup + PyTest = lambda x: x + +f = open(os.path.join(os.path.dirname(__file__), 'README.rst')) +long_description = f.read() +f.close() + +setup( + name='redis', + version=__version__, + description='Python client for Redis key-value store', + long_description=long_description, + url='http://github.com/andymccurdy/redis-py', + author='Andy McCurdy', + author_email='sedrik@gmail.com', + maintainer='Andy McCurdy', + maintainer_email='sedrik@gmail.com', + keywords=['Redis', 'key-value store'], + license='MIT', + packages=['redis'], + tests_require=['pytest>=2.5.0'], + cmdclass={'test': PyTest}, + classifiers=[ + 'Development Status :: 5 - Production/Stable', + 'Environment :: Console', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: MIT License', + 'Operating System :: OS Independent', + 'Programming Language :: Python', + 'Programming Language :: Python :: 2.6', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.2', + 'Programming Language :: Python :: 3.3', + 'Programming Language :: Python :: 3.4', + ] +) diff --git a/client/ledis-py/tests/__init__.py b/client/ledis-py/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client/ledis-py/tests/conftest.py b/client/ledis-py/tests/conftest.py new file mode 100644 index 0000000..bd0116b --- /dev/null +++ b/client/ledis-py/tests/conftest.py @@ -0,0 +1,46 @@ +import pytest +import redis + +from distutils.version import StrictVersion + + +_REDIS_VERSIONS = {} + + +def get_version(**kwargs): + params = {'host': 'localhost', 'port': 6379, 'db': 9} + params.update(kwargs) + key = '%s:%s' % (params['host'], params['port']) + if key not in _REDIS_VERSIONS: + client = redis.Redis(**params) + _REDIS_VERSIONS[key] = client.info()['redis_version'] + client.connection_pool.disconnect() + return _REDIS_VERSIONS[key] + + +def _get_client(cls, request=None, **kwargs): + params = {'host': 'localhost', 'port': 6379, 'db': 9} + params.update(kwargs) + client = cls(**params) + client.flushdb() + if request: + def teardown(): + client.flushdb() + client.connection_pool.disconnect() + request.addfinalizer(teardown) + return client + + +def skip_if_server_version_lt(min_version): + check = StrictVersion(get_version()) < StrictVersion(min_version) + return pytest.mark.skipif(check, reason="") + + +@pytest.fixture() +def r(request, **kwargs): + return _get_client(redis.Redis, request, **kwargs) + + +@pytest.fixture() +def sr(request, **kwargs): + return _get_client(redis.StrictRedis, request, **kwargs) diff --git a/client/ledis-py/tests/test_commands.py b/client/ledis-py/tests/test_commands.py new file mode 100644 index 0000000..aaff22e --- /dev/null +++ b/client/ledis-py/tests/test_commands.py @@ -0,0 +1,1419 @@ +from __future__ import with_statement +import binascii +import datetime +import pytest +import redis +import time + +from redis._compat import (unichr, u, b, ascii_letters, iteritems, iterkeys, + itervalues) +from redis.client import parse_info +from redis import exceptions + +from .conftest import skip_if_server_version_lt + + +@pytest.fixture() +def slowlog(request, r): + current_config = r.config_get() + old_slower_than_value = current_config['slowlog-log-slower-than'] + old_max_legnth_value = current_config['slowlog-max-len'] + + def cleanup(): + r.config_set('slowlog-log-slower-than', old_slower_than_value) + r.config_set('slowlog-max-len', old_max_legnth_value) + request.addfinalizer(cleanup) + + r.config_set('slowlog-log-slower-than', 0) + r.config_set('slowlog-max-len', 128) + + +def redis_server_time(client): + seconds, milliseconds = client.time() + timestamp = float('%s.%s' % (seconds, milliseconds)) + return datetime.datetime.fromtimestamp(timestamp) + + +# RESPONSE CALLBACKS +class TestResponseCallbacks(object): + "Tests for the response callback system" + + def test_response_callbacks(self, r): + assert r.response_callbacks == redis.Redis.RESPONSE_CALLBACKS + assert id(r.response_callbacks) != id(redis.Redis.RESPONSE_CALLBACKS) + r.set_response_callback('GET', lambda x: 'static') + r['a'] = 'foo' + assert r['a'] == 'static' + + +class TestRedisCommands(object): + + def test_command_on_invalid_key_type(self, r): + r.lpush('a', '1') + with pytest.raises(redis.ResponseError): + r['a'] + + # SERVER INFORMATION + def test_client_list(self, r): + clients = r.client_list() + assert isinstance(clients[0], dict) + assert 'addr' in clients[0] + + @skip_if_server_version_lt('2.6.9') + def test_client_getname(self, r): + assert r.client_getname() is None + + @skip_if_server_version_lt('2.6.9') + def test_client_setname(self, r): + assert r.client_setname('redis_py_test') + assert r.client_getname() == 'redis_py_test' + + def test_config_get(self, r): + data = r.config_get() + assert 'maxmemory' in data + assert data['maxmemory'].isdigit() + + def test_config_resetstat(self, r): + r.ping() + prior_commands_processed = int(r.info()['total_commands_processed']) + assert prior_commands_processed >= 1 + r.config_resetstat() + reset_commands_processed = int(r.info()['total_commands_processed']) + assert reset_commands_processed < prior_commands_processed + + def test_config_set(self, r): + data = r.config_get() + rdbname = data['dbfilename'] + try: + assert r.config_set('dbfilename', 'redis_py_test.rdb') + assert r.config_get()['dbfilename'] == 'redis_py_test.rdb' + finally: + assert r.config_set('dbfilename', rdbname) + + def test_dbsize(self, r): + r['a'] = 'foo' + r['b'] = 'bar' + assert r.dbsize() == 2 + + def test_echo(self, r): + assert r.echo('foo bar') == b('foo bar') + + def test_info(self, r): + r['a'] = 'foo' + r['b'] = 'bar' + info = r.info() + assert isinstance(info, dict) + assert info['db9']['keys'] == 2 + + def test_lastsave(self, r): + assert isinstance(r.lastsave(), datetime.datetime) + + def test_object(self, r): + r['a'] = 'foo' + assert isinstance(r.object('refcount', 'a'), int) + assert isinstance(r.object('idletime', 'a'), int) + assert r.object('encoding', 'a') == b('raw') + assert r.object('idletime', 'invalid-key') is None + + def test_ping(self, r): + assert r.ping() + + def test_slowlog_get(self, r, slowlog): + assert r.slowlog_reset() + unicode_string = unichr(3456) + u('abcd') + unichr(3421) + r.get(unicode_string) + slowlog = r.slowlog_get() + assert isinstance(slowlog, list) + commands = [log['command'] for log in slowlog] + + get_command = b(' ').join((b('GET'), unicode_string.encode('utf-8'))) + assert get_command in commands + assert b('SLOWLOG RESET') in commands + # the order should be ['GET ', 'SLOWLOG RESET'], + # but if other clients are executing commands at the same time, there + # could be commands, before, between, or after, so just check that + # the two we care about are in the appropriate order. + assert commands.index(get_command) < commands.index(b('SLOWLOG RESET')) + + # make sure other attributes are typed correctly + assert isinstance(slowlog[0]['start_time'], int) + assert isinstance(slowlog[0]['duration'], int) + + def test_slowlog_get_limit(self, r, slowlog): + assert r.slowlog_reset() + r.get('foo') + r.get('bar') + slowlog = r.slowlog_get(1) + assert isinstance(slowlog, list) + commands = [log['command'] for log in slowlog] + assert b('GET foo') not in commands + assert b('GET bar') in commands + + def test_slowlog_length(self, r, slowlog): + r.get('foo') + assert isinstance(r.slowlog_len(), int) + + @skip_if_server_version_lt('2.6.0') + def test_time(self, r): + t = r.time() + assert len(t) == 2 + assert isinstance(t[0], int) + assert isinstance(t[1], int) + + # BASIC KEY COMMANDS + def test_append(self, r): + assert r.append('a', 'a1') == 2 + assert r['a'] == b('a1') + assert r.append('a', 'a2') == 4 + assert r['a'] == b('a1a2') + + @skip_if_server_version_lt('2.6.0') + def test_bitcount(self, r): + r.setbit('a', 5, True) + assert r.bitcount('a') == 1 + r.setbit('a', 6, True) + assert r.bitcount('a') == 2 + r.setbit('a', 5, False) + assert r.bitcount('a') == 1 + r.setbit('a', 9, True) + r.setbit('a', 17, True) + r.setbit('a', 25, True) + r.setbit('a', 33, True) + assert r.bitcount('a') == 5 + assert r.bitcount('a', 0, -1) == 5 + assert r.bitcount('a', 2, 3) == 2 + assert r.bitcount('a', 2, -1) == 3 + assert r.bitcount('a', -2, -1) == 2 + assert r.bitcount('a', 1, 1) == 1 + + @skip_if_server_version_lt('2.6.0') + def test_bitop_not_empty_string(self, r): + r['a'] = '' + r.bitop('not', 'r', 'a') + assert r.get('r') is None + + @skip_if_server_version_lt('2.6.0') + def test_bitop_not(self, r): + test_str = b('\xAA\x00\xFF\x55') + correct = ~0xAA00FF55 & 0xFFFFFFFF + r['a'] = test_str + r.bitop('not', 'r', 'a') + assert int(binascii.hexlify(r['r']), 16) == correct + + @skip_if_server_version_lt('2.6.0') + def test_bitop_not_in_place(self, r): + test_str = b('\xAA\x00\xFF\x55') + correct = ~0xAA00FF55 & 0xFFFFFFFF + r['a'] = test_str + r.bitop('not', 'a', 'a') + assert int(binascii.hexlify(r['a']), 16) == correct + + @skip_if_server_version_lt('2.6.0') + def test_bitop_single_string(self, r): + test_str = b('\x01\x02\xFF') + r['a'] = test_str + r.bitop('and', 'res1', 'a') + r.bitop('or', 'res2', 'a') + r.bitop('xor', 'res3', 'a') + assert r['res1'] == test_str + assert r['res2'] == test_str + assert r['res3'] == test_str + + @skip_if_server_version_lt('2.6.0') + def test_bitop_string_operands(self, r): + r['a'] = b('\x01\x02\xFF\xFF') + r['b'] = b('\x01\x02\xFF') + r.bitop('and', 'res1', 'a', 'b') + r.bitop('or', 'res2', 'a', 'b') + r.bitop('xor', 'res3', 'a', 'b') + assert int(binascii.hexlify(r['res1']), 16) == 0x0102FF00 + assert int(binascii.hexlify(r['res2']), 16) == 0x0102FFFF + assert int(binascii.hexlify(r['res3']), 16) == 0x000000FF + + def test_decr(self, r): + assert r.decr('a') == -1 + assert r['a'] == b('-1') + assert r.decr('a') == -2 + assert r['a'] == b('-2') + assert r.decr('a', amount=5) == -7 + assert r['a'] == b('-7') + + def test_delete(self, r): + assert r.delete('a') == 0 + r['a'] = 'foo' + assert r.delete('a') == 1 + + def test_delete_with_multiple_keys(self, r): + r['a'] = 'foo' + r['b'] = 'bar' + assert r.delete('a', 'b') == 2 + assert r.get('a') is None + assert r.get('b') is None + + def test_delitem(self, r): + r['a'] = 'foo' + del r['a'] + assert r.get('a') is None + + @skip_if_server_version_lt('2.6.0') + def test_dump_and_restore(self, r): + r['a'] = 'foo' + dumped = r.dump('a') + del r['a'] + r.restore('a', 0, dumped) + assert r['a'] == b('foo') + + def test_exists(self, r): + assert not r.exists('a') + r['a'] = 'foo' + assert r.exists('a') + + def test_exists_contains(self, r): + assert 'a' not in r + r['a'] = 'foo' + assert 'a' in r + + def test_expire(self, r): + assert not r.expire('a', 10) + r['a'] = 'foo' + assert r.expire('a', 10) + assert 0 < r.ttl('a') <= 10 + assert r.persist('a') + assert not r.ttl('a') + + def test_expireat_datetime(self, r): + expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) + r['a'] = 'foo' + assert r.expireat('a', expire_at) + assert 0 < r.ttl('a') <= 61 + + def test_expireat_no_key(self, r): + expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) + assert not r.expireat('a', expire_at) + + def test_expireat_unixtime(self, r): + expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) + r['a'] = 'foo' + expire_at_seconds = int(time.mktime(expire_at.timetuple())) + assert r.expireat('a', expire_at_seconds) + assert 0 < r.ttl('a') <= 61 + + def test_get_and_set(self, r): + # get and set can't be tested independently of each other + assert r.get('a') is None + byte_string = b('value') + integer = 5 + unicode_string = unichr(3456) + u('abcd') + unichr(3421) + assert r.set('byte_string', byte_string) + assert r.set('integer', 5) + assert r.set('unicode_string', unicode_string) + assert r.get('byte_string') == byte_string + assert r.get('integer') == b(str(integer)) + assert r.get('unicode_string').decode('utf-8') == unicode_string + + def test_getitem_and_setitem(self, r): + r['a'] = 'bar' + assert r['a'] == b('bar') + + def test_getitem_raises_keyerror_for_missing_key(self, r): + with pytest.raises(KeyError): + r['a'] + + def test_get_set_bit(self, r): + # no value + assert not r.getbit('a', 5) + # set bit 5 + assert not r.setbit('a', 5, True) + assert r.getbit('a', 5) + # unset bit 4 + assert not r.setbit('a', 4, False) + assert not r.getbit('a', 4) + # set bit 4 + assert not r.setbit('a', 4, True) + assert r.getbit('a', 4) + # set bit 5 again + assert r.setbit('a', 5, True) + assert r.getbit('a', 5) + + def test_getrange(self, r): + r['a'] = 'foo' + assert r.getrange('a', 0, 0) == b('f') + assert r.getrange('a', 0, 2) == b('foo') + assert r.getrange('a', 3, 4) == b('') + + def test_getset(self, r): + assert r.getset('a', 'foo') is None + assert r.getset('a', 'bar') == b('foo') + assert r.get('a') == b('bar') + + def test_incr(self, r): + assert r.incr('a') == 1 + assert r['a'] == b('1') + assert r.incr('a') == 2 + assert r['a'] == b('2') + assert r.incr('a', amount=5) == 7 + assert r['a'] == b('7') + + def test_incrby(self, r): + assert r.incrby('a') == 1 + assert r.incrby('a', 4) == 5 + assert r['a'] == b('5') + + @skip_if_server_version_lt('2.6.0') + def test_incrbyfloat(self, r): + assert r.incrbyfloat('a') == 1.0 + assert r['a'] == b('1') + assert r.incrbyfloat('a', 1.1) == 2.1 + assert float(r['a']) == float(2.1) + + def test_keys(self, r): + assert r.keys() == [] + keys_with_underscores = set([b('test_a'), b('test_b')]) + keys = keys_with_underscores.union(set([b('testc')])) + for key in keys: + r[key] = 1 + assert set(r.keys(pattern='test_*')) == keys_with_underscores + assert set(r.keys(pattern='test*')) == keys + + def test_mget(self, r): + assert r.mget(['a', 'b']) == [None, None] + r['a'] = '1' + r['b'] = '2' + r['c'] = '3' + assert r.mget('a', 'other', 'b', 'c') == [b('1'), None, b('2'), b('3')] + + def test_mset(self, r): + d = {'a': b('1'), 'b': b('2'), 'c': b('3')} + assert r.mset(d) + for k, v in iteritems(d): + assert r[k] == v + + def test_mset_kwargs(self, r): + d = {'a': b('1'), 'b': b('2'), 'c': b('3')} + assert r.mset(**d) + for k, v in iteritems(d): + assert r[k] == v + + def test_msetnx(self, r): + d = {'a': b('1'), 'b': b('2'), 'c': b('3')} + assert r.msetnx(d) + d2 = {'a': b('x'), 'd': b('4')} + assert not r.msetnx(d2) + for k, v in iteritems(d): + assert r[k] == v + assert r.get('d') is None + + def test_msetnx_kwargs(self, r): + d = {'a': b('1'), 'b': b('2'), 'c': b('3')} + assert r.msetnx(**d) + d2 = {'a': b('x'), 'd': b('4')} + assert not r.msetnx(**d2) + for k, v in iteritems(d): + assert r[k] == v + assert r.get('d') is None + + @skip_if_server_version_lt('2.6.0') + def test_pexpire(self, r): + assert not r.pexpire('a', 60000) + r['a'] = 'foo' + assert r.pexpire('a', 60000) + assert 0 < r.pttl('a') <= 60000 + assert r.persist('a') + assert r.pttl('a') is None + + @skip_if_server_version_lt('2.6.0') + def test_pexpireat_datetime(self, r): + expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) + r['a'] = 'foo' + assert r.pexpireat('a', expire_at) + assert 0 < r.pttl('a') <= 61000 + + @skip_if_server_version_lt('2.6.0') + def test_pexpireat_no_key(self, r): + expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) + assert not r.pexpireat('a', expire_at) + + @skip_if_server_version_lt('2.6.0') + def test_pexpireat_unixtime(self, r): + expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) + r['a'] = 'foo' + expire_at_seconds = int(time.mktime(expire_at.timetuple())) * 1000 + assert r.pexpireat('a', expire_at_seconds) + assert 0 < r.pttl('a') <= 61000 + + @skip_if_server_version_lt('2.6.0') + def test_psetex(self, r): + assert r.psetex('a', 1000, 'value') + assert r['a'] == b('value') + assert 0 < r.pttl('a') <= 1000 + + @skip_if_server_version_lt('2.6.0') + def test_psetex_timedelta(self, r): + expire_at = datetime.timedelta(milliseconds=1000) + assert r.psetex('a', expire_at, 'value') + assert r['a'] == b('value') + assert 0 < r.pttl('a') <= 1000 + + def test_randomkey(self, r): + assert r.randomkey() is None + for key in ('a', 'b', 'c'): + r[key] = 1 + assert r.randomkey() in (b('a'), b('b'), b('c')) + + def test_rename(self, r): + r['a'] = '1' + assert r.rename('a', 'b') + assert r.get('a') is None + assert r['b'] == b('1') + + def test_renamenx(self, r): + r['a'] = '1' + r['b'] = '2' + assert not r.renamenx('a', 'b') + assert r['a'] == b('1') + assert r['b'] == b('2') + + @skip_if_server_version_lt('2.6.0') + def test_set_nx(self, r): + assert r.set('a', '1', nx=True) + assert not r.set('a', '2', nx=True) + assert r['a'] == b('1') + + @skip_if_server_version_lt('2.6.0') + def test_set_xx(self, r): + assert not r.set('a', '1', xx=True) + assert r.get('a') is None + r['a'] = 'bar' + assert r.set('a', '2', xx=True) + assert r.get('a') == b('2') + + @skip_if_server_version_lt('2.6.0') + def test_set_px(self, r): + assert r.set('a', '1', px=10000) + assert r['a'] == b('1') + assert 0 < r.pttl('a') <= 10000 + assert 0 < r.ttl('a') <= 10 + + @skip_if_server_version_lt('2.6.0') + def test_set_px_timedelta(self, r): + expire_at = datetime.timedelta(milliseconds=1000) + assert r.set('a', '1', px=expire_at) + assert 0 < r.pttl('a') <= 1000 + assert 0 < r.ttl('a') <= 1 + + @skip_if_server_version_lt('2.6.0') + def test_set_ex(self, r): + assert r.set('a', '1', ex=10) + assert 0 < r.ttl('a') <= 10 + + @skip_if_server_version_lt('2.6.0') + def test_set_ex_timedelta(self, r): + expire_at = datetime.timedelta(seconds=60) + assert r.set('a', '1', ex=expire_at) + assert 0 < r.ttl('a') <= 60 + + @skip_if_server_version_lt('2.6.0') + def test_set_multipleoptions(self, r): + r['a'] = 'val' + assert r.set('a', '1', xx=True, px=10000) + assert 0 < r.ttl('a') <= 10 + + def test_setex(self, r): + assert r.setex('a', '1', 60) + assert r['a'] == b('1') + assert 0 < r.ttl('a') <= 60 + + def test_setnx(self, r): + assert r.setnx('a', '1') + assert r['a'] == b('1') + assert not r.setnx('a', '2') + assert r['a'] == b('1') + + def test_setrange(self, r): + assert r.setrange('a', 5, 'foo') == 8 + assert r['a'] == b('\0\0\0\0\0foo') + r['a'] = 'abcdefghijh' + assert r.setrange('a', 6, '12345') == 11 + assert r['a'] == b('abcdef12345') + + def test_strlen(self, r): + r['a'] = 'foo' + assert r.strlen('a') == 3 + + def test_substr(self, r): + r['a'] = '0123456789' + assert r.substr('a', 0) == b('0123456789') + assert r.substr('a', 2) == b('23456789') + assert r.substr('a', 3, 5) == b('345') + assert r.substr('a', 3, -2) == b('345678') + + def test_type(self, r): + assert r.type('a') == b('none') + r['a'] = '1' + assert r.type('a') == b('string') + del r['a'] + r.lpush('a', '1') + assert r.type('a') == b('list') + del r['a'] + r.sadd('a', '1') + assert r.type('a') == b('set') + del r['a'] + r.zadd('a', **{'1': 1}) + assert r.type('a') == b('zset') + + # LIST COMMANDS + def test_blpop(self, r): + r.rpush('a', '1', '2') + r.rpush('b', '3', '4') + assert r.blpop(['b', 'a'], timeout=1) == (b('b'), b('3')) + assert r.blpop(['b', 'a'], timeout=1) == (b('b'), b('4')) + assert r.blpop(['b', 'a'], timeout=1) == (b('a'), b('1')) + assert r.blpop(['b', 'a'], timeout=1) == (b('a'), b('2')) + assert r.blpop(['b', 'a'], timeout=1) is None + r.rpush('c', '1') + assert r.blpop('c', timeout=1) == (b('c'), b('1')) + + def test_brpop(self, r): + r.rpush('a', '1', '2') + r.rpush('b', '3', '4') + assert r.brpop(['b', 'a'], timeout=1) == (b('b'), b('4')) + assert r.brpop(['b', 'a'], timeout=1) == (b('b'), b('3')) + assert r.brpop(['b', 'a'], timeout=1) == (b('a'), b('2')) + assert r.brpop(['b', 'a'], timeout=1) == (b('a'), b('1')) + assert r.brpop(['b', 'a'], timeout=1) is None + r.rpush('c', '1') + assert r.brpop('c', timeout=1) == (b('c'), b('1')) + + def test_brpoplpush(self, r): + r.rpush('a', '1', '2') + r.rpush('b', '3', '4') + assert r.brpoplpush('a', 'b') == b('2') + assert r.brpoplpush('a', 'b') == b('1') + assert r.brpoplpush('a', 'b', timeout=1) is None + assert r.lrange('a', 0, -1) == [] + assert r.lrange('b', 0, -1) == [b('1'), b('2'), b('3'), b('4')] + + def test_brpoplpush_empty_string(self, r): + r.rpush('a', '') + assert r.brpoplpush('a', 'b') == b('') + + def test_lindex(self, r): + r.rpush('a', '1', '2', '3') + assert r.lindex('a', '0') == b('1') + assert r.lindex('a', '1') == b('2') + assert r.lindex('a', '2') == b('3') + + def test_linsert(self, r): + r.rpush('a', '1', '2', '3') + assert r.linsert('a', 'after', '2', '2.5') == 4 + assert r.lrange('a', 0, -1) == [b('1'), b('2'), b('2.5'), b('3')] + assert r.linsert('a', 'before', '2', '1.5') == 5 + assert r.lrange('a', 0, -1) == \ + [b('1'), b('1.5'), b('2'), b('2.5'), b('3')] + + def test_llen(self, r): + r.rpush('a', '1', '2', '3') + assert r.llen('a') == 3 + + def test_lpop(self, r): + r.rpush('a', '1', '2', '3') + assert r.lpop('a') == b('1') + assert r.lpop('a') == b('2') + assert r.lpop('a') == b('3') + assert r.lpop('a') is None + + def test_lpush(self, r): + assert r.lpush('a', '1') == 1 + assert r.lpush('a', '2') == 2 + assert r.lpush('a', '3', '4') == 4 + assert r.lrange('a', 0, -1) == [b('4'), b('3'), b('2'), b('1')] + + def test_lpushx(self, r): + assert r.lpushx('a', '1') == 0 + assert r.lrange('a', 0, -1) == [] + r.rpush('a', '1', '2', '3') + assert r.lpushx('a', '4') == 4 + assert r.lrange('a', 0, -1) == [b('4'), b('1'), b('2'), b('3')] + + def test_lrange(self, r): + r.rpush('a', '1', '2', '3', '4', '5') + assert r.lrange('a', 0, 2) == [b('1'), b('2'), b('3')] + assert r.lrange('a', 2, 10) == [b('3'), b('4'), b('5')] + assert r.lrange('a', 0, -1) == [b('1'), b('2'), b('3'), b('4'), b('5')] + + def test_lrem(self, r): + r.rpush('a', '1', '1', '1', '1') + assert r.lrem('a', '1', 1) == 1 + assert r.lrange('a', 0, -1) == [b('1'), b('1'), b('1')] + assert r.lrem('a', '1') == 3 + assert r.lrange('a', 0, -1) == [] + + def test_lset(self, r): + r.rpush('a', '1', '2', '3') + assert r.lrange('a', 0, -1) == [b('1'), b('2'), b('3')] + assert r.lset('a', 1, '4') + assert r.lrange('a', 0, 2) == [b('1'), b('4'), b('3')] + + def test_ltrim(self, r): + r.rpush('a', '1', '2', '3') + assert r.ltrim('a', 0, 1) + assert r.lrange('a', 0, -1) == [b('1'), b('2')] + + def test_rpop(self, r): + r.rpush('a', '1', '2', '3') + assert r.rpop('a') == b('3') + assert r.rpop('a') == b('2') + assert r.rpop('a') == b('1') + assert r.rpop('a') is None + + def test_rpoplpush(self, r): + r.rpush('a', 'a1', 'a2', 'a3') + r.rpush('b', 'b1', 'b2', 'b3') + assert r.rpoplpush('a', 'b') == b('a3') + assert r.lrange('a', 0, -1) == [b('a1'), b('a2')] + assert r.lrange('b', 0, -1) == [b('a3'), b('b1'), b('b2'), b('b3')] + + def test_rpush(self, r): + assert r.rpush('a', '1') == 1 + assert r.rpush('a', '2') == 2 + assert r.rpush('a', '3', '4') == 4 + assert r.lrange('a', 0, -1) == [b('1'), b('2'), b('3'), b('4')] + + def test_rpushx(self, r): + assert r.rpushx('a', 'b') == 0 + assert r.lrange('a', 0, -1) == [] + r.rpush('a', '1', '2', '3') + assert r.rpushx('a', '4') == 4 + assert r.lrange('a', 0, -1) == [b('1'), b('2'), b('3'), b('4')] + + # SCAN COMMANDS + @skip_if_server_version_lt('2.8.0') + def test_scan(self, r): + r.set('a', 1) + r.set('b', 2) + r.set('c', 3) + cursor, keys = r.scan() + assert cursor == 0 + assert set(keys) == set([b('a'), b('b'), b('c')]) + _, keys = r.scan(match='a') + assert set(keys) == set([b('a')]) + + @skip_if_server_version_lt('2.8.0') + def test_scan_iter(self, r): + r.set('a', 1) + r.set('b', 2) + r.set('c', 3) + keys = list(r.scan_iter()) + assert set(keys) == set([b('a'), b('b'), b('c')]) + keys = list(r.scan_iter(match='a')) + assert set(keys) == set([b('a')]) + + @skip_if_server_version_lt('2.8.0') + def test_sscan(self, r): + r.sadd('a', 1, 2, 3) + cursor, members = r.sscan('a') + assert cursor == 0 + assert set(members) == set([b('1'), b('2'), b('3')]) + _, members = r.sscan('a', match=b('1')) + assert set(members) == set([b('1')]) + + @skip_if_server_version_lt('2.8.0') + def test_sscan_iter(self, r): + r.sadd('a', 1, 2, 3) + members = list(r.sscan_iter('a')) + assert set(members) == set([b('1'), b('2'), b('3')]) + members = list(r.sscan_iter('a', match=b('1'))) + assert set(members) == set([b('1')]) + + @skip_if_server_version_lt('2.8.0') + def test_hscan(self, r): + r.hmset('a', {'a': 1, 'b': 2, 'c': 3}) + cursor, dic = r.hscan('a') + assert cursor == 0 + assert dic == {b('a'): b('1'), b('b'): b('2'), b('c'): b('3')} + _, dic = r.hscan('a', match='a') + assert dic == {b('a'): b('1')} + + @skip_if_server_version_lt('2.8.0') + def test_hscan_iter(self, r): + r.hmset('a', {'a': 1, 'b': 2, 'c': 3}) + dic = dict(r.hscan_iter('a')) + assert dic == {b('a'): b('1'), b('b'): b('2'), b('c'): b('3')} + dic = dict(r.hscan_iter('a', match='a')) + assert dic == {b('a'): b('1')} + + @skip_if_server_version_lt('2.8.0') + def test_zscan(self, r): + r.zadd('a', 'a', 1, 'b', 2, 'c', 3) + cursor, pairs = r.zscan('a') + assert cursor == 0 + assert set(pairs) == set([(b('a'), 1), (b('b'), 2), (b('c'), 3)]) + _, pairs = r.zscan('a', match='a') + assert set(pairs) == set([(b('a'), 1)]) + + @skip_if_server_version_lt('2.8.0') + def test_zscan_iter(self, r): + r.zadd('a', 'a', 1, 'b', 2, 'c', 3) + pairs = list(r.zscan_iter('a')) + assert set(pairs) == set([(b('a'), 1), (b('b'), 2), (b('c'), 3)]) + pairs = list(r.zscan_iter('a', match='a')) + assert set(pairs) == set([(b('a'), 1)]) + + # SET COMMANDS + def test_sadd(self, r): + members = set([b('1'), b('2'), b('3')]) + r.sadd('a', *members) + assert r.smembers('a') == members + + def test_scard(self, r): + r.sadd('a', '1', '2', '3') + assert r.scard('a') == 3 + + def test_sdiff(self, r): + r.sadd('a', '1', '2', '3') + assert r.sdiff('a', 'b') == set([b('1'), b('2'), b('3')]) + r.sadd('b', '2', '3') + assert r.sdiff('a', 'b') == set([b('1')]) + + def test_sdiffstore(self, r): + r.sadd('a', '1', '2', '3') + assert r.sdiffstore('c', 'a', 'b') == 3 + assert r.smembers('c') == set([b('1'), b('2'), b('3')]) + r.sadd('b', '2', '3') + assert r.sdiffstore('c', 'a', 'b') == 1 + assert r.smembers('c') == set([b('1')]) + + def test_sinter(self, r): + r.sadd('a', '1', '2', '3') + assert r.sinter('a', 'b') == set() + r.sadd('b', '2', '3') + assert r.sinter('a', 'b') == set([b('2'), b('3')]) + + def test_sinterstore(self, r): + r.sadd('a', '1', '2', '3') + assert r.sinterstore('c', 'a', 'b') == 0 + assert r.smembers('c') == set() + r.sadd('b', '2', '3') + assert r.sinterstore('c', 'a', 'b') == 2 + assert r.smembers('c') == set([b('2'), b('3')]) + + def test_sismember(self, r): + r.sadd('a', '1', '2', '3') + assert r.sismember('a', '1') + assert r.sismember('a', '2') + assert r.sismember('a', '3') + assert not r.sismember('a', '4') + + def test_smembers(self, r): + r.sadd('a', '1', '2', '3') + assert r.smembers('a') == set([b('1'), b('2'), b('3')]) + + def test_smove(self, r): + r.sadd('a', 'a1', 'a2') + r.sadd('b', 'b1', 'b2') + assert r.smove('a', 'b', 'a1') + assert r.smembers('a') == set([b('a2')]) + assert r.smembers('b') == set([b('b1'), b('b2'), b('a1')]) + + def test_spop(self, r): + s = [b('1'), b('2'), b('3')] + r.sadd('a', *s) + value = r.spop('a') + assert value in s + assert r.smembers('a') == set(s) - set([value]) + + def test_srandmember(self, r): + s = [b('1'), b('2'), b('3')] + r.sadd('a', *s) + assert r.srandmember('a') in s + + @skip_if_server_version_lt('2.6.0') + def test_srandmember_multi_value(self, r): + s = [b('1'), b('2'), b('3')] + r.sadd('a', *s) + randoms = r.srandmember('a', number=2) + assert len(randoms) == 2 + assert set(randoms).intersection(s) == set(randoms) + + def test_srem(self, r): + r.sadd('a', '1', '2', '3', '4') + assert r.srem('a', '5') == 0 + assert r.srem('a', '2', '4') == 2 + assert r.smembers('a') == set([b('1'), b('3')]) + + def test_sunion(self, r): + r.sadd('a', '1', '2') + r.sadd('b', '2', '3') + assert r.sunion('a', 'b') == set([b('1'), b('2'), b('3')]) + + def test_sunionstore(self, r): + r.sadd('a', '1', '2') + r.sadd('b', '2', '3') + assert r.sunionstore('c', 'a', 'b') == 3 + assert r.smembers('c') == set([b('1'), b('2'), b('3')]) + + # SORTED SET COMMANDS + def test_zadd(self, r): + r.zadd('a', a1=1, a2=2, a3=3) + assert r.zrange('a', 0, -1) == [b('a1'), b('a2'), b('a3')] + + def test_zcard(self, r): + r.zadd('a', a1=1, a2=2, a3=3) + assert r.zcard('a') == 3 + + def test_zcount(self, r): + r.zadd('a', a1=1, a2=2, a3=3) + assert r.zcount('a', '-inf', '+inf') == 3 + assert r.zcount('a', 1, 2) == 2 + assert r.zcount('a', 10, 20) == 0 + + def test_zincrby(self, r): + r.zadd('a', a1=1, a2=2, a3=3) + assert r.zincrby('a', 'a2') == 3.0 + assert r.zincrby('a', 'a3', amount=5) == 8.0 + assert r.zscore('a', 'a2') == 3.0 + assert r.zscore('a', 'a3') == 8.0 + + @skip_if_server_version_lt('2.8.9') + def test_zlexcount(self, r): + r.zadd('a', a=0, b=0, c=0, d=0, e=0, f=0, g=0) + assert r.zlexcount('a', '-', '+') == 7 + assert r.zlexcount('a', '[b', '[f') == 5 + + def test_zinterstore_sum(self, r): + r.zadd('a', a1=1, a2=1, a3=1) + r.zadd('b', a1=2, a2=2, a3=2) + r.zadd('c', a1=6, a3=5, a4=4) + assert r.zinterstore('d', ['a', 'b', 'c']) == 2 + assert r.zrange('d', 0, -1, withscores=True) == \ + [(b('a3'), 8), (b('a1'), 9)] + + def test_zinterstore_max(self, r): + r.zadd('a', a1=1, a2=1, a3=1) + r.zadd('b', a1=2, a2=2, a3=2) + r.zadd('c', a1=6, a3=5, a4=4) + assert r.zinterstore('d', ['a', 'b', 'c'], aggregate='MAX') == 2 + assert r.zrange('d', 0, -1, withscores=True) == \ + [(b('a3'), 5), (b('a1'), 6)] + + def test_zinterstore_min(self, r): + r.zadd('a', a1=1, a2=2, a3=3) + r.zadd('b', a1=2, a2=3, a3=5) + r.zadd('c', a1=6, a3=5, a4=4) + assert r.zinterstore('d', ['a', 'b', 'c'], aggregate='MIN') == 2 + assert r.zrange('d', 0, -1, withscores=True) == \ + [(b('a1'), 1), (b('a3'), 3)] + + def test_zinterstore_with_weight(self, r): + r.zadd('a', a1=1, a2=1, a3=1) + r.zadd('b', a1=2, a2=2, a3=2) + r.zadd('c', a1=6, a3=5, a4=4) + assert r.zinterstore('d', {'a': 1, 'b': 2, 'c': 3}) == 2 + assert r.zrange('d', 0, -1, withscores=True) == \ + [(b('a3'), 20), (b('a1'), 23)] + + def test_zrange(self, r): + r.zadd('a', a1=1, a2=2, a3=3) + assert r.zrange('a', 0, 1) == [b('a1'), b('a2')] + assert r.zrange('a', 1, 2) == [b('a2'), b('a3')] + + # withscores + assert r.zrange('a', 0, 1, withscores=True) == \ + [(b('a1'), 1.0), (b('a2'), 2.0)] + assert r.zrange('a', 1, 2, withscores=True) == \ + [(b('a2'), 2.0), (b('a3'), 3.0)] + + # custom score function + assert r.zrange('a', 0, 1, withscores=True, score_cast_func=int) == \ + [(b('a1'), 1), (b('a2'), 2)] + + @skip_if_server_version_lt('2.8.9') + def test_zrangebylex(self, r): + r.zadd('a', a=0, b=0, c=0, d=0, e=0, f=0, g=0) + assert r.zrangebylex('a', '-', '[c') == [b('a'), b('b'), b('c')] + assert r.zrangebylex('a', '-', '(c') == [b('a'), b('b')] + assert r.zrangebylex('a', '[aaa', '(g') == \ + [b('b'), b('c'), b('d'), b('e'), b('f')] + assert r.zrangebylex('a', '[f', '+') == [b('f'), b('g')] + assert r.zrangebylex('a', '-', '+', start=3, num=2) == [b('d'), b('e')] + + def test_zrangebyscore(self, r): + r.zadd('a', a1=1, a2=2, a3=3, a4=4, a5=5) + assert r.zrangebyscore('a', 2, 4) == [b('a2'), b('a3'), b('a4')] + + # slicing with start/num + assert r.zrangebyscore('a', 2, 4, start=1, num=2) == \ + [b('a3'), b('a4')] + + # withscores + assert r.zrangebyscore('a', 2, 4, withscores=True) == \ + [(b('a2'), 2.0), (b('a3'), 3.0), (b('a4'), 4.0)] + + # custom score function + assert r.zrangebyscore('a', 2, 4, withscores=True, + score_cast_func=int) == \ + [(b('a2'), 2), (b('a3'), 3), (b('a4'), 4)] + + def test_zrank(self, r): + r.zadd('a', a1=1, a2=2, a3=3, a4=4, a5=5) + assert r.zrank('a', 'a1') == 0 + assert r.zrank('a', 'a2') == 1 + assert r.zrank('a', 'a6') is None + + def test_zrem(self, r): + r.zadd('a', a1=1, a2=2, a3=3) + assert r.zrem('a', 'a2') == 1 + assert r.zrange('a', 0, -1) == [b('a1'), b('a3')] + assert r.zrem('a', 'b') == 0 + assert r.zrange('a', 0, -1) == [b('a1'), b('a3')] + + def test_zrem_multiple_keys(self, r): + r.zadd('a', a1=1, a2=2, a3=3) + assert r.zrem('a', 'a1', 'a2') == 2 + assert r.zrange('a', 0, 5) == [b('a3')] + + @skip_if_server_version_lt('2.8.9') + def test_zremrangebylex(self, r): + r.zadd('a', a=0, b=0, c=0, d=0, e=0, f=0, g=0) + assert r.zremrangebylex('a', '-', '[c') == 3 + assert r.zrange('a', 0, -1) == [b('d'), b('e'), b('f'), b('g')] + assert r.zremrangebylex('a', '[f', '+') == 2 + assert r.zrange('a', 0, -1) == [b('d'), b('e')] + assert r.zremrangebylex('a', '[h', '+') == 0 + assert r.zrange('a', 0, -1) == [b('d'), b('e')] + + def test_zremrangebyrank(self, r): + r.zadd('a', a1=1, a2=2, a3=3, a4=4, a5=5) + assert r.zremrangebyrank('a', 1, 3) == 3 + assert r.zrange('a', 0, 5) == [b('a1'), b('a5')] + + def test_zremrangebyscore(self, r): + r.zadd('a', a1=1, a2=2, a3=3, a4=4, a5=5) + assert r.zremrangebyscore('a', 2, 4) == 3 + assert r.zrange('a', 0, -1) == [b('a1'), b('a5')] + assert r.zremrangebyscore('a', 2, 4) == 0 + assert r.zrange('a', 0, -1) == [b('a1'), b('a5')] + + def test_zrevrange(self, r): + r.zadd('a', a1=1, a2=2, a3=3) + assert r.zrevrange('a', 0, 1) == [b('a3'), b('a2')] + assert r.zrevrange('a', 1, 2) == [b('a2'), b('a1')] + + # withscores + assert r.zrevrange('a', 0, 1, withscores=True) == \ + [(b('a3'), 3.0), (b('a2'), 2.0)] + assert r.zrevrange('a', 1, 2, withscores=True) == \ + [(b('a2'), 2.0), (b('a1'), 1.0)] + + # custom score function + assert r.zrevrange('a', 0, 1, withscores=True, + score_cast_func=int) == \ + [(b('a3'), 3.0), (b('a2'), 2.0)] + + def test_zrevrangebyscore(self, r): + r.zadd('a', a1=1, a2=2, a3=3, a4=4, a5=5) + assert r.zrevrangebyscore('a', 4, 2) == [b('a4'), b('a3'), b('a2')] + + # slicing with start/num + assert r.zrevrangebyscore('a', 4, 2, start=1, num=2) == \ + [b('a3'), b('a2')] + + # withscores + assert r.zrevrangebyscore('a', 4, 2, withscores=True) == \ + [(b('a4'), 4.0), (b('a3'), 3.0), (b('a2'), 2.0)] + + # custom score function + assert r.zrevrangebyscore('a', 4, 2, withscores=True, + score_cast_func=int) == \ + [(b('a4'), 4), (b('a3'), 3), (b('a2'), 2)] + + def test_zrevrank(self, r): + r.zadd('a', a1=1, a2=2, a3=3, a4=4, a5=5) + assert r.zrevrank('a', 'a1') == 4 + assert r.zrevrank('a', 'a2') == 3 + assert r.zrevrank('a', 'a6') is None + + def test_zscore(self, r): + r.zadd('a', a1=1, a2=2, a3=3) + assert r.zscore('a', 'a1') == 1.0 + assert r.zscore('a', 'a2') == 2.0 + assert r.zscore('a', 'a4') is None + + def test_zunionstore_sum(self, r): + r.zadd('a', a1=1, a2=1, a3=1) + r.zadd('b', a1=2, a2=2, a3=2) + r.zadd('c', a1=6, a3=5, a4=4) + assert r.zunionstore('d', ['a', 'b', 'c']) == 4 + assert r.zrange('d', 0, -1, withscores=True) == \ + [(b('a2'), 3), (b('a4'), 4), (b('a3'), 8), (b('a1'), 9)] + + def test_zunionstore_max(self, r): + r.zadd('a', a1=1, a2=1, a3=1) + r.zadd('b', a1=2, a2=2, a3=2) + r.zadd('c', a1=6, a3=5, a4=4) + assert r.zunionstore('d', ['a', 'b', 'c'], aggregate='MAX') == 4 + assert r.zrange('d', 0, -1, withscores=True) == \ + [(b('a2'), 2), (b('a4'), 4), (b('a3'), 5), (b('a1'), 6)] + + def test_zunionstore_min(self, r): + r.zadd('a', a1=1, a2=2, a3=3) + r.zadd('b', a1=2, a2=2, a3=4) + r.zadd('c', a1=6, a3=5, a4=4) + assert r.zunionstore('d', ['a', 'b', 'c'], aggregate='MIN') == 4 + assert r.zrange('d', 0, -1, withscores=True) == \ + [(b('a1'), 1), (b('a2'), 2), (b('a3'), 3), (b('a4'), 4)] + + def test_zunionstore_with_weight(self, r): + r.zadd('a', a1=1, a2=1, a3=1) + r.zadd('b', a1=2, a2=2, a3=2) + r.zadd('c', a1=6, a3=5, a4=4) + assert r.zunionstore('d', {'a': 1, 'b': 2, 'c': 3}) == 4 + assert r.zrange('d', 0, -1, withscores=True) == \ + [(b('a2'), 5), (b('a4'), 12), (b('a3'), 20), (b('a1'), 23)] + + # HYPERLOGLOG TESTS + @skip_if_server_version_lt('2.8.9') + def test_pfadd(self, r): + members = set([b('1'), b('2'), b('3')]) + assert r.pfadd('a', *members) == 1 + assert r.pfadd('a', *members) == 0 + assert r.pfcount('a') == len(members) + + @skip_if_server_version_lt('2.8.9') + def test_pfcount(self, r): + members = set([b('1'), b('2'), b('3')]) + r.pfadd('a', *members) + assert r.pfcount('a') == len(members) + + @skip_if_server_version_lt('2.8.9') + def test_pfmerge(self, r): + mema = set([b('1'), b('2'), b('3')]) + memb = set([b('2'), b('3'), b('4')]) + memc = set([b('5'), b('6'), b('7')]) + r.pfadd('a', *mema) + r.pfadd('b', *memb) + r.pfadd('c', *memc) + r.pfmerge('d', 'c', 'a') + assert r.pfcount('d') == 6 + r.pfmerge('d', 'b') + assert r.pfcount('d') == 7 + + # HASH COMMANDS + def test_hget_and_hset(self, r): + r.hmset('a', {'1': 1, '2': 2, '3': 3}) + assert r.hget('a', '1') == b('1') + assert r.hget('a', '2') == b('2') + assert r.hget('a', '3') == b('3') + + # field was updated, redis returns 0 + assert r.hset('a', '2', 5) == 0 + assert r.hget('a', '2') == b('5') + + # field is new, redis returns 1 + assert r.hset('a', '4', 4) == 1 + assert r.hget('a', '4') == b('4') + + # key inside of hash that doesn't exist returns null value + assert r.hget('a', 'b') is None + + def test_hdel(self, r): + r.hmset('a', {'1': 1, '2': 2, '3': 3}) + assert r.hdel('a', '2') == 1 + assert r.hget('a', '2') is None + assert r.hdel('a', '1', '3') == 2 + assert r.hlen('a') == 0 + + def test_hexists(self, r): + r.hmset('a', {'1': 1, '2': 2, '3': 3}) + assert r.hexists('a', '1') + assert not r.hexists('a', '4') + + def test_hgetall(self, r): + h = {b('a1'): b('1'), b('a2'): b('2'), b('a3'): b('3')} + r.hmset('a', h) + assert r.hgetall('a') == h + + def test_hincrby(self, r): + assert r.hincrby('a', '1') == 1 + assert r.hincrby('a', '1', amount=2) == 3 + assert r.hincrby('a', '1', amount=-2) == 1 + + @skip_if_server_version_lt('2.6.0') + def test_hincrbyfloat(self, r): + assert r.hincrbyfloat('a', '1') == 1.0 + assert r.hincrbyfloat('a', '1') == 2.0 + assert r.hincrbyfloat('a', '1', 1.2) == 3.2 + + def test_hkeys(self, r): + h = {b('a1'): b('1'), b('a2'): b('2'), b('a3'): b('3')} + r.hmset('a', h) + local_keys = list(iterkeys(h)) + remote_keys = r.hkeys('a') + assert (sorted(local_keys) == sorted(remote_keys)) + + def test_hlen(self, r): + r.hmset('a', {'1': 1, '2': 2, '3': 3}) + assert r.hlen('a') == 3 + + def test_hmget(self, r): + assert r.hmset('a', {'a': 1, 'b': 2, 'c': 3}) + assert r.hmget('a', 'a', 'b', 'c') == [b('1'), b('2'), b('3')] + + def test_hmset(self, r): + h = {b('a'): b('1'), b('b'): b('2'), b('c'): b('3')} + assert r.hmset('a', h) + assert r.hgetall('a') == h + + def test_hsetnx(self, r): + # Initially set the hash field + assert r.hsetnx('a', '1', 1) + assert r.hget('a', '1') == b('1') + assert not r.hsetnx('a', '1', 2) + assert r.hget('a', '1') == b('1') + + def test_hvals(self, r): + h = {b('a1'): b('1'), b('a2'): b('2'), b('a3'): b('3')} + r.hmset('a', h) + local_vals = list(itervalues(h)) + remote_vals = r.hvals('a') + assert sorted(local_vals) == sorted(remote_vals) + + # SORT + def test_sort_basic(self, r): + r.rpush('a', '3', '2', '1', '4') + assert r.sort('a') == [b('1'), b('2'), b('3'), b('4')] + + def test_sort_limited(self, r): + r.rpush('a', '3', '2', '1', '4') + assert r.sort('a', start=1, num=2) == [b('2'), b('3')] + + def test_sort_by(self, r): + r['score:1'] = 8 + r['score:2'] = 3 + r['score:3'] = 5 + r.rpush('a', '3', '2', '1') + assert r.sort('a', by='score:*') == [b('2'), b('3'), b('1')] + + def test_sort_get(self, r): + r['user:1'] = 'u1' + r['user:2'] = 'u2' + r['user:3'] = 'u3' + r.rpush('a', '2', '3', '1') + assert r.sort('a', get='user:*') == [b('u1'), b('u2'), b('u3')] + + def test_sort_get_multi(self, r): + r['user:1'] = 'u1' + r['user:2'] = 'u2' + r['user:3'] = 'u3' + r.rpush('a', '2', '3', '1') + assert r.sort('a', get=('user:*', '#')) == \ + [b('u1'), b('1'), b('u2'), b('2'), b('u3'), b('3')] + + def test_sort_get_groups_two(self, r): + r['user:1'] = 'u1' + r['user:2'] = 'u2' + r['user:3'] = 'u3' + r.rpush('a', '2', '3', '1') + assert r.sort('a', get=('user:*', '#'), groups=True) == \ + [(b('u1'), b('1')), (b('u2'), b('2')), (b('u3'), b('3'))] + + def test_sort_groups_string_get(self, r): + r['user:1'] = 'u1' + r['user:2'] = 'u2' + r['user:3'] = 'u3' + r.rpush('a', '2', '3', '1') + with pytest.raises(exceptions.DataError): + r.sort('a', get='user:*', groups=True) + + def test_sort_groups_just_one_get(self, r): + r['user:1'] = 'u1' + r['user:2'] = 'u2' + r['user:3'] = 'u3' + r.rpush('a', '2', '3', '1') + with pytest.raises(exceptions.DataError): + r.sort('a', get=['user:*'], groups=True) + + def test_sort_groups_no_get(self, r): + r['user:1'] = 'u1' + r['user:2'] = 'u2' + r['user:3'] = 'u3' + r.rpush('a', '2', '3', '1') + with pytest.raises(exceptions.DataError): + r.sort('a', groups=True) + + def test_sort_groups_three_gets(self, r): + r['user:1'] = 'u1' + r['user:2'] = 'u2' + r['user:3'] = 'u3' + r['door:1'] = 'd1' + r['door:2'] = 'd2' + r['door:3'] = 'd3' + r.rpush('a', '2', '3', '1') + assert r.sort('a', get=('user:*', 'door:*', '#'), groups=True) == \ + [ + (b('u1'), b('d1'), b('1')), + (b('u2'), b('d2'), b('2')), + (b('u3'), b('d3'), b('3')) + ] + + def test_sort_desc(self, r): + r.rpush('a', '2', '3', '1') + assert r.sort('a', desc=True) == [b('3'), b('2'), b('1')] + + def test_sort_alpha(self, r): + r.rpush('a', 'e', 'c', 'b', 'd', 'a') + assert r.sort('a', alpha=True) == \ + [b('a'), b('b'), b('c'), b('d'), b('e')] + + def test_sort_store(self, r): + r.rpush('a', '2', '3', '1') + assert r.sort('a', store='sorted_values') == 3 + assert r.lrange('sorted_values', 0, -1) == [b('1'), b('2'), b('3')] + + def test_sort_all_options(self, r): + r['user:1:username'] = 'zeus' + r['user:2:username'] = 'titan' + r['user:3:username'] = 'hermes' + r['user:4:username'] = 'hercules' + r['user:5:username'] = 'apollo' + r['user:6:username'] = 'athena' + r['user:7:username'] = 'hades' + r['user:8:username'] = 'dionysus' + + r['user:1:favorite_drink'] = 'yuengling' + r['user:2:favorite_drink'] = 'rum' + r['user:3:favorite_drink'] = 'vodka' + r['user:4:favorite_drink'] = 'milk' + r['user:5:favorite_drink'] = 'pinot noir' + r['user:6:favorite_drink'] = 'water' + r['user:7:favorite_drink'] = 'gin' + r['user:8:favorite_drink'] = 'apple juice' + + r.rpush('gods', '5', '8', '3', '1', '2', '7', '6', '4') + num = r.sort('gods', start=2, num=4, by='user:*:username', + get='user:*:favorite_drink', desc=True, alpha=True, + store='sorted') + assert num == 4 + assert r.lrange('sorted', 0, 10) == \ + [b('vodka'), b('milk'), b('gin'), b('apple juice')] + + +class TestStrictCommands(object): + + def test_strict_zadd(self, sr): + sr.zadd('a', 1.0, 'a1', 2.0, 'a2', a3=3.0) + assert sr.zrange('a', 0, -1, withscores=True) == \ + [(b('a1'), 1.0), (b('a2'), 2.0), (b('a3'), 3.0)] + + def test_strict_lrem(self, sr): + sr.rpush('a', 'a1', 'a2', 'a3', 'a1') + sr.lrem('a', 0, 'a1') + assert sr.lrange('a', 0, -1) == [b('a2'), b('a3')] + + def test_strict_setex(self, sr): + assert sr.setex('a', 60, '1') + assert sr['a'] == b('1') + assert 0 < sr.ttl('a') <= 60 + + def test_strict_ttl(self, sr): + assert not sr.expire('a', 10) + sr['a'] = '1' + assert sr.expire('a', 10) + assert 0 < sr.ttl('a') <= 10 + assert sr.persist('a') + assert sr.ttl('a') == -1 + + @skip_if_server_version_lt('2.6.0') + def test_strict_pttl(self, sr): + assert not sr.pexpire('a', 10000) + sr['a'] = '1' + assert sr.pexpire('a', 10000) + assert 0 < sr.pttl('a') <= 10000 + assert sr.persist('a') + assert sr.pttl('a') == -1 + + +class TestBinarySave(object): + def test_binary_get_set(self, r): + assert r.set(' foo bar ', '123') + assert r.get(' foo bar ') == b('123') + + assert r.set(' foo\r\nbar\r\n ', '456') + assert r.get(' foo\r\nbar\r\n ') == b('456') + + assert r.set(' \r\n\t\x07\x13 ', '789') + assert r.get(' \r\n\t\x07\x13 ') == b('789') + + assert sorted(r.keys('*')) == \ + [b(' \r\n\t\x07\x13 '), b(' foo\r\nbar\r\n '), b(' foo bar ')] + + assert r.delete(' foo bar ') + assert r.delete(' foo\r\nbar\r\n ') + assert r.delete(' \r\n\t\x07\x13 ') + + def test_binary_lists(self, r): + mapping = { + b('foo bar'): [b('1'), b('2'), b('3')], + b('foo\r\nbar\r\n'): [b('4'), b('5'), b('6')], + b('foo\tbar\x07'): [b('7'), b('8'), b('9')], + } + # fill in lists + for key, value in iteritems(mapping): + r.rpush(key, *value) + + # check that KEYS returns all the keys as they are + assert sorted(r.keys('*')) == sorted(list(iterkeys(mapping))) + + # check that it is possible to get list content by key name + for key, value in iteritems(mapping): + assert r.lrange(key, 0, -1) == value + + def test_22_info(self, r): + """ + Older Redis versions contained 'allocation_stats' in INFO that + was the cause of a number of bugs when parsing. + """ + info = "allocation_stats:6=1,7=1,8=7141,9=180,10=92,11=116,12=5330," \ + "13=123,14=3091,15=11048,16=225842,17=1784,18=814,19=12020," \ + "20=2530,21=645,22=15113,23=8695,24=142860,25=318,26=3303," \ + "27=20561,28=54042,29=37390,30=1884,31=18071,32=31367,33=160," \ + "34=169,35=201,36=10155,37=1045,38=15078,39=22985,40=12523," \ + "41=15588,42=265,43=1287,44=142,45=382,46=945,47=426,48=171," \ + "49=56,50=516,51=43,52=41,53=46,54=54,55=75,56=647,57=332," \ + "58=32,59=39,60=48,61=35,62=62,63=32,64=221,65=26,66=30," \ + "67=36,68=41,69=44,70=26,71=144,72=169,73=24,74=37,75=25," \ + "76=42,77=21,78=126,79=374,80=27,81=40,82=43,83=47,84=46," \ + "85=114,86=34,87=37,88=7240,89=34,90=38,91=18,92=99,93=20," \ + "94=18,95=17,96=15,97=22,98=18,99=69,100=17,101=22,102=15," \ + "103=29,104=39,105=30,106=70,107=22,108=21,109=26,110=52," \ + "111=45,112=33,113=67,114=41,115=44,116=48,117=53,118=54," \ + "119=51,120=75,121=44,122=57,123=44,124=66,125=56,126=52," \ + "127=81,128=108,129=70,130=50,131=51,132=53,133=45,134=62," \ + "135=12,136=13,137=7,138=15,139=21,140=11,141=20,142=6,143=7," \ + "144=11,145=6,146=16,147=19,148=1112,149=1,151=83,154=1," \ + "155=1,156=1,157=1,160=1,161=1,162=2,166=1,169=1,170=1,171=2," \ + "172=1,174=1,176=2,177=9,178=34,179=73,180=30,181=1,185=3," \ + "187=1,188=1,189=1,192=1,196=1,198=1,200=1,201=1,204=1,205=1," \ + "207=1,208=1,209=1,214=2,215=31,216=78,217=28,218=5,219=2," \ + "220=1,222=1,225=1,227=1,234=1,242=1,250=1,252=1,253=1," \ + ">=256=203" + parsed = parse_info(info) + assert 'allocation_stats' in parsed + assert '6' in parsed['allocation_stats'] + assert '>=256' in parsed['allocation_stats'] + + def test_large_responses(self, r): + "The PythonParser has some special cases for return values > 1MB" + # load up 5MB of data into a key + data = ''.join([ascii_letters] * (5000000 // len(ascii_letters))) + r['a'] = data + assert r['a'] == b(data) + + def test_floating_point_encoding(self, r): + """ + High precision floating point values sent to the server should keep + precision. + """ + timestamp = 1349673917.939762 + r.zadd('a', 'a1', timestamp) + assert r.zscore('a', 'a1') == timestamp diff --git a/client/ledis-py/tests/test_connection_pool.py b/client/ledis-py/tests/test_connection_pool.py new file mode 100644 index 0000000..55ccce1 --- /dev/null +++ b/client/ledis-py/tests/test_connection_pool.py @@ -0,0 +1,402 @@ +from __future__ import with_statement +import os +import pytest +import redis +import time +import re + +from threading import Thread +from redis.connection import ssl_available +from .conftest import skip_if_server_version_lt + + +class DummyConnection(object): + description_format = "DummyConnection<>" + + def __init__(self, **kwargs): + self.kwargs = kwargs + self.pid = os.getpid() + + +class TestConnectionPool(object): + def get_pool(self, connection_kwargs=None, max_connections=None, + connection_class=DummyConnection): + connection_kwargs = connection_kwargs or {} + pool = redis.ConnectionPool( + connection_class=connection_class, + max_connections=max_connections, + **connection_kwargs) + return pool + + def test_connection_creation(self): + connection_kwargs = {'foo': 'bar', 'biz': 'baz'} + pool = self.get_pool(connection_kwargs=connection_kwargs) + connection = pool.get_connection('_') + assert isinstance(connection, DummyConnection) + assert connection.kwargs == connection_kwargs + + def test_multiple_connections(self): + pool = self.get_pool() + c1 = pool.get_connection('_') + c2 = pool.get_connection('_') + assert c1 != c2 + + def test_max_connections(self): + pool = self.get_pool(max_connections=2) + pool.get_connection('_') + pool.get_connection('_') + with pytest.raises(redis.ConnectionError): + pool.get_connection('_') + + def test_reuse_previously_released_connection(self): + pool = self.get_pool() + c1 = pool.get_connection('_') + pool.release(c1) + c2 = pool.get_connection('_') + assert c1 == c2 + + def test_repr_contains_db_info_tcp(self): + connection_kwargs = {'host': 'localhost', 'port': 6379, 'db': 1} + pool = self.get_pool(connection_kwargs=connection_kwargs, + connection_class=redis.Connection) + expected = 'ConnectionPool>' + assert repr(pool) == expected + + def test_repr_contains_db_info_unix(self): + connection_kwargs = {'path': '/abc', 'db': 1} + pool = self.get_pool(connection_kwargs=connection_kwargs, + connection_class=redis.UnixDomainSocketConnection) + expected = 'ConnectionPool>' + assert repr(pool) == expected + + +class TestBlockingConnectionPool(object): + def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): + connection_kwargs = connection_kwargs or {} + pool = redis.BlockingConnectionPool(connection_class=DummyConnection, + max_connections=max_connections, + timeout=timeout, + **connection_kwargs) + return pool + + def test_connection_creation(self): + connection_kwargs = {'foo': 'bar', 'biz': 'baz'} + pool = self.get_pool(connection_kwargs=connection_kwargs) + connection = pool.get_connection('_') + assert isinstance(connection, DummyConnection) + assert connection.kwargs == connection_kwargs + + def test_multiple_connections(self): + pool = self.get_pool() + c1 = pool.get_connection('_') + c2 = pool.get_connection('_') + assert c1 != c2 + + def test_connection_pool_blocks_until_timeout(self): + "When out of connections, block for timeout seconds, then raise" + pool = self.get_pool(max_connections=1, timeout=0.1) + pool.get_connection('_') + + start = time.time() + with pytest.raises(redis.ConnectionError): + pool.get_connection('_') + # we should have waited at least 0.1 seconds + assert time.time() - start >= 0.1 + + def connection_pool_blocks_until_another_connection_released(self): + """ + When out of connections, block until another connection is released + to the pool + """ + pool = self.get_pool(max_connections=1, timeout=2) + c1 = pool.get_connection('_') + + def target(): + time.sleep(0.1) + pool.release(c1) + + Thread(target=target).start() + start = time.time() + pool.get_connection('_') + assert time.time() - start >= 0.1 + + def test_reuse_previously_released_connection(self): + pool = self.get_pool() + c1 = pool.get_connection('_') + pool.release(c1) + c2 = pool.get_connection('_') + assert c1 == c2 + + def test_repr_contains_db_info_tcp(self): + pool = redis.ConnectionPool(host='localhost', port=6379, db=0) + expected = 'ConnectionPool>' + assert repr(pool) == expected + + def test_repr_contains_db_info_unix(self): + pool = redis.ConnectionPool( + connection_class=redis.UnixDomainSocketConnection, + path='abc', + db=0, + ) + expected = 'ConnectionPool>' + assert repr(pool) == expected + + +class TestConnectionPoolURLParsing(object): + def test_defaults(self): + pool = redis.ConnectionPool.from_url('redis://localhost') + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + 'host': 'localhost', + 'port': 6379, + 'db': 0, + 'password': None, + } + + def test_hostname(self): + pool = redis.ConnectionPool.from_url('redis://myhost') + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + 'host': 'myhost', + 'port': 6379, + 'db': 0, + 'password': None, + } + + def test_port(self): + pool = redis.ConnectionPool.from_url('redis://localhost:6380') + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + 'host': 'localhost', + 'port': 6380, + 'db': 0, + 'password': None, + } + + def test_password(self): + pool = redis.ConnectionPool.from_url('redis://:mypassword@localhost') + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + 'host': 'localhost', + 'port': 6379, + 'db': 0, + 'password': 'mypassword', + } + + def test_db_as_argument(self): + pool = redis.ConnectionPool.from_url('redis://localhost', db='1') + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + 'host': 'localhost', + 'port': 6379, + 'db': 1, + 'password': None, + } + + def test_db_in_path(self): + pool = redis.ConnectionPool.from_url('redis://localhost/2', db='1') + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + 'host': 'localhost', + 'port': 6379, + 'db': 2, + 'password': None, + } + + def test_db_in_querystring(self): + pool = redis.ConnectionPool.from_url('redis://localhost/2?db=3', + db='1') + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + 'host': 'localhost', + 'port': 6379, + 'db': 3, + 'password': None, + } + + def test_extra_querystring_options(self): + pool = redis.ConnectionPool.from_url('redis://localhost?a=1&b=2') + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + 'host': 'localhost', + 'port': 6379, + 'db': 0, + 'password': None, + 'a': '1', + 'b': '2' + } + + def test_calling_from_subclass_returns_correct_instance(self): + pool = redis.BlockingConnectionPool.from_url('redis://localhost') + assert isinstance(pool, redis.BlockingConnectionPool) + + def test_client_creates_connection_pool(self): + r = redis.StrictRedis.from_url('redis://myhost') + assert r.connection_pool.connection_class == redis.Connection + assert r.connection_pool.connection_kwargs == { + 'host': 'myhost', + 'port': 6379, + 'db': 0, + 'password': None, + } + + +class TestConnectionPoolUnixSocketURLParsing(object): + def test_defaults(self): + pool = redis.ConnectionPool.from_url('unix:///socket') + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + 'path': '/socket', + 'db': 0, + 'password': None, + } + + def test_password(self): + pool = redis.ConnectionPool.from_url('unix://:mypassword@/socket') + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + 'path': '/socket', + 'db': 0, + 'password': 'mypassword', + } + + def test_db_as_argument(self): + pool = redis.ConnectionPool.from_url('unix:///socket', db=1) + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + 'path': '/socket', + 'db': 1, + 'password': None, + } + + def test_db_in_querystring(self): + pool = redis.ConnectionPool.from_url('unix:///socket?db=2', db=1) + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + 'path': '/socket', + 'db': 2, + 'password': None, + } + + def test_extra_querystring_options(self): + pool = redis.ConnectionPool.from_url('unix:///socket?a=1&b=2') + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + 'path': '/socket', + 'db': 0, + 'password': None, + 'a': '1', + 'b': '2' + } + + +class TestSSLConnectionURLParsing(object): + @pytest.mark.skipif(not ssl_available, reason="SSL not installed") + def test_defaults(self): + pool = redis.ConnectionPool.from_url('rediss://localhost') + assert pool.connection_class == redis.SSLConnection + assert pool.connection_kwargs == { + 'host': 'localhost', + 'port': 6379, + 'db': 0, + 'password': None, + } + + @pytest.mark.skipif(not ssl_available, reason="SSL not installed") + def test_cert_reqs_options(self): + import ssl + pool = redis.ConnectionPool.from_url('rediss://?ssl_cert_reqs=none') + assert pool.get_connection('_').cert_reqs == ssl.CERT_NONE + + pool = redis.ConnectionPool.from_url( + 'rediss://?ssl_cert_reqs=optional') + assert pool.get_connection('_').cert_reqs == ssl.CERT_OPTIONAL + + pool = redis.ConnectionPool.from_url( + 'rediss://?ssl_cert_reqs=required') + assert pool.get_connection('_').cert_reqs == ssl.CERT_REQUIRED + + +class TestConnection(object): + def test_on_connect_error(self): + """ + An error in Connection.on_connect should disconnect from the server + see for details: https://github.com/andymccurdy/redis-py/issues/368 + """ + # this assumes the Redis server being tested against doesn't have + # 9999 databases ;) + bad_connection = redis.Redis(db=9999) + # an error should be raised on connect + with pytest.raises(redis.RedisError): + bad_connection.info() + pool = bad_connection.connection_pool + assert len(pool._available_connections) == 1 + assert not pool._available_connections[0]._sock + + @skip_if_server_version_lt('2.8.8') + def test_busy_loading_disconnects_socket(self, r): + """ + If Redis raises a LOADING error, the connection should be + disconnected and a BusyLoadingError raised + """ + with pytest.raises(redis.BusyLoadingError): + r.execute_command('DEBUG', 'ERROR', 'LOADING fake message') + pool = r.connection_pool + assert len(pool._available_connections) == 1 + assert not pool._available_connections[0]._sock + + @skip_if_server_version_lt('2.8.8') + def test_busy_loading_from_pipeline_immediate_command(self, r): + """ + BusyLoadingErrors should raise from Pipelines that execute a + command immediately, like WATCH does. + """ + pipe = r.pipeline() + with pytest.raises(redis.BusyLoadingError): + pipe.immediate_execute_command('DEBUG', 'ERROR', + 'LOADING fake message') + pool = r.connection_pool + assert not pipe.connection + assert len(pool._available_connections) == 1 + assert not pool._available_connections[0]._sock + + @skip_if_server_version_lt('2.8.8') + def test_busy_loading_from_pipeline(self, r): + """ + BusyLoadingErrors should be raised from a pipeline execution + regardless of the raise_on_error flag. + """ + pipe = r.pipeline() + pipe.execute_command('DEBUG', 'ERROR', 'LOADING fake message') + with pytest.raises(redis.BusyLoadingError): + pipe.execute() + pool = r.connection_pool + assert not pipe.connection + assert len(pool._available_connections) == 1 + assert not pool._available_connections[0]._sock + + @skip_if_server_version_lt('2.8.8') + def test_read_only_error(self, r): + "READONLY errors get turned in ReadOnlyError exceptions" + with pytest.raises(redis.ReadOnlyError): + r.execute_command('DEBUG', 'ERROR', 'READONLY blah blah') + + def test_connect_from_url_tcp(self): + connection = redis.Redis.from_url('redis://localhost') + pool = connection.connection_pool + + assert re.match('(.*)<(.*)<(.*)>>', repr(pool)).groups() == ( + 'ConnectionPool', + 'Connection', + 'host=localhost,port=6379,db=0', + ) + + def test_connect_from_url_unix(self): + connection = redis.Redis.from_url('unix:///path/to/socket') + pool = connection.connection_pool + + assert re.match('(.*)<(.*)<(.*)>>', repr(pool)).groups() == ( + 'ConnectionPool', + 'UnixDomainSocketConnection', + 'path=/path/to/socket,db=0', + ) diff --git a/client/ledis-py/tests/test_encoding.py b/client/ledis-py/tests/test_encoding.py new file mode 100644 index 0000000..b1df0a5 --- /dev/null +++ b/client/ledis-py/tests/test_encoding.py @@ -0,0 +1,33 @@ +from __future__ import with_statement +import pytest + +from redis._compat import unichr, u, unicode +from .conftest import r as _redis_client + + +class TestEncoding(object): + @pytest.fixture() + def r(self, request): + return _redis_client(request=request, decode_responses=True) + + def test_simple_encoding(self, r): + unicode_string = unichr(3456) + u('abcd') + unichr(3421) + r['unicode-string'] = unicode_string + cached_val = r['unicode-string'] + assert isinstance(cached_val, unicode) + assert unicode_string == cached_val + + def test_list_encoding(self, r): + unicode_string = unichr(3456) + u('abcd') + unichr(3421) + result = [unicode_string, unicode_string, unicode_string] + r.rpush('a', *result) + assert r.lrange('a', 0, -1) == result + + +class TestCommandsAndTokensArentEncoded(object): + @pytest.fixture() + def r(self, request): + return _redis_client(request=request, charset='utf-16') + + def test_basic_command(self, r): + r.set('hello', 'world') diff --git a/client/ledis-py/tests/test_lock.py b/client/ledis-py/tests/test_lock.py new file mode 100644 index 0000000..d732ae1 --- /dev/null +++ b/client/ledis-py/tests/test_lock.py @@ -0,0 +1,167 @@ +from __future__ import with_statement +import pytest +import time + +from redis.exceptions import LockError, ResponseError +from redis.lock import Lock, LuaLock + + +class TestLock(object): + lock_class = Lock + + def get_lock(self, redis, *args, **kwargs): + kwargs['lock_class'] = self.lock_class + return redis.lock(*args, **kwargs) + + def test_lock(self, sr): + lock = self.get_lock(sr, 'foo') + assert lock.acquire(blocking=False) + assert sr.get('foo') == lock.local.token + assert sr.ttl('foo') == -1 + lock.release() + assert sr.get('foo') is None + + def test_competing_locks(self, sr): + lock1 = self.get_lock(sr, 'foo') + lock2 = self.get_lock(sr, 'foo') + assert lock1.acquire(blocking=False) + assert not lock2.acquire(blocking=False) + lock1.release() + assert lock2.acquire(blocking=False) + assert not lock1.acquire(blocking=False) + lock2.release() + + def test_timeout(self, sr): + lock = self.get_lock(sr, 'foo', timeout=10) + assert lock.acquire(blocking=False) + assert 8 < sr.ttl('foo') <= 10 + lock.release() + + def test_float_timeout(self, sr): + lock = self.get_lock(sr, 'foo', timeout=9.5) + assert lock.acquire(blocking=False) + assert 8 < sr.pttl('foo') <= 9500 + lock.release() + + def test_blocking_timeout(self, sr): + lock1 = self.get_lock(sr, 'foo') + assert lock1.acquire(blocking=False) + lock2 = self.get_lock(sr, 'foo', blocking_timeout=0.2) + start = time.time() + assert not lock2.acquire() + assert (time.time() - start) > 0.2 + lock1.release() + + def test_context_manager(self, sr): + # blocking_timeout prevents a deadlock if the lock can't be acquired + # for some reason + with self.get_lock(sr, 'foo', blocking_timeout=0.2) as lock: + assert sr.get('foo') == lock.local.token + assert sr.get('foo') is None + + def test_high_sleep_raises_error(self, sr): + "If sleep is higher than timeout, it should raise an error" + with pytest.raises(LockError): + self.get_lock(sr, 'foo', timeout=1, sleep=2) + + def test_releasing_unlocked_lock_raises_error(self, sr): + lock = self.get_lock(sr, 'foo') + with pytest.raises(LockError): + lock.release() + + def test_releasing_lock_no_longer_owned_raises_error(self, sr): + lock = self.get_lock(sr, 'foo') + lock.acquire(blocking=False) + # manually change the token + sr.set('foo', 'a') + with pytest.raises(LockError): + lock.release() + # even though we errored, the token is still cleared + assert lock.local.token is None + + def test_extend_lock(self, sr): + lock = self.get_lock(sr, 'foo', timeout=10) + assert lock.acquire(blocking=False) + assert 8000 < sr.pttl('foo') <= 10000 + assert lock.extend(10) + assert 16000 < sr.pttl('foo') <= 20000 + lock.release() + + def test_extend_lock_float(self, sr): + lock = self.get_lock(sr, 'foo', timeout=10.0) + assert lock.acquire(blocking=False) + assert 8000 < sr.pttl('foo') <= 10000 + assert lock.extend(10.0) + assert 16000 < sr.pttl('foo') <= 20000 + lock.release() + + def test_extending_unlocked_lock_raises_error(self, sr): + lock = self.get_lock(sr, 'foo', timeout=10) + with pytest.raises(LockError): + lock.extend(10) + + def test_extending_lock_with_no_timeout_raises_error(self, sr): + lock = self.get_lock(sr, 'foo') + assert lock.acquire(blocking=False) + with pytest.raises(LockError): + lock.extend(10) + lock.release() + + def test_extending_lock_no_longer_owned_raises_error(self, sr): + lock = self.get_lock(sr, 'foo') + assert lock.acquire(blocking=False) + sr.set('foo', 'a') + with pytest.raises(LockError): + lock.extend(10) + + +class TestLuaLock(TestLock): + lock_class = LuaLock + + +class TestLockClassSelection(object): + def test_lock_class_argument(self, sr): + lock = sr.lock('foo', lock_class=Lock) + assert type(lock) == Lock + lock = sr.lock('foo', lock_class=LuaLock) + assert type(lock) == LuaLock + + def test_cached_lualock_flag(self, sr): + try: + sr._use_lua_lock = True + lock = sr.lock('foo') + assert type(lock) == LuaLock + finally: + sr._use_lua_lock = None + + def test_cached_lock_flag(self, sr): + try: + sr._use_lua_lock = False + lock = sr.lock('foo') + assert type(lock) == Lock + finally: + sr._use_lua_lock = None + + def test_lua_compatible_server(self, sr, monkeypatch): + @classmethod + def mock_register(cls, redis): + return + monkeypatch.setattr(LuaLock, 'register_scripts', mock_register) + try: + lock = sr.lock('foo') + assert type(lock) == LuaLock + assert sr._use_lua_lock is True + finally: + sr._use_lua_lock = None + + def test_lua_unavailable(self, sr, monkeypatch): + @classmethod + def mock_register(cls, redis): + raise ResponseError() + monkeypatch.setattr(LuaLock, 'register_scripts', mock_register) + try: + lock = sr.lock('foo') + assert type(lock) == Lock + assert sr._use_lua_lock is False + finally: + sr._use_lua_lock = None diff --git a/client/ledis-py/tests/test_pipeline.py b/client/ledis-py/tests/test_pipeline.py new file mode 100644 index 0000000..46fc994 --- /dev/null +++ b/client/ledis-py/tests/test_pipeline.py @@ -0,0 +1,226 @@ +from __future__ import with_statement +import pytest + +import redis +from redis._compat import b, u, unichr, unicode + + +class TestPipeline(object): + def test_pipeline(self, r): + with r.pipeline() as pipe: + pipe.set('a', 'a1').get('a').zadd('z', z1=1).zadd('z', z2=4) + pipe.zincrby('z', 'z1').zrange('z', 0, 5, withscores=True) + assert pipe.execute() == \ + [ + True, + b('a1'), + True, + True, + 2.0, + [(b('z1'), 2.0), (b('z2'), 4)], + ] + + def test_pipeline_length(self, r): + with r.pipeline() as pipe: + # Initially empty. + assert len(pipe) == 0 + assert not pipe + + # Fill 'er up! + pipe.set('a', 'a1').set('b', 'b1').set('c', 'c1') + assert len(pipe) == 3 + assert pipe + + # Execute calls reset(), so empty once again. + pipe.execute() + assert len(pipe) == 0 + assert not pipe + + def test_pipeline_no_transaction(self, r): + with r.pipeline(transaction=False) as pipe: + pipe.set('a', 'a1').set('b', 'b1').set('c', 'c1') + assert pipe.execute() == [True, True, True] + assert r['a'] == b('a1') + assert r['b'] == b('b1') + assert r['c'] == b('c1') + + def test_pipeline_no_transaction_watch(self, r): + r['a'] = 0 + + with r.pipeline(transaction=False) as pipe: + pipe.watch('a') + a = pipe.get('a') + + pipe.multi() + pipe.set('a', int(a) + 1) + assert pipe.execute() == [True] + + def test_pipeline_no_transaction_watch_failure(self, r): + r['a'] = 0 + + with r.pipeline(transaction=False) as pipe: + pipe.watch('a') + a = pipe.get('a') + + r['a'] = 'bad' + + pipe.multi() + pipe.set('a', int(a) + 1) + + with pytest.raises(redis.WatchError): + pipe.execute() + + assert r['a'] == b('bad') + + def test_exec_error_in_response(self, r): + """ + an invalid pipeline command at exec time adds the exception instance + to the list of returned values + """ + r['c'] = 'a' + with r.pipeline() as pipe: + pipe.set('a', 1).set('b', 2).lpush('c', 3).set('d', 4) + result = pipe.execute(raise_on_error=False) + + assert result[0] + assert r['a'] == b('1') + assert result[1] + assert r['b'] == b('2') + + # we can't lpush to a key that's a string value, so this should + # be a ResponseError exception + assert isinstance(result[2], redis.ResponseError) + assert r['c'] == b('a') + + # since this isn't a transaction, the other commands after the + # error are still executed + assert result[3] + assert r['d'] == b('4') + + # make sure the pipe was restored to a working state + assert pipe.set('z', 'zzz').execute() == [True] + assert r['z'] == b('zzz') + + def test_exec_error_raised(self, r): + r['c'] = 'a' + with r.pipeline() as pipe: + pipe.set('a', 1).set('b', 2).lpush('c', 3).set('d', 4) + with pytest.raises(redis.ResponseError) as ex: + pipe.execute() + assert unicode(ex.value).startswith('Command # 3 (LPUSH c 3) of ' + 'pipeline caused error: ') + + # make sure the pipe was restored to a working state + assert pipe.set('z', 'zzz').execute() == [True] + assert r['z'] == b('zzz') + + def test_parse_error_raised(self, r): + with r.pipeline() as pipe: + # the zrem is invalid because we don't pass any keys to it + pipe.set('a', 1).zrem('b').set('b', 2) + with pytest.raises(redis.ResponseError) as ex: + pipe.execute() + + assert unicode(ex.value).startswith('Command # 2 (ZREM b) of ' + 'pipeline caused error: ') + + # make sure the pipe was restored to a working state + assert pipe.set('z', 'zzz').execute() == [True] + assert r['z'] == b('zzz') + + def test_watch_succeed(self, r): + r['a'] = 1 + r['b'] = 2 + + with r.pipeline() as pipe: + pipe.watch('a', 'b') + assert pipe.watching + a_value = pipe.get('a') + b_value = pipe.get('b') + assert a_value == b('1') + assert b_value == b('2') + pipe.multi() + + pipe.set('c', 3) + assert pipe.execute() == [True] + assert not pipe.watching + + def test_watch_failure(self, r): + r['a'] = 1 + r['b'] = 2 + + with r.pipeline() as pipe: + pipe.watch('a', 'b') + r['b'] = 3 + pipe.multi() + pipe.get('a') + with pytest.raises(redis.WatchError): + pipe.execute() + + assert not pipe.watching + + def test_unwatch(self, r): + r['a'] = 1 + r['b'] = 2 + + with r.pipeline() as pipe: + pipe.watch('a', 'b') + r['b'] = 3 + pipe.unwatch() + assert not pipe.watching + pipe.get('a') + assert pipe.execute() == [b('1')] + + def test_transaction_callable(self, r): + r['a'] = 1 + r['b'] = 2 + has_run = [] + + def my_transaction(pipe): + a_value = pipe.get('a') + assert a_value in (b('1'), b('2')) + b_value = pipe.get('b') + assert b_value == b('2') + + # silly run-once code... incr's "a" so WatchError should be raised + # forcing this all to run again. this should incr "a" once to "2" + if not has_run: + r.incr('a') + has_run.append('it has') + + pipe.multi() + pipe.set('c', int(a_value) + int(b_value)) + + result = r.transaction(my_transaction, 'a', 'b') + assert result == [True] + assert r['c'] == b('4') + + def test_exec_error_in_no_transaction_pipeline(self, r): + r['a'] = 1 + with r.pipeline(transaction=False) as pipe: + pipe.llen('a') + pipe.expire('a', 100) + + with pytest.raises(redis.ResponseError) as ex: + pipe.execute() + + assert unicode(ex.value).startswith('Command # 1 (LLEN a) of ' + 'pipeline caused error: ') + + assert r['a'] == b('1') + + def test_exec_error_in_no_transaction_pipeline_unicode_command(self, r): + key = unichr(3456) + u('abcd') + unichr(3421) + r[key] = 1 + with r.pipeline(transaction=False) as pipe: + pipe.llen(key) + pipe.expire(key, 100) + + with pytest.raises(redis.ResponseError) as ex: + pipe.execute() + + expected = unicode('Command # 1 (LLEN %s) of pipeline caused ' + 'error: ') % key + assert unicode(ex.value).startswith(expected) + + assert r[key] == b('1') diff --git a/client/ledis-py/tests/test_pubsub.py b/client/ledis-py/tests/test_pubsub.py new file mode 100644 index 0000000..5486b75 --- /dev/null +++ b/client/ledis-py/tests/test_pubsub.py @@ -0,0 +1,392 @@ +from __future__ import with_statement +import pytest +import time + +import redis +from redis.exceptions import ConnectionError +from redis._compat import basestring, u, unichr + +from .conftest import r as _redis_client + + +def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): + now = time.time() + timeout = now + timeout + while now < timeout: + message = pubsub.get_message( + ignore_subscribe_messages=ignore_subscribe_messages) + if message is not None: + return message + time.sleep(0.01) + now = time.time() + return None + + +def make_message(type, channel, data, pattern=None): + return { + 'type': type, + 'pattern': pattern and pattern.encode('utf-8') or None, + 'channel': channel.encode('utf-8'), + 'data': data.encode('utf-8') if isinstance(data, basestring) else data + } + + +def make_subscribe_test_data(pubsub, type): + if type == 'channel': + return { + 'p': pubsub, + 'sub_type': 'subscribe', + 'unsub_type': 'unsubscribe', + 'sub_func': pubsub.subscribe, + 'unsub_func': pubsub.unsubscribe, + 'keys': ['foo', 'bar', u('uni') + unichr(4456) + u('code')] + } + elif type == 'pattern': + return { + 'p': pubsub, + 'sub_type': 'psubscribe', + 'unsub_type': 'punsubscribe', + 'sub_func': pubsub.psubscribe, + 'unsub_func': pubsub.punsubscribe, + 'keys': ['f*', 'b*', u('uni') + unichr(4456) + u('*')] + } + assert False, 'invalid subscribe type: %s' % type + + +class TestPubSubSubscribeUnsubscribe(object): + + def _test_subscribe_unsubscribe(self, p, sub_type, unsub_type, sub_func, + unsub_func, keys): + for key in keys: + assert sub_func(key) is None + + # should be a message for each channel/pattern we just subscribed to + for i, key in enumerate(keys): + assert wait_for_message(p) == make_message(sub_type, key, i + 1) + + for key in keys: + assert unsub_func(key) is None + + # should be a message for each channel/pattern we just unsubscribed + # from + for i, key in enumerate(keys): + i = len(keys) - 1 - i + assert wait_for_message(p) == make_message(unsub_type, key, i) + + def test_channel_subscribe_unsubscribe(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), 'channel') + self._test_subscribe_unsubscribe(**kwargs) + + def test_pattern_subscribe_unsubscribe(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), 'pattern') + self._test_subscribe_unsubscribe(**kwargs) + + def _test_resubscribe_on_reconnection(self, p, sub_type, unsub_type, + sub_func, unsub_func, keys): + + for key in keys: + assert sub_func(key) is None + + # should be a message for each channel/pattern we just subscribed to + for i, key in enumerate(keys): + assert wait_for_message(p) == make_message(sub_type, key, i + 1) + + # manually disconnect + p.connection.disconnect() + + # calling get_message again reconnects and resubscribes + # note, we may not re-subscribe to channels in exactly the same order + # so we have to do some extra checks to make sure we got them all + messages = [] + for i in range(len(keys)): + messages.append(wait_for_message(p)) + + unique_channels = set() + assert len(messages) == len(keys) + for i, message in enumerate(messages): + assert message['type'] == sub_type + assert message['data'] == i + 1 + assert isinstance(message['channel'], bytes) + channel = message['channel'].decode('utf-8') + unique_channels.add(channel) + + assert len(unique_channels) == len(keys) + for channel in unique_channels: + assert channel in keys + + def test_resubscribe_to_channels_on_reconnection(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), 'channel') + self._test_resubscribe_on_reconnection(**kwargs) + + def test_resubscribe_to_patterns_on_reconnection(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), 'pattern') + self._test_resubscribe_on_reconnection(**kwargs) + + def _test_subscribed_property(self, p, sub_type, unsub_type, sub_func, + unsub_func, keys): + + assert p.subscribed is False + sub_func(keys[0]) + # we're now subscribed even though we haven't processed the + # reply from the server just yet + assert p.subscribed is True + assert wait_for_message(p) == make_message(sub_type, keys[0], 1) + # we're still subscribed + assert p.subscribed is True + + # unsubscribe from all channels + unsub_func() + # we're still technically subscribed until we process the + # response messages from the server + assert p.subscribed is True + assert wait_for_message(p) == make_message(unsub_type, keys[0], 0) + # now we're no longer subscribed as no more messages can be delivered + # to any channels we were listening to + assert p.subscribed is False + + # subscribing again flips the flag back + sub_func(keys[0]) + assert p.subscribed is True + assert wait_for_message(p) == make_message(sub_type, keys[0], 1) + + # unsubscribe again + unsub_func() + assert p.subscribed is True + # subscribe to another channel before reading the unsubscribe response + sub_func(keys[1]) + assert p.subscribed is True + # read the unsubscribe for key1 + assert wait_for_message(p) == make_message(unsub_type, keys[0], 0) + # we're still subscribed to key2, so subscribed should still be True + assert p.subscribed is True + # read the key2 subscribe message + assert wait_for_message(p) == make_message(sub_type, keys[1], 1) + unsub_func() + # haven't read the message yet, so we're still subscribed + assert p.subscribed is True + assert wait_for_message(p) == make_message(unsub_type, keys[1], 0) + # now we're finally unsubscribed + assert p.subscribed is False + + def test_subscribe_property_with_channels(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), 'channel') + self._test_subscribed_property(**kwargs) + + def test_subscribe_property_with_patterns(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), 'pattern') + self._test_subscribed_property(**kwargs) + + def test_ignore_all_subscribe_messages(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + + checks = ( + (p.subscribe, 'foo'), + (p.unsubscribe, 'foo'), + (p.psubscribe, 'f*'), + (p.punsubscribe, 'f*'), + ) + + assert p.subscribed is False + for func, channel in checks: + assert func(channel) is None + assert p.subscribed is True + assert wait_for_message(p) is None + assert p.subscribed is False + + def test_ignore_individual_subscribe_messages(self, r): + p = r.pubsub() + + checks = ( + (p.subscribe, 'foo'), + (p.unsubscribe, 'foo'), + (p.psubscribe, 'f*'), + (p.punsubscribe, 'f*'), + ) + + assert p.subscribed is False + for func, channel in checks: + assert func(channel) is None + assert p.subscribed is True + message = wait_for_message(p, ignore_subscribe_messages=True) + assert message is None + assert p.subscribed is False + + +class TestPubSubMessages(object): + def setup_method(self, method): + self.message = None + + def message_handler(self, message): + self.message = message + + def test_published_message_to_channel(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.subscribe('foo') + assert r.publish('foo', 'test message') == 1 + + message = wait_for_message(p) + assert isinstance(message, dict) + assert message == make_message('message', 'foo', 'test message') + + def test_published_message_to_pattern(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.subscribe('foo') + p.psubscribe('f*') + # 1 to pattern, 1 to channel + assert r.publish('foo', 'test message') == 2 + + message1 = wait_for_message(p) + message2 = wait_for_message(p) + assert isinstance(message1, dict) + assert isinstance(message2, dict) + + expected = [ + make_message('message', 'foo', 'test message'), + make_message('pmessage', 'foo', 'test message', pattern='f*') + ] + + assert message1 in expected + assert message2 in expected + assert message1 != message2 + + def test_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.subscribe(foo=self.message_handler) + assert r.publish('foo', 'test message') == 1 + assert wait_for_message(p) is None + assert self.message == make_message('message', 'foo', 'test message') + + def test_pattern_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.psubscribe(**{'f*': self.message_handler}) + assert r.publish('foo', 'test message') == 1 + assert wait_for_message(p) is None + assert self.message == make_message('pmessage', 'foo', 'test message', + pattern='f*') + + def test_unicode_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + channel = u('uni') + unichr(4456) + u('code') + channels = {channel: self.message_handler} + p.subscribe(**channels) + assert r.publish(channel, 'test message') == 1 + assert wait_for_message(p) is None + assert self.message == make_message('message', channel, 'test message') + + def test_unicode_pattern_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + pattern = u('uni') + unichr(4456) + u('*') + channel = u('uni') + unichr(4456) + u('code') + p.psubscribe(**{pattern: self.message_handler}) + assert r.publish(channel, 'test message') == 1 + assert wait_for_message(p) is None + assert self.message == make_message('pmessage', channel, + 'test message', pattern=pattern) + + +class TestPubSubAutoDecoding(object): + "These tests only validate that we get unicode values back" + + channel = u('uni') + unichr(4456) + u('code') + pattern = u('uni') + unichr(4456) + u('*') + data = u('abc') + unichr(4458) + u('123') + + def make_message(self, type, channel, data, pattern=None): + return { + 'type': type, + 'channel': channel, + 'pattern': pattern, + 'data': data + } + + def setup_method(self, method): + self.message = None + + def message_handler(self, message): + self.message = message + + @pytest.fixture() + def r(self, request): + return _redis_client(request=request, decode_responses=True) + + def test_channel_subscribe_unsubscribe(self, r): + p = r.pubsub() + p.subscribe(self.channel) + assert wait_for_message(p) == self.make_message('subscribe', + self.channel, 1) + + p.unsubscribe(self.channel) + assert wait_for_message(p) == self.make_message('unsubscribe', + self.channel, 0) + + def test_pattern_subscribe_unsubscribe(self, r): + p = r.pubsub() + p.psubscribe(self.pattern) + assert wait_for_message(p) == self.make_message('psubscribe', + self.pattern, 1) + + p.punsubscribe(self.pattern) + assert wait_for_message(p) == self.make_message('punsubscribe', + self.pattern, 0) + + def test_channel_publish(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.subscribe(self.channel) + r.publish(self.channel, self.data) + assert wait_for_message(p) == self.make_message('message', + self.channel, + self.data) + + def test_pattern_publish(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.psubscribe(self.pattern) + r.publish(self.channel, self.data) + assert wait_for_message(p) == self.make_message('pmessage', + self.channel, + self.data, + pattern=self.pattern) + + def test_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.subscribe(**{self.channel: self.message_handler}) + r.publish(self.channel, self.data) + assert wait_for_message(p) is None + assert self.message == self.make_message('message', self.channel, + self.data) + + # test that we reconnected to the correct channel + p.connection.disconnect() + assert wait_for_message(p) is None # should reconnect + new_data = self.data + u('new data') + r.publish(self.channel, new_data) + assert wait_for_message(p) is None + assert self.message == self.make_message('message', self.channel, + new_data) + + def test_pattern_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.psubscribe(**{self.pattern: self.message_handler}) + r.publish(self.channel, self.data) + assert wait_for_message(p) is None + assert self.message == self.make_message('pmessage', self.channel, + self.data, + pattern=self.pattern) + + # test that we reconnected to the correct pattern + p.connection.disconnect() + assert wait_for_message(p) is None # should reconnect + new_data = self.data + u('new data') + r.publish(self.channel, new_data) + assert wait_for_message(p) is None + assert self.message == self.make_message('pmessage', self.channel, + new_data, + pattern=self.pattern) + + +class TestPubSubRedisDown(object): + + def test_channel_subscribe(self, r): + r = redis.Redis(host='localhost', port=6390) + p = r.pubsub() + with pytest.raises(ConnectionError): + p.subscribe('foo') diff --git a/client/ledis-py/tests/test_scripting.py b/client/ledis-py/tests/test_scripting.py new file mode 100644 index 0000000..4849c81 --- /dev/null +++ b/client/ledis-py/tests/test_scripting.py @@ -0,0 +1,82 @@ +from __future__ import with_statement +import pytest + +from redis import exceptions +from redis._compat import b + + +multiply_script = """ +local value = redis.call('GET', KEYS[1]) +value = tonumber(value) +return value * ARGV[1]""" + + +class TestScripting(object): + @pytest.fixture(autouse=True) + def reset_scripts(self, r): + r.script_flush() + + def test_eval(self, r): + r.set('a', 2) + # 2 * 3 == 6 + assert r.eval(multiply_script, 1, 'a', 3) == 6 + + def test_evalsha(self, r): + r.set('a', 2) + sha = r.script_load(multiply_script) + # 2 * 3 == 6 + assert r.evalsha(sha, 1, 'a', 3) == 6 + + def test_evalsha_script_not_loaded(self, r): + r.set('a', 2) + sha = r.script_load(multiply_script) + # remove the script from Redis's cache + r.script_flush() + with pytest.raises(exceptions.NoScriptError): + r.evalsha(sha, 1, 'a', 3) + + def test_script_loading(self, r): + # get the sha, then clear the cache + sha = r.script_load(multiply_script) + r.script_flush() + assert r.script_exists(sha) == [False] + r.script_load(multiply_script) + assert r.script_exists(sha) == [True] + + def test_script_object(self, r): + r.set('a', 2) + multiply = r.register_script(multiply_script) + assert not multiply.sha + # test evalsha fail -> script load + retry + assert multiply(keys=['a'], args=[3]) == 6 + assert multiply.sha + assert r.script_exists(multiply.sha) == [True] + # test first evalsha + assert multiply(keys=['a'], args=[3]) == 6 + + def test_script_object_in_pipeline(self, r): + multiply = r.register_script(multiply_script) + assert not multiply.sha + pipe = r.pipeline() + pipe.set('a', 2) + pipe.get('a') + multiply(keys=['a'], args=[3], client=pipe) + # even though the pipeline wasn't executed yet, we made sure the + # script was loaded and got a valid sha + assert multiply.sha + assert r.script_exists(multiply.sha) == [True] + # [SET worked, GET 'a', result of multiple script] + assert pipe.execute() == [True, b('2'), 6] + + # purge the script from redis's cache and re-run the pipeline + # the multiply script object knows it's sha, so it shouldn't get + # reloaded until pipe.execute() + r.script_flush() + pipe = r.pipeline() + pipe.set('a', 2) + pipe.get('a') + assert multiply.sha + multiply(keys=['a'], args=[3], client=pipe) + assert r.script_exists(multiply.sha) == [False] + # [SET worked, GET 'a', result of multiple script] + assert pipe.execute() == [True, b('2'), 6] diff --git a/client/ledis-py/tests/test_sentinel.py b/client/ledis-py/tests/test_sentinel.py new file mode 100644 index 0000000..0a6e98b --- /dev/null +++ b/client/ledis-py/tests/test_sentinel.py @@ -0,0 +1,173 @@ +from __future__ import with_statement +import pytest + +from redis import exceptions +from redis.sentinel import (Sentinel, SentinelConnectionPool, + MasterNotFoundError, SlaveNotFoundError) +from redis._compat import next +import redis.sentinel + + +class SentinelTestClient(object): + def __init__(self, cluster, id): + self.cluster = cluster + self.id = id + + def sentinel_masters(self): + self.cluster.connection_error_if_down(self) + return {self.cluster.service_name: self.cluster.master} + + def sentinel_slaves(self, master_name): + self.cluster.connection_error_if_down(self) + if master_name != self.cluster.service_name: + return [] + return self.cluster.slaves + + +class SentinelTestCluster(object): + def __init__(self, service_name='mymaster', ip='127.0.0.1', port=6379): + self.clients = {} + self.master = { + 'ip': ip, + 'port': port, + 'is_master': True, + 'is_sdown': False, + 'is_odown': False, + 'num-other-sentinels': 0, + } + self.service_name = service_name + self.slaves = [] + self.nodes_down = set() + + def connection_error_if_down(self, node): + if node.id in self.nodes_down: + raise exceptions.ConnectionError + + def client(self, host, port, **kwargs): + return SentinelTestClient(self, (host, port)) + + +@pytest.fixture() +def cluster(request): + def teardown(): + redis.sentinel.StrictRedis = saved_StrictRedis + cluster = SentinelTestCluster() + saved_StrictRedis = redis.sentinel.StrictRedis + redis.sentinel.StrictRedis = cluster.client + request.addfinalizer(teardown) + return cluster + + +@pytest.fixture() +def sentinel(request, cluster): + return Sentinel([('foo', 26379), ('bar', 26379)]) + + +def test_discover_master(sentinel): + address = sentinel.discover_master('mymaster') + assert address == ('127.0.0.1', 6379) + + +def test_discover_master_error(sentinel): + with pytest.raises(MasterNotFoundError): + sentinel.discover_master('xxx') + + +def test_discover_master_sentinel_down(cluster, sentinel): + # Put first sentinel 'foo' down + cluster.nodes_down.add(('foo', 26379)) + address = sentinel.discover_master('mymaster') + assert address == ('127.0.0.1', 6379) + # 'bar' is now first sentinel + assert sentinel.sentinels[0].id == ('bar', 26379) + + +def test_master_min_other_sentinels(cluster): + sentinel = Sentinel([('foo', 26379)], min_other_sentinels=1) + # min_other_sentinels + with pytest.raises(MasterNotFoundError): + sentinel.discover_master('mymaster') + cluster.master['num-other-sentinels'] = 2 + address = sentinel.discover_master('mymaster') + assert address == ('127.0.0.1', 6379) + + +def test_master_odown(cluster, sentinel): + cluster.master['is_odown'] = True + with pytest.raises(MasterNotFoundError): + sentinel.discover_master('mymaster') + + +def test_master_sdown(cluster, sentinel): + cluster.master['is_sdown'] = True + with pytest.raises(MasterNotFoundError): + sentinel.discover_master('mymaster') + + +def test_discover_slaves(cluster, sentinel): + assert sentinel.discover_slaves('mymaster') == [] + + cluster.slaves = [ + {'ip': 'slave0', 'port': 1234, 'is_odown': False, 'is_sdown': False}, + {'ip': 'slave1', 'port': 1234, 'is_odown': False, 'is_sdown': False}, + ] + assert sentinel.discover_slaves('mymaster') == [ + ('slave0', 1234), ('slave1', 1234)] + + # slave0 -> ODOWN + cluster.slaves[0]['is_odown'] = True + assert sentinel.discover_slaves('mymaster') == [ + ('slave1', 1234)] + + # slave1 -> SDOWN + cluster.slaves[1]['is_sdown'] = True + assert sentinel.discover_slaves('mymaster') == [] + + cluster.slaves[0]['is_odown'] = False + cluster.slaves[1]['is_sdown'] = False + + # node0 -> DOWN + cluster.nodes_down.add(('foo', 26379)) + assert sentinel.discover_slaves('mymaster') == [ + ('slave0', 1234), ('slave1', 1234)] + + +def test_master_for(cluster, sentinel): + master = sentinel.master_for('mymaster', db=9) + assert master.ping() + assert master.connection_pool.master_address == ('127.0.0.1', 6379) + + # Use internal connection check + master = sentinel.master_for('mymaster', db=9, check_connection=True) + assert master.ping() + + +def test_slave_for(cluster, sentinel): + cluster.slaves = [ + {'ip': '127.0.0.1', 'port': 6379, + 'is_odown': False, 'is_sdown': False}, + ] + slave = sentinel.slave_for('mymaster', db=9) + assert slave.ping() + + +def test_slave_for_slave_not_found_error(cluster, sentinel): + cluster.master['is_odown'] = True + slave = sentinel.slave_for('mymaster', db=9) + with pytest.raises(SlaveNotFoundError): + slave.ping() + + +def test_slave_round_robin(cluster, sentinel): + cluster.slaves = [ + {'ip': 'slave0', 'port': 6379, 'is_odown': False, 'is_sdown': False}, + {'ip': 'slave1', 'port': 6379, 'is_odown': False, 'is_sdown': False}, + ] + pool = SentinelConnectionPool('mymaster', sentinel) + rotator = pool.rotate_slaves() + assert next(rotator) in (('slave0', 6379), ('slave1', 6379)) + assert next(rotator) in (('slave0', 6379), ('slave1', 6379)) + # Fallback to master + assert next(rotator) == ('127.0.0.1', 6379) + with pytest.raises(SlaveNotFoundError): + next(rotator) From 13a2e0e69062468df90dda5ac07d545ee5d21b4e Mon Sep 17 00:00:00 2001 From: silentsai Date: Wed, 18 Jun 2014 15:47:26 +0800 Subject: [PATCH 4/9] fix bug - expire bk routine keep running after ledis close --- ledis/ledis.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ledis/ledis.go b/ledis/ledis.go index baf152b..f011d65 100644 --- a/ledis/ledis.go +++ b/ledis/ledis.go @@ -46,6 +46,7 @@ type Ledis struct { binlog *BinLog quit chan struct{} + jobs *sync.WaitGroup } func Open(configJson json.RawMessage) (*Ledis, error) { @@ -75,6 +76,7 @@ func OpenWithConfig(cfg *Config) (*Ledis, error) { l := new(Ledis) l.quit = make(chan struct{}) + l.jobs = new(sync.WaitGroup) l.ldb = ldb @@ -118,6 +120,7 @@ func newDB(l *Ledis, index uint8) *DB { func (l *Ledis) Close() { close(l.quit) + l.jobs.Wait() l.ldb.Close() @@ -156,19 +159,23 @@ func (l *Ledis) activeExpireCycle() { executors[i] = db.newEliminator() } + l.jobs.Add(1) go func() { tick := time.NewTicker(1 * time.Second) - for { + end := false + for !end { select { case <-tick.C: for _, eli := range executors { eli.active() } case <-l.quit: + end = true break } } tick.Stop() + l.jobs.Done() }() } From 48e09a27273051b6407bec819dfc851336961586 Mon Sep 17 00:00:00 2001 From: siddontang Date: Thu, 19 Jun 2014 17:19:40 +0800 Subject: [PATCH 5/9] use our own leveldb --- bootstrap.sh | 1 - etc/ledis.json | 3 +- ledis/dump.go | 6 +- ledis/dump_test.go | 4 +- ledis/ledis.go | 2 +- ledis/replication_test.go | 4 +- ledis/t_hash.go | 27 ++- ledis/t_hash_test.go | 19 +- ledis/t_kv.go | 17 +- ledis/t_kv_test.go | 21 ++- ledis/t_list.go | 8 +- ledis/t_ttl.go | 6 +- ledis/t_zset.go | 22 +-- ledis/tx.go | 2 +- leveldb/batch.go | 59 ++++++ leveldb/cache.go | 18 ++ leveldb/db.go | 328 +++++++++++++++++++++++++++++++++ leveldb/filterpolicy.go | 19 ++ leveldb/iterator.go | 229 +++++++++++++++++++++++ leveldb/leveldb_test.go | 259 ++++++++++++++++++++++++++ leveldb/levigo-license | 7 + leveldb/options.go | 128 +++++++++++++ leveldb/snapshot.go | 54 ++++++ leveldb/util.go | 43 +++++ server/cmd_hash.go | 2 +- server/cmd_replication_test.go | 4 +- 26 files changed, 1233 insertions(+), 59 deletions(-) create mode 100644 leveldb/batch.go create mode 100644 leveldb/cache.go create mode 100644 leveldb/db.go create mode 100644 leveldb/filterpolicy.go create mode 100644 leveldb/iterator.go create mode 100644 leveldb/leveldb_test.go create mode 100644 leveldb/levigo-license create mode 100644 leveldb/options.go create mode 100644 leveldb/snapshot.go create mode 100644 leveldb/util.go diff --git a/bootstrap.sh b/bootstrap.sh index b8d0c41..c5aaefe 100644 --- a/bootstrap.sh +++ b/bootstrap.sh @@ -2,6 +2,5 @@ . ./dev.sh -go get -u github.com/siddontang/go-leveldb/leveldb go get -u github.com/siddontang/go-log/log go get -u github.com/garyburd/redigo/redis diff --git a/etc/ledis.json b/etc/ledis.json index 8f230c3..0d93e88 100644 --- a/etc/ledis.json +++ b/etc/ledis.json @@ -6,7 +6,8 @@ "compression": false, "block_size": 32768, "write_buffer_size": 67108864, - "cache_size": 524288000 + "cache_size": 524288000, + "max_open_files":1024 } }, diff --git a/ledis/dump.go b/ledis/dump.go index 47bca19..16354c8 100644 --- a/ledis/dump.go +++ b/ledis/dump.go @@ -4,7 +4,7 @@ import ( "bufio" "bytes" "encoding/binary" - "github.com/siddontang/go-leveldb/leveldb" + "github.com/siddontang/ledisdb/leveldb" "io" "os" ) @@ -73,7 +73,9 @@ func (l *Ledis) Dump(w io.Writer) error { return err } - it := sp.Iterator(nil, nil, leveldb.RangeClose, 0, -1) + it := sp.NewIterator() + it.SeekToFirst() + var key []byte var value []byte for ; it.Valid(); it.Next() { diff --git a/ledis/dump_test.go b/ledis/dump_test.go index f15f8f4..c8bfe36 100644 --- a/ledis/dump_test.go +++ b/ledis/dump_test.go @@ -2,7 +2,7 @@ package ledis import ( "bytes" - "github.com/siddontang/go-leveldb/leveldb" + "github.com/siddontang/ledisdb/leveldb" "os" "testing" ) @@ -59,7 +59,7 @@ func TestDump(t *testing.T) { t.Fatal(err) } - it := master.ldb.Iterator(nil, nil, leveldb.RangeClose, 0, -1) + it := master.ldb.RangeLimitIterator(nil, nil, leveldb.RangeClose, 0, -1) for ; it.Valid(); it.Next() { key := it.Key() value := it.Value() diff --git a/ledis/ledis.go b/ledis/ledis.go index f011d65..668098c 100644 --- a/ledis/ledis.go +++ b/ledis/ledis.go @@ -3,8 +3,8 @@ package ledis import ( "encoding/json" "fmt" - "github.com/siddontang/go-leveldb/leveldb" "github.com/siddontang/go-log/log" + "github.com/siddontang/ledisdb/leveldb" "path" "sync" "time" diff --git a/ledis/replication_test.go b/ledis/replication_test.go index 21d4dbc..7e9aa20 100644 --- a/ledis/replication_test.go +++ b/ledis/replication_test.go @@ -3,14 +3,14 @@ package ledis import ( "bytes" "fmt" - "github.com/siddontang/go-leveldb/leveldb" + "github.com/siddontang/ledisdb/leveldb" "os" "path" "testing" ) func checkLedisEqual(master *Ledis, slave *Ledis) error { - it := master.ldb.Iterator(nil, nil, leveldb.RangeClose, 0, -1) + it := master.ldb.RangeLimitIterator(nil, nil, leveldb.RangeClose, 0, -1) for ; it.Valid(); it.Next() { key := it.Key() value := it.Value() diff --git a/ledis/t_hash.go b/ledis/t_hash.go index d64a1f6..4fa4cdc 100644 --- a/ledis/t_hash.go +++ b/ledis/t_hash.go @@ -3,7 +3,7 @@ package ledis import ( "encoding/binary" "errors" - "github.com/siddontang/go-leveldb/leveldb" + "github.com/siddontang/ledisdb/leveldb" "time" ) @@ -132,7 +132,7 @@ func (db *DB) hDelete(t *tx, key []byte) int64 { stop := db.hEncodeStopKey(key) var num int64 = 0 - it := db.db.Iterator(start, stop, leveldb.RangeROpen, 0, -1) + it := db.db.RangeLimitIterator(start, stop, leveldb.RangeROpen, 0, -1) for ; it.Valid(); it.Next() { t.Delete(it.Key()) num++ @@ -232,10 +232,11 @@ func (db *DB) HMset(key []byte, args ...FVPair) error { return err } -func (db *DB) HMget(key []byte, args [][]byte) ([]interface{}, error) { +func (db *DB) HMget(key []byte, args ...[]byte) ([]interface{}, error) { var ek []byte - var v []byte - var err error + + it := db.db.NewIterator() + defer it.Close() r := make([]interface{}, len(args)) for i := 0; i < len(args); i++ { @@ -245,11 +246,7 @@ func (db *DB) HMget(key []byte, args [][]byte) ([]interface{}, error) { ek = db.hEncodeHashKey(key, args[i]) - if v, err = db.db.Get(ek); err != nil { - return nil, err - } - - r[i] = v + r[i] = it.Find(ek) } return r, nil @@ -355,7 +352,7 @@ func (db *DB) HGetAll(key []byte) ([]interface{}, error) { v := make([]interface{}, 0, 16) - it := db.db.Iterator(start, stop, leveldb.RangeROpen, 0, -1) + it := db.db.RangeLimitIterator(start, stop, leveldb.RangeROpen, 0, -1) for ; it.Valid(); it.Next() { _, k, err := db.hDecodeHashKey(it.Key()) if err != nil { @@ -380,7 +377,7 @@ func (db *DB) HKeys(key []byte) ([]interface{}, error) { v := make([]interface{}, 0, 16) - it := db.db.Iterator(start, stop, leveldb.RangeROpen, 0, -1) + it := db.db.RangeLimitIterator(start, stop, leveldb.RangeROpen, 0, -1) for ; it.Valid(); it.Next() { _, k, err := db.hDecodeHashKey(it.Key()) if err != nil { @@ -404,7 +401,7 @@ func (db *DB) HValues(key []byte) ([]interface{}, error) { v := make([]interface{}, 0, 16) - it := db.db.Iterator(start, stop, leveldb.RangeROpen, 0, -1) + it := db.db.RangeLimitIterator(start, stop, leveldb.RangeROpen, 0, -1) for ; it.Valid(); it.Next() { v = append(v, it.Value()) } @@ -443,7 +440,7 @@ func (db *DB) hFlush() (drop int64, err error) { maxKey[0] = db.index maxKey[1] = hSizeType + 1 - it := db.db.Iterator(minKey, maxKey, leveldb.RangeROpen, 0, -1) + it := db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeROpen, 0, -1) for ; it.Valid(); it.Next() { t.Delete(it.Key()) drop++ @@ -485,7 +482,7 @@ func (db *DB) HScan(key []byte, field []byte, count int, inclusive bool) ([]FVPa rangeType = leveldb.RangeOpen } - it := db.db.Iterator(minKey, maxKey, rangeType, 0, count) + it := db.db.RangeLimitIterator(minKey, maxKey, rangeType, 0, count) for ; it.Valid(); it.Next() { if _, f, err := db.hDecodeHashKey(it.Key()); err != nil { continue diff --git a/ledis/t_hash_test.go b/ledis/t_hash_test.go index 4648c09..a663c2e 100644 --- a/ledis/t_hash_test.go +++ b/ledis/t_hash_test.go @@ -32,11 +32,28 @@ func TestDBHash(t *testing.T) { key := []byte("testdb_hash_a") - if n, err := db.HSet(key, []byte("a"), []byte("hello world")); err != nil { + if n, err := db.HSet(key, []byte("a"), []byte("hello world 1")); err != nil { t.Fatal(err) } else if n != 1 { t.Fatal(n) } + + if n, err := db.HSet(key, []byte("b"), []byte("hello world 2")); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Fatal(n) + } + + ay, _ := db.HMget(key, []byte("a"), []byte("b")) + + if v1, _ := ay[0].([]byte); string(v1) != "hello world 1" { + t.Fatal(string(v1)) + } + + if v2, _ := ay[1].([]byte); string(v2) != "hello world 2" { + t.Fatal(string(v2)) + } + } func TestDBHScan(t *testing.T) { diff --git a/ledis/t_kv.go b/ledis/t_kv.go index 2ae2a63..3008fb8 100644 --- a/ledis/t_kv.go +++ b/ledis/t_kv.go @@ -2,7 +2,7 @@ package ledis import ( "errors" - "github.com/siddontang/go-leveldb/leveldb" + "github.com/siddontang/ledisdb/leveldb" "time" ) @@ -204,18 +204,15 @@ func (db *DB) IncryBy(key []byte, increment int64) (int64, error) { func (db *DB) MGet(keys ...[]byte) ([]interface{}, error) { values := make([]interface{}, len(keys)) - var err error - var value []byte + it := db.db.NewIterator() + defer it.Close() + for i := range keys { if err := checkKeySize(keys[i]); err != nil { return nil, err } - if value, err = db.db.Get(db.encodeKVKey(keys[i])); err != nil { - return nil, err - } - - values[i] = value + values[i] = it.Find(db.encodeKVKey(keys[i])) } return values, nil @@ -319,7 +316,7 @@ func (db *DB) flush() (drop int64, err error) { minKey := db.encodeKVMinKey() maxKey := db.encodeKVMaxKey() - it := db.db.Iterator(minKey, maxKey, leveldb.RangeROpen, 0, -1) + it := db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeROpen, 0, -1) for ; it.Valid(); it.Next() { t.Delete(it.Key()) drop++ @@ -362,7 +359,7 @@ func (db *DB) Scan(key []byte, count int, inclusive bool) ([]KVPair, error) { rangeType = leveldb.RangeOpen } - it := db.db.Iterator(minKey, maxKey, rangeType, 0, count) + it := db.db.RangeLimitIterator(minKey, maxKey, rangeType, 0, count) for ; it.Valid(); it.Next() { if key, err := db.decodeKVKey(it.Key()); err != nil { continue diff --git a/ledis/t_kv_test.go b/ledis/t_kv_test.go index 0252421..9088bd1 100644 --- a/ledis/t_kv_test.go +++ b/ledis/t_kv_test.go @@ -19,11 +19,28 @@ func TestKVCodec(t *testing.T) { func TestDBKV(t *testing.T) { db := getTestDB() - key := []byte("testdb_kv_a") + key1 := []byte("testdb_kv_a") - if err := db.Set(key, []byte("hello world")); err != nil { + if err := db.Set(key1, []byte("hello world 1")); err != nil { t.Fatal(err) } + + key2 := []byte("testdb_kv_b") + + if err := db.Set(key2, []byte("hello world 2")); err != nil { + t.Fatal(err) + } + + ay, _ := db.MGet(key1, key2) + + if v1, _ := ay[0].([]byte); string(v1) != "hello world 1" { + t.Fatal(string(v1)) + } + + if v2, _ := ay[1].([]byte); string(v2) != "hello world 2" { + t.Fatal(string(v2)) + } + } func TestDBScan(t *testing.T) { diff --git a/ledis/t_list.go b/ledis/t_list.go index df5f576..e5d2f42 100644 --- a/ledis/t_list.go +++ b/ledis/t_list.go @@ -3,7 +3,7 @@ package ledis import ( "encoding/binary" "errors" - "github.com/siddontang/go-leveldb/leveldb" + "github.com/siddontang/ledisdb/leveldb" "time" ) @@ -200,7 +200,7 @@ func (db *DB) lDelete(t *tx, key []byte) int64 { startKey := db.lEncodeListKey(key, headSeq) stopKey := db.lEncodeListKey(key, tailSeq) - it := db.db.Iterator(startKey, stopKey, leveldb.RangeClose, 0, -1) + it := db.db.RangeLimitIterator(startKey, stopKey, leveldb.RangeClose, 0, -1) for ; it.Valid(); it.Next() { t.Delete(it.Key()) num++ @@ -361,7 +361,7 @@ func (db *DB) LRange(key []byte, start int32, stop int32) ([]interface{}, error) startKey := db.lEncodeListKey(key, startSeq) stopKey := db.lEncodeListKey(key, stopSeq) - it := db.db.Iterator(startKey, stopKey, leveldb.RangeClose, 0, -1) + it := db.db.RangeLimitIterator(startKey, stopKey, leveldb.RangeClose, 0, -1) for ; it.Valid(); it.Next() { v = append(v, it.Value()) } @@ -408,7 +408,7 @@ func (db *DB) lFlush() (drop int64, err error) { maxKey[0] = db.index maxKey[1] = lMetaType + 1 - it := db.db.Iterator(minKey, maxKey, leveldb.RangeROpen, 0, -1) + it := db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeROpen, 0, -1) for ; it.Valid(); it.Next() { t.Delete(it.Key()) drop++ diff --git a/ledis/t_ttl.go b/ledis/t_ttl.go index b353ce7..df64118 100644 --- a/ledis/t_ttl.go +++ b/ledis/t_ttl.go @@ -3,7 +3,7 @@ package ledis import ( "encoding/binary" "errors" - "github.com/siddontang/go-leveldb/leveldb" + "github.com/siddontang/ledisdb/leveldb" "time" ) @@ -119,7 +119,7 @@ func (db *DB) expFlush(t *tx, expType byte) (err error) { maxKey[0] = db.index maxKey[1] = expMetaType + 1 - it := db.db.Iterator(minKey, maxKey, leveldb.RangeROpen, 0, -1) + it := db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeROpen, 0, -1) for ; it.Valid(); it.Next() { t.Delete(it.Key()) drop++ @@ -173,7 +173,7 @@ func (eli *elimination) active() { continue } - it := db.db.Iterator(minKey, maxKey, leveldb.RangeROpen, 0, -1) + it := db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeROpen, 0, -1) for it.Valid() { for i := 1; i < 512 && it.Valid(); i++ { expKeys = append(expKeys, it.Key(), it.Value()) diff --git a/ledis/t_zset.go b/ledis/t_zset.go index 45f8cfa..a4ea355 100644 --- a/ledis/t_zset.go +++ b/ledis/t_zset.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/binary" "errors" - "github.com/siddontang/go-leveldb/leveldb" + "github.com/siddontang/ledisdb/leveldb" "time" ) @@ -434,7 +434,7 @@ func (db *DB) ZCount(key []byte, min int64, max int64) (int64, error) { rangeType := leveldb.RangeROpen - it := db.db.Iterator(minKey, maxKey, rangeType, 0, -1) + it := db.db.RangeLimitIterator(minKey, maxKey, rangeType, 0, -1) var n int64 = 0 for ; it.Valid(); it.Next() { n++ @@ -459,16 +459,16 @@ func (db *DB) zrank(key []byte, member []byte, reverse bool) (int64, error) { if s, err := Int64(v, err); err != nil { return 0, err } else { - var it *leveldb.Iterator + var it *leveldb.RangeLimitIterator sk := db.zEncodeScoreKey(key, member, s) if !reverse { minKey := db.zEncodeStartScoreKey(key, MinScore) - it = db.db.Iterator(minKey, sk, leveldb.RangeClose, 0, -1) + it = db.db.RangeLimitIterator(minKey, sk, leveldb.RangeClose, 0, -1) } else { maxKey := db.zEncodeStopScoreKey(key, MaxScore) - it = db.db.RevIterator(sk, maxKey, leveldb.RangeClose, 0, -1) + it = db.db.RevRangeLimitIterator(sk, maxKey, leveldb.RangeClose, 0, -1) } var lastKey []byte = nil @@ -492,14 +492,14 @@ func (db *DB) zrank(key []byte, member []byte, reverse bool) (int64, error) { return -1, nil } -func (db *DB) zIterator(key []byte, min int64, max int64, offset int, limit int, reverse bool) *leveldb.Iterator { +func (db *DB) zIterator(key []byte, min int64, max int64, offset int, limit int, reverse bool) *leveldb.RangeLimitIterator { minKey := db.zEncodeStartScoreKey(key, min) maxKey := db.zEncodeStopScoreKey(key, max) if !reverse { - return db.db.Iterator(minKey, maxKey, leveldb.RangeClose, offset, limit) + return db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeClose, offset, limit) } else { - return db.db.RevIterator(minKey, maxKey, leveldb.RangeClose, offset, limit) + return db.db.RevRangeLimitIterator(minKey, maxKey, leveldb.RangeClose, offset, limit) } } @@ -567,7 +567,7 @@ func (db *DB) zRange(key []byte, min int64, max int64, withScores bool, offset i } v := make([]interface{}, 0, nv) - var it *leveldb.Iterator + var it *leveldb.RangeLimitIterator //if reverse and offset is 0, limit < 0, we may use forward iterator then reverse //because leveldb iterator prev is slower than next @@ -745,7 +745,7 @@ func (db *DB) zFlush() (drop int64, err error) { maxKey[0] = db.index maxKey[1] = zScoreType + 1 - it := db.db.Iterator(minKey, maxKey, leveldb.RangeROpen, 0, -1) + it := db.db.RangeLimitIterator(minKey, maxKey, leveldb.RangeROpen, 0, -1) for ; it.Valid(); it.Next() { t.Delete(it.Key()) drop++ @@ -788,7 +788,7 @@ func (db *DB) ZScan(key []byte, member []byte, count int, inclusive bool) ([]Sco rangeType = leveldb.RangeOpen } - it := db.db.Iterator(minKey, maxKey, rangeType, 0, count) + it := db.db.RangeLimitIterator(minKey, maxKey, rangeType, 0, count) for ; it.Valid(); it.Next() { if _, m, err := db.zDecodeSetKey(it.Key()); err != nil { continue diff --git a/ledis/tx.go b/ledis/tx.go index fa7379b..0fe716a 100644 --- a/ledis/tx.go +++ b/ledis/tx.go @@ -1,7 +1,7 @@ package ledis import ( - "github.com/siddontang/go-leveldb/leveldb" + "github.com/siddontang/ledisdb/leveldb" "sync" ) diff --git a/leveldb/batch.go b/leveldb/batch.go new file mode 100644 index 0000000..f24ec65 --- /dev/null +++ b/leveldb/batch.go @@ -0,0 +1,59 @@ +package leveldb + +// #cgo LDFLAGS: -lleveldb +// #include "leveldb/c.h" +import "C" + +import ( + "unsafe" +) + +type WriteBatch struct { + db *DB + wbatch *C.leveldb_writebatch_t +} + +func (w *WriteBatch) Close() { + C.leveldb_writebatch_destroy(w.wbatch) +} + +func (w *WriteBatch) Put(key, value []byte) { + var k, v *C.char + if len(key) != 0 { + k = (*C.char)(unsafe.Pointer(&key[0])) + } + if len(value) != 0 { + v = (*C.char)(unsafe.Pointer(&value[0])) + } + + lenk := len(key) + lenv := len(value) + + C.leveldb_writebatch_put(w.wbatch, k, C.size_t(lenk), v, C.size_t(lenv)) +} + +func (w *WriteBatch) Delete(key []byte) { + C.leveldb_writebatch_delete(w.wbatch, + (*C.char)(unsafe.Pointer(&key[0])), C.size_t(len(key))) +} + +func (w *WriteBatch) Commit() error { + return w.commit(w.db.writeOpts) +} + +func (w *WriteBatch) SyncCommit() error { + return w.commit(w.db.syncWriteOpts) +} + +func (w *WriteBatch) Rollback() { + C.leveldb_writebatch_clear(w.wbatch) +} + +func (w *WriteBatch) commit(wb *WriteOptions) error { + var errStr *C.char + C.leveldb_write(w.db.db, wb.Opt, w.wbatch, &errStr) + if errStr != nil { + return saveError(errStr) + } + return nil +} diff --git a/leveldb/cache.go b/leveldb/cache.go new file mode 100644 index 0000000..3fbcf0d --- /dev/null +++ b/leveldb/cache.go @@ -0,0 +1,18 @@ +package leveldb + +// #cgo LDFLAGS: -lleveldb +// #include +// #include "leveldb/c.h" +import "C" + +type Cache struct { + Cache *C.leveldb_cache_t +} + +func NewLRUCache(capacity int) *Cache { + return &Cache{C.leveldb_cache_create_lru(C.size_t(capacity))} +} + +func (c *Cache) Close() { + C.leveldb_cache_destroy(c.Cache) +} diff --git a/leveldb/db.go b/leveldb/db.go new file mode 100644 index 0000000..eb68d5a --- /dev/null +++ b/leveldb/db.go @@ -0,0 +1,328 @@ +package leveldb + +/* +#cgo LDFLAGS: -lleveldb +#include +*/ +import "C" + +import ( + "encoding/json" + "os" + "unsafe" +) + +const defaultFilterBits int = 10 + +type Config struct { + Path string `json:"path"` + + Compression bool `json:"compression"` + BlockSize int `json:"block_size"` + WriteBufferSize int `json:"write_buffer_size"` + CacheSize int `json:"cache_size"` + MaxOpenFiles int `json:"max_open_files"` +} + +type DB struct { + cfg *Config + + db *C.leveldb_t + + opts *Options + + //for default read and write options + readOpts *ReadOptions + writeOpts *WriteOptions + iteratorOpts *ReadOptions + + syncWriteOpts *WriteOptions + + cache *Cache + + filter *FilterPolicy +} + +func Open(configJson json.RawMessage) (*DB, error) { + cfg := new(Config) + err := json.Unmarshal(configJson, cfg) + if err != nil { + return nil, err + } + + return OpenWithConfig(cfg) +} + +func OpenWithConfig(cfg *Config) (*DB, error) { + if err := os.MkdirAll(cfg.Path, os.ModePerm); err != nil { + return nil, err + } + + db := new(DB) + db.cfg = cfg + + if err := db.open(); err != nil { + return nil, err + } + + return db, nil +} + +func (db *DB) open() error { + db.opts = db.initOptions(db.cfg) + + db.readOpts = NewReadOptions() + db.writeOpts = NewWriteOptions() + + db.iteratorOpts = NewReadOptions() + db.iteratorOpts.SetFillCache(false) + + db.syncWriteOpts = NewWriteOptions() + db.syncWriteOpts.SetSync(true) + + var errStr *C.char + ldbname := C.CString(db.cfg.Path) + defer C.leveldb_free(unsafe.Pointer(ldbname)) + + db.db = C.leveldb_open(db.opts.Opt, ldbname, &errStr) + if errStr != nil { + return saveError(errStr) + } + return nil +} + +func (db *DB) initOptions(cfg *Config) *Options { + opts := NewOptions() + + opts.SetCreateIfMissing(true) + + if cfg.CacheSize <= 0 { + cfg.CacheSize = 4 * 1024 * 1024 + } + + db.cache = NewLRUCache(cfg.CacheSize) + opts.SetCache(db.cache) + + //we must use bloomfilter + db.filter = NewBloomFilter(defaultFilterBits) + opts.SetFilterPolicy(db.filter) + + if !cfg.Compression { + opts.SetCompression(NoCompression) + } + + if cfg.BlockSize <= 0 { + cfg.BlockSize = 4 * 1024 + } + + opts.SetBlockSize(cfg.BlockSize) + + if cfg.WriteBufferSize <= 0 { + cfg.WriteBufferSize = 4 * 1024 * 1024 + } + + opts.SetWriteBufferSize(cfg.WriteBufferSize) + + if cfg.MaxOpenFiles < 1024 { + cfg.MaxOpenFiles = 1024 + } + + opts.SetMaxOpenFiles(cfg.MaxOpenFiles) + + return opts +} + +func (db *DB) Close() { + C.leveldb_close(db.db) + db.db = nil + + db.opts.Close() + + if db.cache != nil { + db.cache.Close() + } + + if db.filter != nil { + db.filter.Close() + } + + db.readOpts.Close() + db.writeOpts.Close() + db.iteratorOpts.Close() + db.syncWriteOpts.Close() +} + +func (db *DB) Destroy() error { + path := db.cfg.Path + + db.Close() + + opts := NewOptions() + defer opts.Close() + + var errStr *C.char + ldbname := C.CString(path) + defer C.leveldb_free(unsafe.Pointer(ldbname)) + + C.leveldb_destroy_db(opts.Opt, ldbname, &errStr) + if errStr != nil { + return saveError(errStr) + } + return nil +} + +func (db *DB) Clear() error { + bc := db.NewWriteBatch() + defer bc.Close() + + var err error + it := db.NewIterator() + it.SeekToFirst() + + num := 0 + for ; it.Valid(); it.Next() { + bc.Delete(it.Key()) + num++ + if num == 1000 { + num = 0 + if err = bc.Commit(); err != nil { + return err + } + } + } + + err = bc.Commit() + + return err +} + +func (db *DB) Put(key, value []byte) error { + return db.put(db.writeOpts, key, value) +} + +func (db *DB) SyncPut(key, value []byte) error { + return db.put(db.syncWriteOpts, key, value) +} + +func (db *DB) Get(key []byte) ([]byte, error) { + return db.get(db.readOpts, key) +} + +func (db *DB) Delete(key []byte) error { + return db.delete(db.writeOpts, key) +} + +func (db *DB) SyncDelete(key []byte) error { + return db.delete(db.syncWriteOpts, key) +} + +func (db *DB) NewWriteBatch() *WriteBatch { + wb := &WriteBatch{ + db: db, + wbatch: C.leveldb_writebatch_create(), + } + return wb +} + +func (db *DB) NewSnapshot() *Snapshot { + s := &Snapshot{ + db: db, + snap: C.leveldb_create_snapshot(db.db), + readOpts: NewReadOptions(), + iteratorOpts: NewReadOptions(), + } + + s.readOpts.SetSnapshot(s) + s.iteratorOpts.SetSnapshot(s) + s.iteratorOpts.SetFillCache(false) + + return s +} + +func (db *DB) NewIterator() *Iterator { + it := new(Iterator) + + it.it = C.leveldb_create_iterator(db.db, db.iteratorOpts.Opt) + + return it +} + +func (db *DB) RangeIterator(min []byte, max []byte, rangeType uint8) *RangeLimitIterator { + return newRangeLimitIterator(db.NewIterator(), &Range{min, max, rangeType}, 0, -1, IteratorForward) +} + +func (db *DB) RevRangeIterator(min []byte, max []byte, rangeType uint8) *RangeLimitIterator { + return newRangeLimitIterator(db.NewIterator(), &Range{min, max, rangeType}, 0, -1, IteratorBackward) +} + +//limit < 0, unlimit +//offset must >= 0, if < 0, will get nothing +func (db *DB) RangeLimitIterator(min []byte, max []byte, rangeType uint8, offset int, limit int) *RangeLimitIterator { + return newRangeLimitIterator(db.NewIterator(), &Range{min, max, rangeType}, offset, limit, IteratorForward) +} + +//limit < 0, unlimit +//offset must >= 0, if < 0, will get nothing +func (db *DB) RevRangeLimitIterator(min []byte, max []byte, rangeType uint8, offset int, limit int) *RangeLimitIterator { + return newRangeLimitIterator(db.NewIterator(), &Range{min, max, rangeType}, offset, limit, IteratorBackward) +} + +func (db *DB) put(wo *WriteOptions, key, value []byte) error { + var errStr *C.char + var k, v *C.char + if len(key) != 0 { + k = (*C.char)(unsafe.Pointer(&key[0])) + } + if len(value) != 0 { + v = (*C.char)(unsafe.Pointer(&value[0])) + } + + lenk := len(key) + lenv := len(value) + C.leveldb_put( + db.db, wo.Opt, k, C.size_t(lenk), v, C.size_t(lenv), &errStr) + + if errStr != nil { + return saveError(errStr) + } + return nil +} + +func (db *DB) get(ro *ReadOptions, key []byte) ([]byte, error) { + var errStr *C.char + var vallen C.size_t + var k *C.char + if len(key) != 0 { + k = (*C.char)(unsafe.Pointer(&key[0])) + } + + value := C.leveldb_get( + db.db, ro.Opt, k, C.size_t(len(key)), &vallen, &errStr) + + if errStr != nil { + return nil, saveError(errStr) + } + + if value == nil { + return nil, nil + } + + defer C.leveldb_free(unsafe.Pointer(value)) + return C.GoBytes(unsafe.Pointer(value), C.int(vallen)), nil +} + +func (db *DB) delete(wo *WriteOptions, key []byte) error { + var errStr *C.char + var k *C.char + if len(key) != 0 { + k = (*C.char)(unsafe.Pointer(&key[0])) + } + + C.leveldb_delete( + db.db, wo.Opt, k, C.size_t(len(key)), &errStr) + + if errStr != nil { + return saveError(errStr) + } + return nil +} diff --git a/leveldb/filterpolicy.go b/leveldb/filterpolicy.go new file mode 100644 index 0000000..b007d58 --- /dev/null +++ b/leveldb/filterpolicy.go @@ -0,0 +1,19 @@ +package leveldb + +// #cgo LDFLAGS: -lleveldb +// #include +// #include "leveldb/c.h" +import "C" + +type FilterPolicy struct { + Policy *C.leveldb_filterpolicy_t +} + +func NewBloomFilter(bitsPerKey int) *FilterPolicy { + policy := C.leveldb_filterpolicy_create_bloom(C.int(bitsPerKey)) + return &FilterPolicy{policy} +} + +func (fp *FilterPolicy) Close() { + C.leveldb_filterpolicy_destroy(fp.Policy) +} diff --git a/leveldb/iterator.go b/leveldb/iterator.go new file mode 100644 index 0000000..6f22b92 --- /dev/null +++ b/leveldb/iterator.go @@ -0,0 +1,229 @@ +package leveldb + +// #cgo LDFLAGS: -lleveldb +// #include +// #include "leveldb/c.h" +import "C" + +import ( + "bytes" + "unsafe" +) + +const ( + IteratorForward uint8 = 0 + IteratorBackward uint8 = 1 +) + +const ( + RangeClose uint8 = 0x00 + RangeLOpen uint8 = 0x01 + RangeROpen uint8 = 0x10 + RangeOpen uint8 = 0x11 +) + +//min must less or equal than max +//range type: +//close: [min, max] +//open: (min, max) +//lopen: (min, max] +//ropen: [min, max) +type Range struct { + Min []byte + Max []byte + + Type uint8 +} + +type Iterator struct { + it *C.leveldb_iterator_t +} + +func (it *Iterator) Key() []byte { + var klen C.size_t + kdata := C.leveldb_iter_key(it.it, &klen) + if kdata == nil { + return nil + } + + return C.GoBytes(unsafe.Pointer(kdata), C.int(klen)) +} + +func (it *Iterator) Value() []byte { + var vlen C.size_t + vdata := C.leveldb_iter_value(it.it, &vlen) + if vdata == nil { + return nil + } + + return C.GoBytes(unsafe.Pointer(vdata), C.int(vlen)) +} + +func (it *Iterator) Close() { + C.leveldb_iter_destroy(it.it) + it.it = nil +} + +func (it *Iterator) Valid() bool { + return ucharToBool(C.leveldb_iter_valid(it.it)) +} + +func (it *Iterator) Next() { + C.leveldb_iter_next(it.it) +} + +func (it *Iterator) Prev() { + C.leveldb_iter_prev(it.it) +} + +func (it *Iterator) SeekToFirst() { + C.leveldb_iter_seek_to_first(it.it) +} + +func (it *Iterator) SeekToLast() { + C.leveldb_iter_seek_to_last(it.it) +} + +func (it *Iterator) Seek(key []byte) { + C.leveldb_iter_seek(it.it, (*C.char)(unsafe.Pointer(&key[0])), C.size_t(len(key))) +} + +func (it *Iterator) Find(key []byte) []byte { + it.Seek(key) + if it.Valid() && bytes.Equal(it.Key(), key) { + return it.Value() + } else { + return nil + } +} + +type RangeLimitIterator struct { + it *Iterator + + r *Range + + offset int + limit int + + step int + + //0 for IteratorForward, 1 for IteratorBackward + direction uint8 +} + +func (it *RangeLimitIterator) Key() []byte { + return it.it.Key() +} + +func (it *RangeLimitIterator) Value() []byte { + return it.it.Value() +} + +func (it *RangeLimitIterator) Valid() bool { + if it.offset < 0 { + return false + } else if !it.it.Valid() { + return false + } else if it.limit >= 0 && it.step >= it.limit { + return false + } + + if it.direction == IteratorForward { + if it.r.Max != nil { + r := bytes.Compare(it.it.Key(), it.r.Max) + if it.r.Type&RangeROpen > 0 { + return !(r >= 0) + } else { + return !(r > 0) + } + } + } else { + if it.r.Min != nil { + r := bytes.Compare(it.it.Key(), it.r.Min) + if it.r.Type&RangeLOpen > 0 { + return !(r <= 0) + } else { + return !(r < 0) + } + } + } + + return true +} + +func (it *RangeLimitIterator) Next() { + it.step++ + + if it.direction == IteratorForward { + it.it.Next() + } else { + it.it.Prev() + } +} + +func (it *RangeLimitIterator) Close() { + it.it.Close() +} + +func newRangeLimitIterator(i *Iterator, r *Range, offset int, limit int, direction uint8) *RangeLimitIterator { + it := new(RangeLimitIterator) + + it.it = i + + it.r = r + it.offset = offset + it.limit = limit + it.direction = direction + + it.step = 0 + + if offset < 0 { + return it + } + + if direction == IteratorForward { + if r.Min == nil { + it.it.SeekToFirst() + } else { + it.it.Seek(r.Min) + + if r.Type&RangeLOpen > 0 { + if it.it.Valid() && bytes.Equal(it.it.Key(), r.Min) { + it.it.Next() + } + } + } + } else { + if r.Max == nil { + it.it.SeekToLast() + } else { + it.it.Seek(r.Max) + + if !it.it.Valid() { + it.it.SeekToLast() + } else { + if !bytes.Equal(it.it.Key(), r.Max) { + it.it.Prev() + } + } + + if r.Type&RangeROpen > 0 { + if it.it.Valid() && bytes.Equal(it.it.Key(), r.Max) { + it.it.Prev() + } + } + } + } + + for i := 0; i < offset; i++ { + if it.it.Valid() { + if it.direction == IteratorForward { + it.it.Next() + } else { + it.it.Prev() + } + } + } + + return it +} diff --git a/leveldb/leveldb_test.go b/leveldb/leveldb_test.go new file mode 100644 index 0000000..7289ed9 --- /dev/null +++ b/leveldb/leveldb_test.go @@ -0,0 +1,259 @@ +package leveldb + +import ( + "bytes" + "fmt" + "os" + "sync" + "testing" +) + +var testConfigJson = []byte(` + { + "path" : "./testdb", + "compression":true, + "block_size" : 32768, + "write_buffer_size" : 2097152, + "cache_size" : 20971520 + } + `) + +var testOnce sync.Once +var testDB *DB + +func getTestDB() *DB { + f := func() { + var err error + testDB, err = Open(testConfigJson) + if err != nil { + println(err.Error()) + panic(err) + } + } + + testOnce.Do(f) + return testDB +} + +func TestSimple(t *testing.T) { + db := getTestDB() + + key := []byte("key") + value := []byte("hello world") + if err := db.Put(key, value); err != nil { + t.Fatal(err) + } + + if v, err := db.Get(key); err != nil { + t.Fatal(err) + } else if !bytes.Equal(v, value) { + t.Fatal("not equal") + } + + if err := db.Delete(key); err != nil { + t.Fatal(err) + } + if v, err := db.Get(key); err != nil { + t.Fatal(err) + } else if v != nil { + t.Fatal("must nil") + } +} + +func TestBatch(t *testing.T) { + db := getTestDB() + + key1 := []byte("key1") + key2 := []byte("key2") + + value := []byte("hello world") + + db.Put(key1, value) + db.Put(key2, value) + + wb := db.NewWriteBatch() + defer wb.Close() + + wb.Delete(key2) + wb.Put(key1, []byte("hello world2")) + + if err := wb.Commit(); err != nil { + t.Fatal(err) + } + + if v, err := db.Get(key2); err != nil { + t.Fatal(err) + } else if v != nil { + t.Fatal("must nil") + } + + if v, err := db.Get(key1); err != nil { + t.Fatal(err) + } else if string(v) != "hello world2" { + t.Fatal(string(v)) + } + + wb.Delete(key1) + + wb.Rollback() + + if v, err := db.Get(key1); err != nil { + t.Fatal(err) + } else if string(v) != "hello world2" { + t.Fatal(string(v)) + } + + db.Delete(key1) +} + +func checkIterator(it *RangeLimitIterator, cv ...int) error { + v := make([]string, 0, len(cv)) + for ; it.Valid(); it.Next() { + k := it.Key() + v = append(v, string(k)) + } + + it.Close() + + if len(v) != len(cv) { + return fmt.Errorf("len error %d != %d", len(v), len(cv)) + } + + for k, i := range cv { + if fmt.Sprintf("key_%d", i) != v[k] { + return fmt.Errorf("%s, %d", v[k], i) + } + } + + return nil +} + +func TestIterator(t *testing.T) { + db := getTestDB() + + db.Clear() + + for i := 0; i < 10; i++ { + key := []byte(fmt.Sprintf("key_%d", i)) + value := []byte("") + db.Put(key, value) + } + + var it *RangeLimitIterator + + k := func(i int) []byte { + return []byte(fmt.Sprintf("key_%d", i)) + } + + it = db.RangeLimitIterator(k(1), k(5), RangeClose, 0, -1) + if err := checkIterator(it, 1, 2, 3, 4, 5); err != nil { + t.Fatal(err) + } + + it = db.RangeLimitIterator(k(1), k(5), RangeClose, 1, 3) + if err := checkIterator(it, 2, 3, 4); err != nil { + t.Fatal(err) + } + + it = db.RangeLimitIterator(k(1), k(5), RangeLOpen, 0, -1) + if err := checkIterator(it, 2, 3, 4, 5); err != nil { + t.Fatal(err) + } + + it = db.RangeLimitIterator(k(1), k(5), RangeROpen, 0, -1) + if err := checkIterator(it, 1, 2, 3, 4); err != nil { + t.Fatal(err) + } + + it = db.RangeLimitIterator(k(1), k(5), RangeOpen, 0, -1) + if err := checkIterator(it, 2, 3, 4); err != nil { + t.Fatal(err) + } + + it = db.RevRangeLimitIterator(k(1), k(5), RangeClose, 0, -1) + if err := checkIterator(it, 5, 4, 3, 2, 1); err != nil { + t.Fatal(err) + } + + it = db.RevRangeLimitIterator(k(1), k(5), RangeClose, 1, 3) + if err := checkIterator(it, 4, 3, 2); err != nil { + t.Fatal(err) + } + + it = db.RevRangeLimitIterator(k(1), k(5), RangeLOpen, 0, -1) + if err := checkIterator(it, 5, 4, 3, 2); err != nil { + t.Fatal(err) + } + + it = db.RevRangeLimitIterator(k(1), k(5), RangeROpen, 0, -1) + if err := checkIterator(it, 4, 3, 2, 1); err != nil { + t.Fatal(err) + } + + it = db.RevRangeLimitIterator(k(1), k(5), RangeOpen, 0, -1) + if err := checkIterator(it, 4, 3, 2); err != nil { + t.Fatal(err) + } +} + +func TestSnapshot(t *testing.T) { + db := getTestDB() + + key := []byte("key") + value := []byte("hello world") + + db.Put(key, value) + + s := db.NewSnapshot() + defer s.Close() + + db.Put(key, []byte("hello world2")) + + if v, err := s.Get(key); err != nil { + t.Fatal(err) + } else if string(v) != string(value) { + t.Fatal(string(v)) + } +} + +func TestDestroy(t *testing.T) { + db := getTestDB() + + db.Put([]byte("a"), []byte("1")) + if err := db.Clear(); err != nil { + t.Fatal(err) + } + + if _, err := os.Stat(db.cfg.Path); err != nil { + t.Fatal("must exist ", err.Error()) + } + + if v, err := db.Get([]byte("a")); err != nil { + t.Fatal(err) + } else if string(v) == "1" { + t.Fatal(string(v)) + } + + db.Destroy() + + if _, err := os.Stat(db.cfg.Path); !os.IsNotExist(err) { + t.Fatal("must not exist") + } +} + +func TestCloseMore(t *testing.T) { + cfg := new(Config) + cfg.Path = "/tmp/testdb1234" + cfg.CacheSize = 4 * 1024 * 1024 + os.RemoveAll(cfg.Path) + for i := 0; i < 100; i++ { + db, err := OpenWithConfig(cfg) + if err != nil { + t.Fatal(err) + } + + db.Put([]byte("key"), []byte("value")) + + db.Close() + } +} diff --git a/leveldb/levigo-license b/leveldb/levigo-license new file mode 100644 index 0000000..c7c73be --- /dev/null +++ b/leveldb/levigo-license @@ -0,0 +1,7 @@ +Copyright (c) 2012 Jeffrey M Hodges + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/leveldb/options.go b/leveldb/options.go new file mode 100644 index 0000000..f080c24 --- /dev/null +++ b/leveldb/options.go @@ -0,0 +1,128 @@ +package leveldb + +// #cgo LDFLAGS: -lleveldb +// #include "leveldb/c.h" +import "C" + +type CompressionOpt int + +const ( + NoCompression = CompressionOpt(0) + SnappyCompression = CompressionOpt(1) +) + +type Options struct { + Opt *C.leveldb_options_t +} + +type ReadOptions struct { + Opt *C.leveldb_readoptions_t +} + +type WriteOptions struct { + Opt *C.leveldb_writeoptions_t +} + +func NewOptions() *Options { + opt := C.leveldb_options_create() + return &Options{opt} +} + +func NewReadOptions() *ReadOptions { + opt := C.leveldb_readoptions_create() + return &ReadOptions{opt} +} + +func NewWriteOptions() *WriteOptions { + opt := C.leveldb_writeoptions_create() + return &WriteOptions{opt} +} + +func (o *Options) Close() { + C.leveldb_options_destroy(o.Opt) +} + +func (o *Options) SetComparator(cmp *C.leveldb_comparator_t) { + C.leveldb_options_set_comparator(o.Opt, cmp) +} + +func (o *Options) SetErrorIfExists(error_if_exists bool) { + eie := boolToUchar(error_if_exists) + C.leveldb_options_set_error_if_exists(o.Opt, eie) +} + +func (o *Options) SetCache(cache *Cache) { + C.leveldb_options_set_cache(o.Opt, cache.Cache) +} + +// func (o *Options) SetEnv(env *Env) { +// C.leveldb_options_set_env(o.Opt, env.Env) +// } + +func (o *Options) SetInfoLog(log *C.leveldb_logger_t) { + C.leveldb_options_set_info_log(o.Opt, log) +} + +func (o *Options) SetWriteBufferSize(s int) { + C.leveldb_options_set_write_buffer_size(o.Opt, C.size_t(s)) +} + +func (o *Options) SetParanoidChecks(pc bool) { + C.leveldb_options_set_paranoid_checks(o.Opt, boolToUchar(pc)) +} + +func (o *Options) SetMaxOpenFiles(n int) { + C.leveldb_options_set_max_open_files(o.Opt, C.int(n)) +} + +func (o *Options) SetBlockSize(s int) { + C.leveldb_options_set_block_size(o.Opt, C.size_t(s)) +} + +func (o *Options) SetBlockRestartInterval(n int) { + C.leveldb_options_set_block_restart_interval(o.Opt, C.int(n)) +} + +func (o *Options) SetCompression(t CompressionOpt) { + C.leveldb_options_set_compression(o.Opt, C.int(t)) +} + +func (o *Options) SetCreateIfMissing(b bool) { + C.leveldb_options_set_create_if_missing(o.Opt, boolToUchar(b)) +} + +func (o *Options) SetFilterPolicy(fp *FilterPolicy) { + var policy *C.leveldb_filterpolicy_t + if fp != nil { + policy = fp.Policy + } + C.leveldb_options_set_filter_policy(o.Opt, policy) +} + +func (ro *ReadOptions) Close() { + C.leveldb_readoptions_destroy(ro.Opt) +} + +func (ro *ReadOptions) SetVerifyChecksums(b bool) { + C.leveldb_readoptions_set_verify_checksums(ro.Opt, boolToUchar(b)) +} + +func (ro *ReadOptions) SetFillCache(b bool) { + C.leveldb_readoptions_set_fill_cache(ro.Opt, boolToUchar(b)) +} + +func (ro *ReadOptions) SetSnapshot(snap *Snapshot) { + var s *C.leveldb_snapshot_t + if snap != nil { + s = snap.snap + } + C.leveldb_readoptions_set_snapshot(ro.Opt, s) +} + +func (wo *WriteOptions) Close() { + C.leveldb_writeoptions_destroy(wo.Opt) +} + +func (wo *WriteOptions) SetSync(b bool) { + C.leveldb_writeoptions_set_sync(wo.Opt, boolToUchar(b)) +} diff --git a/leveldb/snapshot.go b/leveldb/snapshot.go new file mode 100644 index 0000000..3c847e9 --- /dev/null +++ b/leveldb/snapshot.go @@ -0,0 +1,54 @@ +package leveldb + +// #cgo LDFLAGS: -lleveldb +// #include +// #include "leveldb/c.h" +import "C" + +type Snapshot struct { + db *DB + + snap *C.leveldb_snapshot_t + + readOpts *ReadOptions + iteratorOpts *ReadOptions +} + +func (s *Snapshot) Close() { + C.leveldb_release_snapshot(s.db.db, s.snap) + + s.iteratorOpts.Close() + s.readOpts.Close() +} + +func (s *Snapshot) Get(key []byte) ([]byte, error) { + return s.db.get(s.readOpts, key) +} + +func (s *Snapshot) NewIterator() *Iterator { + it := new(Iterator) + + it.it = C.leveldb_create_iterator(s.db.db, s.iteratorOpts.Opt) + + return it +} + +func (s *Snapshot) RangeIterator(min []byte, max []byte, rangeType uint8) *RangeLimitIterator { + return newRangeLimitIterator(s.NewIterator(), &Range{min, max, rangeType}, 0, -1, IteratorForward) +} + +func (s *Snapshot) RevRangeIterator(min []byte, max []byte, rangeType uint8) *RangeLimitIterator { + return newRangeLimitIterator(s.NewIterator(), &Range{min, max, rangeType}, 0, -1, IteratorBackward) +} + +//limit < 0, unlimit +//offset must >= 0, if < 0, will get nothing +func (s *Snapshot) RangeLimitIterator(min []byte, max []byte, rangeType uint8, offset int, limit int) *RangeLimitIterator { + return newRangeLimitIterator(s.NewIterator(), &Range{min, max, rangeType}, offset, limit, IteratorForward) +} + +//limit < 0, unlimit +//offset must >= 0, if < 0, will get nothing +func (s *Snapshot) RevRangeLimitIterator(min []byte, max []byte, rangeType uint8, offset int, limit int) *RangeLimitIterator { + return newRangeLimitIterator(s.NewIterator(), &Range{min, max, rangeType}, offset, limit, IteratorBackward) +} diff --git a/leveldb/util.go b/leveldb/util.go new file mode 100644 index 0000000..e1fd57d --- /dev/null +++ b/leveldb/util.go @@ -0,0 +1,43 @@ +package leveldb + +// #include "leveldb/c.h" +import "C" + +import ( + "fmt" + "reflect" + "unsafe" +) + +func boolToUchar(b bool) C.uchar { + uc := C.uchar(0) + if b { + uc = C.uchar(1) + } + return uc +} + +func ucharToBool(uc C.uchar) bool { + if uc == C.uchar(0) { + return false + } + return true +} + +func saveError(errStr *C.char) error { + if errStr != nil { + gs := C.GoString(errStr) + C.leveldb_free(unsafe.Pointer(errStr)) + return fmt.Errorf(gs) + } + return nil +} + +func slice(p unsafe.Pointer, n int) []byte { + var b []byte + pbyte := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + pbyte.Data = uintptr(p) + pbyte.Len = n + pbyte.Cap = n + return b +} diff --git a/server/cmd_hash.go b/server/cmd_hash.go index bce1be2..a20fef6 100644 --- a/server/cmd_hash.go +++ b/server/cmd_hash.go @@ -138,7 +138,7 @@ func hmgetCommand(c *client) error { return ErrCmdParams } - if v, err := c.db.HMget(args[0], args[1:]); err != nil { + if v, err := c.db.HMget(args[0], args[1:]...); err != nil { return err } else { c.writeArray(v) diff --git a/server/cmd_replication_test.go b/server/cmd_replication_test.go index 645f4f7..290d94b 100644 --- a/server/cmd_replication_test.go +++ b/server/cmd_replication_test.go @@ -3,14 +3,14 @@ package server import ( "bytes" "fmt" - "github.com/siddontang/go-leveldb/leveldb" + "github.com/siddontang/ledisdb/leveldb" "os" "testing" "time" ) func checkDataEqual(master *App, slave *App) error { - it := master.ldb.DataDB().Iterator(nil, nil, leveldb.RangeClose, 0, -1) + it := master.ldb.DataDB().RangeLimitIterator(nil, nil, leveldb.RangeClose, 0, -1) for ; it.Valid(); it.Next() { key := it.Key() value := it.Value() From a71fbfbfd3ce055ade4c4063c55c9fd794de626c Mon Sep 17 00:00:00 2001 From: siddontang Date: Thu, 19 Jun 2014 17:45:52 +0800 Subject: [PATCH 6/9] iterator find optimize --- leveldb/iterator.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/leveldb/iterator.go b/leveldb/iterator.go index 6f22b92..4659a7b 100644 --- a/leveldb/iterator.go +++ b/leveldb/iterator.go @@ -90,11 +90,17 @@ func (it *Iterator) Seek(key []byte) { func (it *Iterator) Find(key []byte) []byte { it.Seek(key) - if it.Valid() && bytes.Equal(it.Key(), key) { - return it.Value() - } else { - return nil + if it.Valid() { + var klen C.size_t + kdata := C.leveldb_iter_key(it.it, &klen) + if kdata == nil { + return nil + } else if bytes.Equal(slice(unsafe.Pointer(kdata), int(C.int(klen))), key) { + return it.Value() + } } + + return nil } type RangeLimitIterator struct { From 50bd648e0cae7f93657a9bd57161b7ce58a231ce Mon Sep 17 00:00:00 2001 From: siddontang Date: Thu, 19 Jun 2014 17:46:20 +0800 Subject: [PATCH 7/9] copy garyburd redid lib to client --- bootstrap.sh | 1 - client/go/redis/commandinfo.go | 45 +++ client/go/redis/conn.go | 418 +++++++++++++++++++++++++++ client/go/redis/doc.go | 167 +++++++++++ client/go/redis/log.go | 117 ++++++++ client/go/redis/pool.go | 358 +++++++++++++++++++++++ client/go/redis/pubsub.go | 129 +++++++++ client/go/redis/redis.go | 44 +++ client/go/redis/reply.go | 271 +++++++++++++++++ client/go/redis/scan.go | 513 +++++++++++++++++++++++++++++++++ client/go/redis/script.go | 86 ++++++ server/app_test.go | 2 +- server/cmd_hash_test.go | 2 +- server/cmd_kv_test.go | 2 +- server/cmd_list_test.go | 2 +- server/cmd_ttl_test.go | 2 +- server/cmd_zset_test.go | 2 +- 17 files changed, 2154 insertions(+), 7 deletions(-) create mode 100644 client/go/redis/commandinfo.go create mode 100644 client/go/redis/conn.go create mode 100644 client/go/redis/doc.go create mode 100644 client/go/redis/log.go create mode 100644 client/go/redis/pool.go create mode 100644 client/go/redis/pubsub.go create mode 100644 client/go/redis/redis.go create mode 100644 client/go/redis/reply.go create mode 100644 client/go/redis/scan.go create mode 100644 client/go/redis/script.go diff --git a/bootstrap.sh b/bootstrap.sh index c5aaefe..c05f3db 100644 --- a/bootstrap.sh +++ b/bootstrap.sh @@ -3,4 +3,3 @@ . ./dev.sh go get -u github.com/siddontang/go-log/log -go get -u github.com/garyburd/redigo/redis diff --git a/client/go/redis/commandinfo.go b/client/go/redis/commandinfo.go new file mode 100644 index 0000000..014115d --- /dev/null +++ b/client/go/redis/commandinfo.go @@ -0,0 +1,45 @@ +// Copyright 2014 Gary Burd +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package redis + +import ( + "strings" +) + +const ( + watchState = 1 << iota + multiState + subscribeState + monitorState +) + +type commandInfo struct { + set, clear int +} + +var commandInfos = map[string]commandInfo{ + "WATCH": commandInfo{set: watchState}, + "UNWATCH": commandInfo{clear: watchState}, + "MULTI": commandInfo{set: multiState}, + "EXEC": commandInfo{clear: watchState | multiState}, + "DISCARD": commandInfo{clear: watchState | multiState}, + "PSUBSCRIBE": commandInfo{set: subscribeState}, + "SUBSCRIBE": commandInfo{set: subscribeState}, + "MONITOR": commandInfo{set: monitorState}, +} + +func lookupCommandInfo(commandName string) commandInfo { + return commandInfos[strings.ToUpper(commandName)] +} diff --git a/client/go/redis/conn.go b/client/go/redis/conn.go new file mode 100644 index 0000000..331d3c6 --- /dev/null +++ b/client/go/redis/conn.go @@ -0,0 +1,418 @@ +// Copyright 2012 Gary Burd +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package redis + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "net" + "strconv" + "sync" + "time" +) + +// conn is the low-level implementation of Conn +type conn struct { + + // Shared + mu sync.Mutex + pending int + err error + conn net.Conn + + // Read + readTimeout time.Duration + br *bufio.Reader + + // Write + writeTimeout time.Duration + bw *bufio.Writer + + // Scratch space for formatting argument length. + // '*' or '$', length, "\r\n" + lenScratch [32]byte + + // Scratch space for formatting integers and floats. + numScratch [40]byte +} + +// Dial connects to the Redis server at the given network and address. +func Dial(network, address string) (Conn, error) { + c, err := net.Dial(network, address) + if err != nil { + return nil, err + } + return NewConn(c, 0, 0), nil +} + +// DialTimeout acts like Dial but takes timeouts for establishing the +// connection to the server, writing a command and reading a reply. +func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) { + var c net.Conn + var err error + if connectTimeout > 0 { + c, err = net.DialTimeout(network, address, connectTimeout) + } else { + c, err = net.Dial(network, address) + } + if err != nil { + return nil, err + } + return NewConn(c, readTimeout, writeTimeout), nil +} + +// NewConn returns a new Redigo connection for the given net connection. +func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn { + return &conn{ + conn: netConn, + bw: bufio.NewWriter(netConn), + br: bufio.NewReader(netConn), + readTimeout: readTimeout, + writeTimeout: writeTimeout, + } +} + +func (c *conn) Close() error { + c.mu.Lock() + err := c.err + if c.err == nil { + c.err = errors.New("redigo: closed") + err = c.conn.Close() + } + c.mu.Unlock() + return err +} + +func (c *conn) fatal(err error) error { + c.mu.Lock() + if c.err == nil { + c.err = err + // Close connection to force errors on subsequent calls and to unblock + // other reader or writer. + c.conn.Close() + } + c.mu.Unlock() + return err +} + +func (c *conn) Err() error { + c.mu.Lock() + err := c.err + c.mu.Unlock() + return err +} + +func (c *conn) writeLen(prefix byte, n int) error { + c.lenScratch[len(c.lenScratch)-1] = '\n' + c.lenScratch[len(c.lenScratch)-2] = '\r' + i := len(c.lenScratch) - 3 + for { + c.lenScratch[i] = byte('0' + n%10) + i -= 1 + n = n / 10 + if n == 0 { + break + } + } + c.lenScratch[i] = prefix + _, err := c.bw.Write(c.lenScratch[i:]) + return err +} + +func (c *conn) writeString(s string) error { + c.writeLen('$', len(s)) + c.bw.WriteString(s) + _, err := c.bw.WriteString("\r\n") + return err +} + +func (c *conn) writeBytes(p []byte) error { + c.writeLen('$', len(p)) + c.bw.Write(p) + _, err := c.bw.WriteString("\r\n") + return err +} + +func (c *conn) writeInt64(n int64) error { + return c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10)) +} + +func (c *conn) writeFloat64(n float64) error { + return c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64)) +} + +func (c *conn) writeCommand(cmd string, args []interface{}) (err error) { + c.writeLen('*', 1+len(args)) + err = c.writeString(cmd) + for _, arg := range args { + if err != nil { + break + } + switch arg := arg.(type) { + case string: + err = c.writeString(arg) + case []byte: + err = c.writeBytes(arg) + case int: + err = c.writeInt64(int64(arg)) + case int64: + err = c.writeInt64(arg) + case float64: + err = c.writeFloat64(arg) + case bool: + if arg { + err = c.writeString("1") + } else { + err = c.writeString("0") + } + case nil: + err = c.writeString("") + default: + var buf bytes.Buffer + fmt.Fprint(&buf, arg) + err = c.writeBytes(buf.Bytes()) + } + } + return err +} + +func (c *conn) readLine() ([]byte, error) { + p, err := c.br.ReadSlice('\n') + if err == bufio.ErrBufferFull { + return nil, errors.New("redigo: long response line") + } + if err != nil { + return nil, err + } + i := len(p) - 2 + if i < 0 || p[i] != '\r' { + return nil, errors.New("redigo: bad response line terminator") + } + return p[:i], nil +} + +// parseLen parses bulk string and array lengths. +func parseLen(p []byte) (int, error) { + if len(p) == 0 { + return -1, errors.New("redigo: malformed length") + } + + if p[0] == '-' && len(p) == 2 && p[1] == '1' { + // handle $-1 and $-1 null replies. + return -1, nil + } + + var n int + for _, b := range p { + n *= 10 + if b < '0' || b > '9' { + return -1, errors.New("redigo: illegal bytes in length") + } + n += int(b - '0') + } + + return n, nil +} + +// parseInt parses an integer reply. +func parseInt(p []byte) (interface{}, error) { + if len(p) == 0 { + return 0, errors.New("redigo: malformed integer") + } + + var negate bool + if p[0] == '-' { + negate = true + p = p[1:] + if len(p) == 0 { + return 0, errors.New("redigo: malformed integer") + } + } + + var n int64 + for _, b := range p { + n *= 10 + if b < '0' || b > '9' { + return 0, errors.New("redigo: illegal bytes in length") + } + n += int64(b - '0') + } + + if negate { + n = -n + } + return n, nil +} + +var ( + okReply interface{} = "OK" + pongReply interface{} = "PONG" +) + +func (c *conn) readReply() (interface{}, error) { + line, err := c.readLine() + if err != nil { + return nil, err + } + if len(line) == 0 { + return nil, errors.New("redigo: short response line") + } + switch line[0] { + case '+': + switch { + case len(line) == 3 && line[1] == 'O' && line[2] == 'K': + // Avoid allocation for frequent "+OK" response. + return okReply, nil + case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G': + // Avoid allocation in PING command benchmarks :) + return pongReply, nil + default: + return string(line[1:]), nil + } + case '-': + return Error(string(line[1:])), nil + case ':': + return parseInt(line[1:]) + case '$': + n, err := parseLen(line[1:]) + if n < 0 || err != nil { + return nil, err + } + p := make([]byte, n) + _, err = io.ReadFull(c.br, p) + if err != nil { + return nil, err + } + if line, err := c.readLine(); err != nil { + return nil, err + } else if len(line) != 0 { + return nil, errors.New("redigo: bad bulk string format") + } + return p, nil + case '*': + n, err := parseLen(line[1:]) + if n < 0 || err != nil { + return nil, err + } + r := make([]interface{}, n) + for i := range r { + r[i], err = c.readReply() + if err != nil { + return nil, err + } + } + return r, nil + } + return nil, errors.New("redigo: unexpected response line") +} + +func (c *conn) Send(cmd string, args ...interface{}) error { + c.mu.Lock() + c.pending += 1 + c.mu.Unlock() + if c.writeTimeout != 0 { + c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) + } + if err := c.writeCommand(cmd, args); err != nil { + return c.fatal(err) + } + return nil +} + +func (c *conn) Flush() error { + if c.writeTimeout != 0 { + c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) + } + if err := c.bw.Flush(); err != nil { + return c.fatal(err) + } + return nil +} + +func (c *conn) Receive() (reply interface{}, err error) { + c.mu.Lock() + // There can be more receives than sends when using pub/sub. To allow + // normal use of the connection after unsubscribe from all channels, do not + // decrement pending to a negative value. + if c.pending > 0 { + c.pending -= 1 + } + c.mu.Unlock() + if c.readTimeout != 0 { + c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) + } + if reply, err = c.readReply(); err != nil { + return nil, c.fatal(err) + } + if err, ok := reply.(Error); ok { + return nil, err + } + return +} + +func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) { + c.mu.Lock() + pending := c.pending + c.pending = 0 + c.mu.Unlock() + + if cmd == "" && pending == 0 { + return nil, nil + } + + if c.writeTimeout != 0 { + c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) + } + + if cmd != "" { + c.writeCommand(cmd, args) + } + + if err := c.bw.Flush(); err != nil { + return nil, c.fatal(err) + } + + if c.readTimeout != 0 { + c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) + } + + if cmd == "" { + reply := make([]interface{}, pending) + for i := range reply { + r, e := c.readReply() + if e != nil { + return nil, c.fatal(e) + } + reply[i] = r + } + return reply, nil + } + + var err error + var reply interface{} + for i := 0; i <= pending; i++ { + var e error + if reply, e = c.readReply(); e != nil { + return nil, c.fatal(e) + } + if e, ok := reply.(Error); ok && err == nil { + err = e + } + } + return reply, err +} diff --git a/client/go/redis/doc.go b/client/go/redis/doc.go new file mode 100644 index 0000000..63e6ffe --- /dev/null +++ b/client/go/redis/doc.go @@ -0,0 +1,167 @@ +// Copyright 2012 Gary Burd +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +// Package redis is a client for the Redis database. +// +// The Redigo FAQ (https://github.com/garyburd/redigo/wiki/FAQ) contains more +// documentation about this package. +// +// Connections +// +// The Conn interface is the primary interface for working with Redis. +// Applications create connections by calling the Dial, DialWithTimeout or +// NewConn functions. In the future, functions will be added for creating +// sharded and other types of connections. +// +// The application must call the connection Close method when the application +// is done with the connection. +// +// Executing Commands +// +// The Conn interface has a generic method for executing Redis commands: +// +// Do(commandName string, args ...interface{}) (reply interface{}, err error) +// +// The Redis command reference (http://redis.io/commands) lists the available +// commands. An example of using the Redis APPEND command is: +// +// n, err := conn.Do("APPEND", "key", "value") +// +// The Do method converts command arguments to binary strings for transmission +// to the server as follows: +// +// Go Type Conversion +// []byte Sent as is +// string Sent as is +// int, int64 strconv.FormatInt(v) +// float64 strconv.FormatFloat(v, 'g', -1, 64) +// bool true -> "1", false -> "0" +// nil "" +// all other types fmt.Print(v) +// +// Redis command reply types are represented using the following Go types: +// +// Redis type Go type +// error redis.Error +// integer int64 +// simple string string +// bulk string []byte or nil if value not present. +// array []interface{} or nil if value not present. +// +// Use type assertions or the reply helper functions to convert from +// interface{} to the specific Go type for the command result. +// +// Pipelining +// +// Connections support pipelining using the Send, Flush and Receive methods. +// +// Send(commandName string, args ...interface{}) error +// Flush() error +// Receive() (reply interface{}, err error) +// +// Send writes the command to the connection's output buffer. Flush flushes the +// connection's output buffer to the server. Receive reads a single reply from +// the server. The following example shows a simple pipeline. +// +// c.Send("SET", "foo", "bar") +// c.Send("GET", "foo") +// c.Flush() +// c.Receive() // reply from SET +// v, err = c.Receive() // reply from GET +// +// The Do method combines the functionality of the Send, Flush and Receive +// methods. The Do method starts by writing the command and flushing the output +// buffer. Next, the Do method receives all pending replies including the reply +// for the command just sent by Do. If any of the received replies is an error, +// then Do returns the error. If there are no errors, then Do returns the last +// reply. If the command argument to the Do method is "", then the Do method +// will flush the output buffer and receive pending replies without sending a +// command. +// +// Use the Send and Do methods to implement pipelined transactions. +// +// c.Send("MULTI") +// c.Send("INCR", "foo") +// c.Send("INCR", "bar") +// r, err := c.Do("EXEC") +// fmt.Println(r) // prints [1, 1] +// +// Concurrency +// +// Connections support a single concurrent caller to the write methods (Send, +// Flush) and a single concurrent caller to the read method (Receive). Because +// Do method combines the functionality of Send, Flush and Receive, the Do +// method cannot be called concurrently with the other methods. +// +// For full concurrent access to Redis, use the thread-safe Pool to get and +// release connections from within a goroutine. +// +// Publish and Subscribe +// +// Use the Send, Flush and Receive methods to implement Pub/Sub subscribers. +// +// c.Send("SUBSCRIBE", "example") +// c.Flush() +// for { +// reply, err := c.Receive() +// if err != nil { +// return err +// } +// // process pushed message +// } +// +// The PubSubConn type wraps a Conn with convenience methods for implementing +// subscribers. The Subscribe, PSubscribe, Unsubscribe and PUnsubscribe methods +// send and flush a subscription management command. The receive method +// converts a pushed message to convenient types for use in a type switch. +// +// psc := PubSubConn{c} +// psc.Subscribe("example") +// for { +// switch v := psc.Receive().(type) { +// case redis.Message: +// fmt.Printf("%s: message: %s\n", v.Channel, v.Data) +// case redis.Subscription: +// fmt.Printf("%s: %s %d\n", v.Channel, v.Kind, v.Count) +// case error: +// return v +// } +// } +// +// Reply Helpers +// +// The Bool, Int, Bytes, String, Strings and Values functions convert a reply +// to a value of a specific type. To allow convenient wrapping of calls to the +// connection Do and Receive methods, the functions take a second argument of +// type error. If the error is non-nil, then the helper function returns the +// error. If the error is nil, the function converts the reply to the specified +// type: +// +// exists, err := redis.Bool(c.Do("EXISTS", "foo")) +// if err != nil { +// // handle error return from c.Do or type conversion error. +// } +// +// The Scan function converts elements of a array reply to Go types: +// +// var value1 int +// var value2 string +// reply, err := redis.Values(c.Do("MGET", "key1", "key2")) +// if err != nil { +// // handle error +// } +// if _, err := redis.Scan(reply, &value1, &value2); err != nil { +// // handle error +// } +package redis diff --git a/client/go/redis/log.go b/client/go/redis/log.go new file mode 100644 index 0000000..129b86d --- /dev/null +++ b/client/go/redis/log.go @@ -0,0 +1,117 @@ +// Copyright 2012 Gary Burd +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package redis + +import ( + "bytes" + "fmt" + "log" +) + +// NewLoggingConn returns a logging wrapper around a connection. +func NewLoggingConn(conn Conn, logger *log.Logger, prefix string) Conn { + if prefix != "" { + prefix = prefix + "." + } + return &loggingConn{conn, logger, prefix} +} + +type loggingConn struct { + Conn + logger *log.Logger + prefix string +} + +func (c *loggingConn) Close() error { + err := c.Conn.Close() + var buf bytes.Buffer + fmt.Fprintf(&buf, "%sClose() -> (%v)", c.prefix, err) + c.logger.Output(2, buf.String()) + return err +} + +func (c *loggingConn) printValue(buf *bytes.Buffer, v interface{}) { + const chop = 32 + switch v := v.(type) { + case []byte: + if len(v) > chop { + fmt.Fprintf(buf, "%q...", v[:chop]) + } else { + fmt.Fprintf(buf, "%q", v) + } + case string: + if len(v) > chop { + fmt.Fprintf(buf, "%q...", v[:chop]) + } else { + fmt.Fprintf(buf, "%q", v) + } + case []interface{}: + if len(v) == 0 { + buf.WriteString("[]") + } else { + sep := "[" + fin := "]" + if len(v) > chop { + v = v[:chop] + fin = "...]" + } + for _, vv := range v { + buf.WriteString(sep) + c.printValue(buf, vv) + sep = ", " + } + buf.WriteString(fin) + } + default: + fmt.Fprint(buf, v) + } +} + +func (c *loggingConn) print(method, commandName string, args []interface{}, reply interface{}, err error) { + var buf bytes.Buffer + fmt.Fprintf(&buf, "%s%s(", c.prefix, method) + if method != "Receive" { + buf.WriteString(commandName) + for _, arg := range args { + buf.WriteString(", ") + c.printValue(&buf, arg) + } + } + buf.WriteString(") -> (") + if method != "Send" { + c.printValue(&buf, reply) + buf.WriteString(", ") + } + fmt.Fprintf(&buf, "%v)", err) + c.logger.Output(3, buf.String()) +} + +func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{}, error) { + reply, err := c.Conn.Do(commandName, args...) + c.print("Do", commandName, args, reply, err) + return reply, err +} + +func (c *loggingConn) Send(commandName string, args ...interface{}) error { + err := c.Conn.Send(commandName, args...) + c.print("Send", commandName, args, nil, err) + return err +} + +func (c *loggingConn) Receive() (interface{}, error) { + reply, err := c.Conn.Receive() + c.print("Receive", "", nil, reply, err) + return reply, err +} diff --git a/client/go/redis/pool.go b/client/go/redis/pool.go new file mode 100644 index 0000000..2e2dac9 --- /dev/null +++ b/client/go/redis/pool.go @@ -0,0 +1,358 @@ +// Copyright 2012 Gary Burd +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package redis + +import ( + "bytes" + "container/list" + "crypto/rand" + "crypto/sha1" + "errors" + "io" + "strconv" + "sync" + "time" +) + +var nowFunc = time.Now // for testing + +// ErrPoolExhausted is returned from a pool connection method (Do, Send, +// Receive, Flush, Err) when the maximum number of database connections in the +// pool has been reached. +var ErrPoolExhausted = errors.New("redigo: connection pool exhausted") + +var errPoolClosed = errors.New("redigo: connection pool closed") + +// Pool maintains a pool of connections. The application calls the Get method +// to get a connection from the pool and the connection's Close method to +// return the connection's resources to the pool. +// +// The following example shows how to use a pool in a web application. The +// application creates a pool at application startup and makes it available to +// request handlers using a global variable. +// +// func newPool(server, password string) *redis.Pool { +// return &redis.Pool{ +// MaxIdle: 3, +// IdleTimeout: 240 * time.Second, +// Dial: func () (redis.Conn, error) { +// c, err := redis.Dial("tcp", server) +// if err != nil { +// return nil, err +// } +// if _, err := c.Do("AUTH", password); err != nil { +// c.Close() +// return nil, err +// } +// return c, err +// }, +// TestOnBorrow: func(c redis.Conn, t time.Time) error { +// _, err := c.Do("PING") +// return err +// }, +// } +// } +// +// var ( +// pool *redis.Pool +// redisServer = flag.String("redisServer", ":6379", "") +// redisPassword = flag.String("redisPassword", "", "") +// ) +// +// func main() { +// flag.Parse() +// pool = newPool(*redisServer, *redisPassword) +// ... +// } +// +// A request handler gets a connection from the pool and closes the connection +// when the handler is done: +// +// func serveHome(w http.ResponseWriter, r *http.Request) { +// conn := pool.Get() +// defer conn.Close() +// .... +// } +// +type Pool struct { + + // Dial is an application supplied function for creating new connections. + Dial func() (Conn, error) + + // TestOnBorrow is an optional application supplied function for checking + // the health of an idle connection before the connection is used again by + // the application. Argument t is the time that the connection was returned + // to the pool. If the function returns an error, then the connection is + // closed. + TestOnBorrow func(c Conn, t time.Time) error + + // Maximum number of idle connections in the pool. + MaxIdle int + + // Maximum number of connections allocated by the pool at a given time. + // When zero, there is no limit on the number of connections in the pool. + MaxActive int + + // Close connections after remaining idle for this duration. If the value + // is zero, then idle connections are not closed. Applications should set + // the timeout to a value less than the server's timeout. + IdleTimeout time.Duration + + // mu protects fields defined below. + mu sync.Mutex + closed bool + active int + + // Stack of idleConn with most recently used at the front. + idle list.List +} + +type idleConn struct { + c Conn + t time.Time +} + +// NewPool is a convenience function for initializing a pool. +func NewPool(newFn func() (Conn, error), maxIdle int) *Pool { + return &Pool{Dial: newFn, MaxIdle: maxIdle} +} + +// Get gets a connection. The application must close the returned connection. +// The connection acquires an underlying connection on the first call to the +// connection Do, Send, Receive, Flush or Err methods. An application can force +// the connection to acquire an underlying connection without executing a Redis +// command by calling the Err method. +func (p *Pool) Get() Conn { + return &pooledConnection{p: p} +} + +// ActiveCount returns the number of active connections in the pool. +func (p *Pool) ActiveCount() int { + p.mu.Lock() + active := p.active + p.mu.Unlock() + return active +} + +// Close releases the resources used by the pool. +func (p *Pool) Close() error { + p.mu.Lock() + idle := p.idle + p.idle.Init() + p.closed = true + p.active -= idle.Len() + p.mu.Unlock() + for e := idle.Front(); e != nil; e = e.Next() { + e.Value.(idleConn).c.Close() + } + return nil +} + +// get prunes stale connections and returns a connection from the idle list or +// creates a new connection. +func (p *Pool) get() (Conn, error) { + p.mu.Lock() + + if p.closed { + p.mu.Unlock() + return nil, errors.New("redigo: get on closed pool") + } + + // Prune stale connections. + + if timeout := p.IdleTimeout; timeout > 0 { + for i, n := 0, p.idle.Len(); i < n; i++ { + e := p.idle.Back() + if e == nil { + break + } + ic := e.Value.(idleConn) + if ic.t.Add(timeout).After(nowFunc()) { + break + } + p.idle.Remove(e) + p.active -= 1 + p.mu.Unlock() + ic.c.Close() + p.mu.Lock() + } + } + + // Get idle connection. + + for i, n := 0, p.idle.Len(); i < n; i++ { + e := p.idle.Front() + if e == nil { + break + } + ic := e.Value.(idleConn) + p.idle.Remove(e) + test := p.TestOnBorrow + p.mu.Unlock() + if test == nil || test(ic.c, ic.t) == nil { + return ic.c, nil + } + ic.c.Close() + p.mu.Lock() + p.active -= 1 + } + + if p.MaxActive > 0 && p.active >= p.MaxActive { + p.mu.Unlock() + return nil, ErrPoolExhausted + } + + // No idle connection, create new. + + dial := p.Dial + p.active += 1 + p.mu.Unlock() + c, err := dial() + if err != nil { + p.mu.Lock() + p.active -= 1 + p.mu.Unlock() + c = nil + } + return c, err +} + +func (p *Pool) put(c Conn, forceClose bool) error { + if c.Err() == nil && !forceClose { + p.mu.Lock() + if !p.closed { + p.idle.PushFront(idleConn{t: nowFunc(), c: c}) + if p.idle.Len() > p.MaxIdle { + c = p.idle.Remove(p.idle.Back()).(idleConn).c + } else { + c = nil + } + } + p.mu.Unlock() + } + if c != nil { + p.mu.Lock() + p.active -= 1 + p.mu.Unlock() + return c.Close() + } + return nil +} + +type pooledConnection struct { + c Conn + err error + p *Pool + state int +} + +func (c *pooledConnection) get() error { + if c.err == nil && c.c == nil { + c.c, c.err = c.p.get() + } + return c.err +} + +var ( + sentinel []byte + sentinelOnce sync.Once +) + +func initSentinel() { + p := make([]byte, 64) + if _, err := rand.Read(p); err == nil { + sentinel = p + } else { + h := sha1.New() + io.WriteString(h, "Oops, rand failed. Use time instead.") + io.WriteString(h, strconv.FormatInt(time.Now().UnixNano(), 10)) + sentinel = h.Sum(nil) + } +} + +func (c *pooledConnection) Close() (err error) { + if c.c != nil { + if c.state&multiState != 0 { + c.c.Send("DISCARD") + c.state &^= (multiState | watchState) + } else if c.state&watchState != 0 { + c.c.Send("UNWATCH") + c.state &^= watchState + } + if c.state&subscribeState != 0 { + c.c.Send("UNSUBSCRIBE") + c.c.Send("PUNSUBSCRIBE") + // To detect the end of the message stream, ask the server to echo + // a sentinel value and read until we see that value. + sentinelOnce.Do(initSentinel) + c.c.Send("ECHO", sentinel) + c.c.Flush() + for { + p, err := c.c.Receive() + if err != nil { + break + } + if p, ok := p.([]byte); ok && bytes.Equal(p, sentinel) { + c.state &^= subscribeState + break + } + } + } + c.c.Do("") + c.p.put(c.c, c.state != 0) + c.c = nil + c.err = errPoolClosed + } + return err +} + +func (c *pooledConnection) Err() error { + if err := c.get(); err != nil { + return err + } + return c.c.Err() +} + +func (c *pooledConnection) Do(commandName string, args ...interface{}) (reply interface{}, err error) { + if err := c.get(); err != nil { + return nil, err + } + ci := lookupCommandInfo(commandName) + c.state = (c.state | ci.set) &^ ci.clear + return c.c.Do(commandName, args...) +} + +func (c *pooledConnection) Send(commandName string, args ...interface{}) error { + if err := c.get(); err != nil { + return err + } + ci := lookupCommandInfo(commandName) + c.state = (c.state | ci.set) &^ ci.clear + return c.c.Send(commandName, args...) +} + +func (c *pooledConnection) Flush() error { + if err := c.get(); err != nil { + return err + } + return c.c.Flush() +} + +func (c *pooledConnection) Receive() (reply interface{}, err error) { + if err := c.get(); err != nil { + return nil, err + } + return c.c.Receive() +} diff --git a/client/go/redis/pubsub.go b/client/go/redis/pubsub.go new file mode 100644 index 0000000..f079042 --- /dev/null +++ b/client/go/redis/pubsub.go @@ -0,0 +1,129 @@ +// Copyright 2012 Gary Burd +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package redis + +import ( + "errors" +) + +// Subscription represents a subscribe or unsubscribe notification. +type Subscription struct { + + // Kind is "subscribe", "unsubscribe", "psubscribe" or "punsubscribe" + Kind string + + // The channel that was changed. + Channel string + + // The current number of subscriptions for connection. + Count int +} + +// Message represents a message notification. +type Message struct { + + // The originating channel. + Channel string + + // The message data. + Data []byte +} + +// PMessage represents a pmessage notification. +type PMessage struct { + + // The matched pattern. + Pattern string + + // The originating channel. + Channel string + + // The message data. + Data []byte +} + +// PubSubConn wraps a Conn with convenience methods for subscribers. +type PubSubConn struct { + Conn Conn +} + +// Close closes the connection. +func (c PubSubConn) Close() error { + return c.Conn.Close() +} + +// Subscribe subscribes the connection to the specified channels. +func (c PubSubConn) Subscribe(channel ...interface{}) error { + c.Conn.Send("SUBSCRIBE", channel...) + return c.Conn.Flush() +} + +// PSubscribe subscribes the connection to the given patterns. +func (c PubSubConn) PSubscribe(channel ...interface{}) error { + c.Conn.Send("PSUBSCRIBE", channel...) + return c.Conn.Flush() +} + +// Unsubscribe unsubscribes the connection from the given channels, or from all +// of them if none is given. +func (c PubSubConn) Unsubscribe(channel ...interface{}) error { + c.Conn.Send("UNSUBSCRIBE", channel...) + return c.Conn.Flush() +} + +// PUnsubscribe unsubscribes the connection from the given patterns, or from all +// of them if none is given. +func (c PubSubConn) PUnsubscribe(channel ...interface{}) error { + c.Conn.Send("PUNSUBSCRIBE", channel...) + return c.Conn.Flush() +} + +// Receive returns a pushed message as a Subscription, Message, PMessage or +// error. The return value is intended to be used directly in a type switch as +// illustrated in the PubSubConn example. +func (c PubSubConn) Receive() interface{} { + reply, err := Values(c.Conn.Receive()) + if err != nil { + return err + } + + var kind string + reply, err = Scan(reply, &kind) + if err != nil { + return err + } + + switch kind { + case "message": + var m Message + if _, err := Scan(reply, &m.Channel, &m.Data); err != nil { + return err + } + return m + case "pmessage": + var pm PMessage + if _, err := Scan(reply, &pm.Pattern, &pm.Channel, &pm.Data); err != nil { + return err + } + return pm + case "subscribe", "psubscribe", "unsubscribe", "punsubscribe": + s := Subscription{Kind: kind} + if _, err := Scan(reply, &s.Channel, &s.Count); err != nil { + return err + } + return s + } + return errors.New("redigo: unknown pubsub notification") +} diff --git a/client/go/redis/redis.go b/client/go/redis/redis.go new file mode 100644 index 0000000..c90a48e --- /dev/null +++ b/client/go/redis/redis.go @@ -0,0 +1,44 @@ +// Copyright 2012 Gary Burd +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package redis + +// Error represents an error returned in a command reply. +type Error string + +func (err Error) Error() string { return string(err) } + +// Conn represents a connection to a Redis server. +type Conn interface { + // Close closes the connection. + Close() error + + // Err returns a non-nil value if the connection is broken. The returned + // value is either the first non-nil value returned from the underlying + // network connection or a protocol parsing error. Applications should + // close broken connections. + Err() error + + // Do sends a command to the server and returns the received reply. + Do(commandName string, args ...interface{}) (reply interface{}, err error) + + // Send writes the command to the client's output buffer. + Send(commandName string, args ...interface{}) error + + // Flush flushes the output buffer to the Redis server. + Flush() error + + // Receive receives a single reply from the Redis server + Receive() (reply interface{}, err error) +} diff --git a/client/go/redis/reply.go b/client/go/redis/reply.go new file mode 100644 index 0000000..161a147 --- /dev/null +++ b/client/go/redis/reply.go @@ -0,0 +1,271 @@ +// Copyright 2012 Gary Burd +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package redis + +import ( + "errors" + "fmt" + "strconv" +) + +// ErrNil indicates that a reply value is nil. +var ErrNil = errors.New("redigo: nil returned") + +// Int is a helper that converts a command reply to an integer. If err is not +// equal to nil, then Int returns 0, err. Otherwise, Int converts the +// reply to an int as follows: +// +// Reply type Result +// integer int(reply), nil +// bulk string parsed reply, nil +// nil 0, ErrNil +// other 0, error +func Int(reply interface{}, err error) (int, error) { + if err != nil { + return 0, err + } + switch reply := reply.(type) { + case int64: + x := int(reply) + if int64(x) != reply { + return 0, strconv.ErrRange + } + return x, nil + case []byte: + n, err := strconv.ParseInt(string(reply), 10, 0) + return int(n), err + case nil: + return 0, ErrNil + case Error: + return 0, reply + } + return 0, fmt.Errorf("redigo: unexpected type for Int, got type %T", reply) +} + +// Int64 is a helper that converts a command reply to 64 bit integer. If err is +// not equal to nil, then Int returns 0, err. Otherwise, Int64 converts the +// reply to an int64 as follows: +// +// Reply type Result +// integer reply, nil +// bulk string parsed reply, nil +// nil 0, ErrNil +// other 0, error +func Int64(reply interface{}, err error) (int64, error) { + if err != nil { + return 0, err + } + switch reply := reply.(type) { + case int64: + return reply, nil + case []byte: + n, err := strconv.ParseInt(string(reply), 10, 64) + return n, err + case nil: + return 0, ErrNil + case Error: + return 0, reply + } + return 0, fmt.Errorf("redigo: unexpected type for Int64, got type %T", reply) +} + +var errNegativeInt = errors.New("redigo: unexpected value for Uint64") + +// Uint64 is a helper that converts a command reply to 64 bit integer. If err is +// not equal to nil, then Int returns 0, err. Otherwise, Int64 converts the +// reply to an int64 as follows: +// +// Reply type Result +// integer reply, nil +// bulk string parsed reply, nil +// nil 0, ErrNil +// other 0, error +func Uint64(reply interface{}, err error) (uint64, error) { + if err != nil { + return 0, err + } + switch reply := reply.(type) { + case int64: + if reply < 0 { + return 0, errNegativeInt + } + return uint64(reply), nil + case []byte: + n, err := strconv.ParseUint(string(reply), 10, 64) + return n, err + case nil: + return 0, ErrNil + case Error: + return 0, reply + } + return 0, fmt.Errorf("redigo: unexpected type for Uint64, got type %T", reply) +} + +// Float64 is a helper that converts a command reply to 64 bit float. If err is +// not equal to nil, then Float64 returns 0, err. Otherwise, Float64 converts +// the reply to an int as follows: +// +// Reply type Result +// bulk string parsed reply, nil +// nil 0, ErrNil +// other 0, error +func Float64(reply interface{}, err error) (float64, error) { + if err != nil { + return 0, err + } + switch reply := reply.(type) { + case []byte: + n, err := strconv.ParseFloat(string(reply), 64) + return n, err + case nil: + return 0, ErrNil + case Error: + return 0, reply + } + return 0, fmt.Errorf("redigo: unexpected type for Float64, got type %T", reply) +} + +// String is a helper that converts a command reply to a string. If err is not +// equal to nil, then String returns "", err. Otherwise String converts the +// reply to a string as follows: +// +// Reply type Result +// bulk string string(reply), nil +// simple string reply, nil +// nil "", ErrNil +// other "", error +func String(reply interface{}, err error) (string, error) { + if err != nil { + return "", err + } + switch reply := reply.(type) { + case []byte: + return string(reply), nil + case string: + return reply, nil + case nil: + return "", ErrNil + case Error: + return "", reply + } + return "", fmt.Errorf("redigo: unexpected type for String, got type %T", reply) +} + +// Bytes is a helper that converts a command reply to a slice of bytes. If err +// is not equal to nil, then Bytes returns nil, err. Otherwise Bytes converts +// the reply to a slice of bytes as follows: +// +// Reply type Result +// bulk string reply, nil +// simple string []byte(reply), nil +// nil nil, ErrNil +// other nil, error +func Bytes(reply interface{}, err error) ([]byte, error) { + if err != nil { + return nil, err + } + switch reply := reply.(type) { + case []byte: + return reply, nil + case string: + return []byte(reply), nil + case nil: + return nil, ErrNil + case Error: + return nil, reply + } + return nil, fmt.Errorf("redigo: unexpected type for Bytes, got type %T", reply) +} + +// Bool is a helper that converts a command reply to a boolean. If err is not +// equal to nil, then Bool returns false, err. Otherwise Bool converts the +// reply to boolean as follows: +// +// Reply type Result +// integer value != 0, nil +// bulk string strconv.ParseBool(reply) +// nil false, ErrNil +// other false, error +func Bool(reply interface{}, err error) (bool, error) { + if err != nil { + return false, err + } + switch reply := reply.(type) { + case int64: + return reply != 0, nil + case []byte: + return strconv.ParseBool(string(reply)) + case nil: + return false, ErrNil + case Error: + return false, reply + } + return false, fmt.Errorf("redigo: unexpected type for Bool, got type %T", reply) +} + +// MultiBulk is deprecated. Use Values. +func MultiBulk(reply interface{}, err error) ([]interface{}, error) { return Values(reply, err) } + +// Values is a helper that converts an array command reply to a []interface{}. +// If err is not equal to nil, then Values returns nil, err. Otherwise, Values +// converts the reply as follows: +// +// Reply type Result +// array reply, nil +// nil nil, ErrNil +// other nil, error +func Values(reply interface{}, err error) ([]interface{}, error) { + if err != nil { + return nil, err + } + switch reply := reply.(type) { + case []interface{}: + return reply, nil + case nil: + return nil, ErrNil + case Error: + return nil, reply + } + return nil, fmt.Errorf("redigo: unexpected type for Values, got type %T", reply) +} + +// Strings is a helper that converts an array command reply to a []string. If +// err is not equal to nil, then Strings returns nil, err. If one of the array +// items is not a bulk string or nil, then Strings returns an error. +func Strings(reply interface{}, err error) ([]string, error) { + if err != nil { + return nil, err + } + switch reply := reply.(type) { + case []interface{}: + result := make([]string, len(reply)) + for i := range reply { + if reply[i] == nil { + continue + } + p, ok := reply[i].([]byte) + if !ok { + return nil, fmt.Errorf("redigo: unexpected element type for Strings, got type %T", reply[i]) + } + result[i] = string(p) + } + return result, nil + case nil: + return nil, ErrNil + case Error: + return nil, reply + } + return nil, fmt.Errorf("redigo: unexpected type for Strings, got type %T", reply) +} diff --git a/client/go/redis/scan.go b/client/go/redis/scan.go new file mode 100644 index 0000000..8c9cfa1 --- /dev/null +++ b/client/go/redis/scan.go @@ -0,0 +1,513 @@ +// Copyright 2012 Gary Burd +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package redis + +import ( + "errors" + "fmt" + "reflect" + "strconv" + "strings" + "sync" +) + +func ensureLen(d reflect.Value, n int) { + if n > d.Cap() { + d.Set(reflect.MakeSlice(d.Type(), n, n)) + } else { + d.SetLen(n) + } +} + +func cannotConvert(d reflect.Value, s interface{}) error { + return fmt.Errorf("redigo: Scan cannot convert from %s to %s", + reflect.TypeOf(s), d.Type()) +} + +func convertAssignBytes(d reflect.Value, s []byte) (err error) { + switch d.Type().Kind() { + case reflect.Float32, reflect.Float64: + var x float64 + x, err = strconv.ParseFloat(string(s), d.Type().Bits()) + d.SetFloat(x) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + var x int64 + x, err = strconv.ParseInt(string(s), 10, d.Type().Bits()) + d.SetInt(x) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + var x uint64 + x, err = strconv.ParseUint(string(s), 10, d.Type().Bits()) + d.SetUint(x) + case reflect.Bool: + var x bool + x, err = strconv.ParseBool(string(s)) + d.SetBool(x) + case reflect.String: + d.SetString(string(s)) + case reflect.Slice: + if d.Type().Elem().Kind() != reflect.Uint8 { + err = cannotConvert(d, s) + } else { + d.SetBytes(s) + } + default: + err = cannotConvert(d, s) + } + return +} + +func convertAssignInt(d reflect.Value, s int64) (err error) { + switch d.Type().Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + d.SetInt(s) + if d.Int() != s { + err = strconv.ErrRange + d.SetInt(0) + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if s < 0 { + err = strconv.ErrRange + } else { + x := uint64(s) + d.SetUint(x) + if d.Uint() != x { + err = strconv.ErrRange + d.SetUint(0) + } + } + case reflect.Bool: + d.SetBool(s != 0) + default: + err = cannotConvert(d, s) + } + return +} + +func convertAssignValue(d reflect.Value, s interface{}) (err error) { + switch s := s.(type) { + case []byte: + err = convertAssignBytes(d, s) + case int64: + err = convertAssignInt(d, s) + default: + err = cannotConvert(d, s) + } + return err +} + +func convertAssignValues(d reflect.Value, s []interface{}) error { + if d.Type().Kind() != reflect.Slice { + return cannotConvert(d, s) + } + ensureLen(d, len(s)) + for i := 0; i < len(s); i++ { + if err := convertAssignValue(d.Index(i), s[i]); err != nil { + return err + } + } + return nil +} + +func convertAssign(d interface{}, s interface{}) (err error) { + // Handle the most common destination types using type switches and + // fall back to reflection for all other types. + switch s := s.(type) { + case nil: + // ingore + case []byte: + switch d := d.(type) { + case *string: + *d = string(s) + case *int: + *d, err = strconv.Atoi(string(s)) + case *bool: + *d, err = strconv.ParseBool(string(s)) + case *[]byte: + *d = s + case *interface{}: + *d = s + case nil: + // skip value + default: + if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr { + err = cannotConvert(d, s) + } else { + err = convertAssignBytes(d.Elem(), s) + } + } + case int64: + switch d := d.(type) { + case *int: + x := int(s) + if int64(x) != s { + err = strconv.ErrRange + x = 0 + } + *d = x + case *bool: + *d = s != 0 + case *interface{}: + *d = s + case nil: + // skip value + default: + if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr { + err = cannotConvert(d, s) + } else { + err = convertAssignInt(d.Elem(), s) + } + } + case []interface{}: + switch d := d.(type) { + case *[]interface{}: + *d = s + case *interface{}: + *d = s + case nil: + // skip value + default: + if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr { + err = cannotConvert(d, s) + } else { + err = convertAssignValues(d.Elem(), s) + } + } + case Error: + err = s + default: + err = cannotConvert(reflect.ValueOf(d), s) + } + return +} + +// Scan copies from src to the values pointed at by dest. +// +// The values pointed at by dest must be an integer, float, boolean, string, +// []byte, interface{} or slices of these types. Scan uses the standard strconv +// package to convert bulk strings to numeric and boolean types. +// +// If a dest value is nil, then the corresponding src value is skipped. +// +// If a src element is nil, then the corresponding dest value is not modified. +// +// To enable easy use of Scan in a loop, Scan returns the slice of src +// following the copied values. +func Scan(src []interface{}, dest ...interface{}) ([]interface{}, error) { + if len(src) < len(dest) { + return nil, errors.New("redigo: Scan array short") + } + var err error + for i, d := range dest { + err = convertAssign(d, src[i]) + if err != nil { + break + } + } + return src[len(dest):], err +} + +type fieldSpec struct { + name string + index []int + //omitEmpty bool +} + +type structSpec struct { + m map[string]*fieldSpec + l []*fieldSpec +} + +func (ss *structSpec) fieldSpec(name []byte) *fieldSpec { + return ss.m[string(name)] +} + +func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec) { + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + switch { + case f.PkgPath != "": + // Ignore unexported fields. + case f.Anonymous: + // TODO: Handle pointers. Requires change to decoder and + // protection against infinite recursion. + if f.Type.Kind() == reflect.Struct { + compileStructSpec(f.Type, depth, append(index, i), ss) + } + default: + fs := &fieldSpec{name: f.Name} + tag := f.Tag.Get("redis") + p := strings.Split(tag, ",") + if len(p) > 0 { + if p[0] == "-" { + continue + } + if len(p[0]) > 0 { + fs.name = p[0] + } + for _, s := range p[1:] { + switch s { + //case "omitempty": + // fs.omitempty = true + default: + panic(errors.New("redigo: unknown field flag " + s + " for type " + t.Name())) + } + } + } + d, found := depth[fs.name] + if !found { + d = 1 << 30 + } + switch { + case len(index) == d: + // At same depth, remove from result. + delete(ss.m, fs.name) + j := 0 + for i := 0; i < len(ss.l); i++ { + if fs.name != ss.l[i].name { + ss.l[j] = ss.l[i] + j += 1 + } + } + ss.l = ss.l[:j] + case len(index) < d: + fs.index = make([]int, len(index)+1) + copy(fs.index, index) + fs.index[len(index)] = i + depth[fs.name] = len(index) + ss.m[fs.name] = fs + ss.l = append(ss.l, fs) + } + } + } +} + +var ( + structSpecMutex sync.RWMutex + structSpecCache = make(map[reflect.Type]*structSpec) + defaultFieldSpec = &fieldSpec{} +) + +func structSpecForType(t reflect.Type) *structSpec { + + structSpecMutex.RLock() + ss, found := structSpecCache[t] + structSpecMutex.RUnlock() + if found { + return ss + } + + structSpecMutex.Lock() + defer structSpecMutex.Unlock() + ss, found = structSpecCache[t] + if found { + return ss + } + + ss = &structSpec{m: make(map[string]*fieldSpec)} + compileStructSpec(t, make(map[string]int), nil, ss) + structSpecCache[t] = ss + return ss +} + +var errScanStructValue = errors.New("redigo: ScanStruct value must be non-nil pointer to a struct") + +// ScanStruct scans alternating names and values from src to a struct. The +// HGETALL and CONFIG GET commands return replies in this format. +// +// ScanStruct uses exported field names to match values in the response. Use +// 'redis' field tag to override the name: +// +// Field int `redis:"myName"` +// +// Fields with the tag redis:"-" are ignored. +// +// Integer, float, boolean, string and []byte fields are supported. Scan uses the +// standard strconv package to convert bulk string values to numeric and +// boolean types. +// +// If a src element is nil, then the corresponding field is not modified. +func ScanStruct(src []interface{}, dest interface{}) error { + d := reflect.ValueOf(dest) + if d.Kind() != reflect.Ptr || d.IsNil() { + return errScanStructValue + } + d = d.Elem() + if d.Kind() != reflect.Struct { + return errScanStructValue + } + ss := structSpecForType(d.Type()) + + if len(src)%2 != 0 { + return errors.New("redigo: ScanStruct expects even number of values in values") + } + + for i := 0; i < len(src); i += 2 { + s := src[i+1] + if s == nil { + continue + } + name, ok := src[i].([]byte) + if !ok { + return errors.New("redigo: ScanStruct key not a bulk string value") + } + fs := ss.fieldSpec(name) + if fs == nil { + continue + } + if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil { + return err + } + } + return nil +} + +var ( + errScanSliceValue = errors.New("redigo: ScanSlice dest must be non-nil pointer to a struct") +) + +// ScanSlice scans src to the slice pointed to by dest. The elements the dest +// slice must be integer, float, boolean, string, struct or pointer to struct +// values. +// +// Struct fields must be integer, float, boolean or string values. All struct +// fields are used unless a subset is specified using fieldNames. +func ScanSlice(src []interface{}, dest interface{}, fieldNames ...string) error { + d := reflect.ValueOf(dest) + if d.Kind() != reflect.Ptr || d.IsNil() { + return errScanSliceValue + } + d = d.Elem() + if d.Kind() != reflect.Slice { + return errScanSliceValue + } + + isPtr := false + t := d.Type().Elem() + if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct { + isPtr = true + t = t.Elem() + } + + if t.Kind() != reflect.Struct { + ensureLen(d, len(src)) + for i, s := range src { + if s == nil { + continue + } + if err := convertAssignValue(d.Index(i), s); err != nil { + return err + } + } + return nil + } + + ss := structSpecForType(t) + fss := ss.l + if len(fieldNames) > 0 { + fss = make([]*fieldSpec, len(fieldNames)) + for i, name := range fieldNames { + fss[i] = ss.m[name] + if fss[i] == nil { + return errors.New("redigo: ScanSlice bad field name " + name) + } + } + } + + if len(fss) == 0 { + return errors.New("redigo: ScanSlice no struct fields") + } + + n := len(src) / len(fss) + if n*len(fss) != len(src) { + return errors.New("redigo: ScanSlice length not a multiple of struct field count") + } + + ensureLen(d, n) + for i := 0; i < n; i++ { + d := d.Index(i) + if isPtr { + if d.IsNil() { + d.Set(reflect.New(t)) + } + d = d.Elem() + } + for j, fs := range fss { + s := src[i*len(fss)+j] + if s == nil { + continue + } + if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil { + return err + } + } + } + return nil +} + +// Args is a helper for constructing command arguments from structured values. +type Args []interface{} + +// Add returns the result of appending value to args. +func (args Args) Add(value ...interface{}) Args { + return append(args, value...) +} + +// AddFlat returns the result of appending the flattened value of v to args. +// +// Maps are flattened by appending the alternating keys and map values to args. +// +// Slices are flattened by appending the slice elements to args. +// +// Structs are flattened by appending the alternating names and values of +// exported fields to args. If v is a nil struct pointer, then nothing is +// appended. The 'redis' field tag overrides struct field names. See ScanStruct +// for more information on the use of the 'redis' field tag. +// +// Other types are appended to args as is. +func (args Args) AddFlat(v interface{}) Args { + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Struct: + args = flattenStruct(args, rv) + case reflect.Slice: + for i := 0; i < rv.Len(); i++ { + args = append(args, rv.Index(i).Interface()) + } + case reflect.Map: + for _, k := range rv.MapKeys() { + args = append(args, k.Interface(), rv.MapIndex(k).Interface()) + } + case reflect.Ptr: + if rv.Type().Elem().Kind() == reflect.Struct { + if !rv.IsNil() { + args = flattenStruct(args, rv.Elem()) + } + } else { + args = append(args, v) + } + default: + args = append(args, v) + } + return args +} + +func flattenStruct(args Args, v reflect.Value) Args { + ss := structSpecForType(v.Type()) + for _, fs := range ss.l { + fv := v.FieldByIndex(fs.index) + args = append(args, fs.name, fv.Interface()) + } + return args +} diff --git a/client/go/redis/script.go b/client/go/redis/script.go new file mode 100644 index 0000000..2417753 --- /dev/null +++ b/client/go/redis/script.go @@ -0,0 +1,86 @@ +// Copyright 2012 Gary Burd +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package redis + +import ( + "crypto/sha1" + "encoding/hex" + "io" + "strings" +) + +// Script encapsulates the source, hash and key count for a Lua script. See +// http://redis.io/commands/eval for information on scripts in Redis. +type Script struct { + keyCount int + src string + hash string +} + +// NewScript returns a new script object. If keyCount is greater than or equal +// to zero, then the count is automatically inserted in the EVAL command +// argument list. If keyCount is less than zero, then the application supplies +// the count as the first value in the keysAndArgs argument to the Do, Send and +// SendHash methods. +func NewScript(keyCount int, src string) *Script { + h := sha1.New() + io.WriteString(h, src) + return &Script{keyCount, src, hex.EncodeToString(h.Sum(nil))} +} + +func (s *Script) args(spec string, keysAndArgs []interface{}) []interface{} { + var args []interface{} + if s.keyCount < 0 { + args = make([]interface{}, 1+len(keysAndArgs)) + args[0] = spec + copy(args[1:], keysAndArgs) + } else { + args = make([]interface{}, 2+len(keysAndArgs)) + args[0] = spec + args[1] = s.keyCount + copy(args[2:], keysAndArgs) + } + return args +} + +// Do evalutes the script. Under the covers, Do optimistically evaluates the +// script using the EVALSHA command. If the command fails because the script is +// not loaded, then Do evaluates the script using the EVAL command (thus +// causing the script to load). +func (s *Script) Do(c Conn, keysAndArgs ...interface{}) (interface{}, error) { + v, err := c.Do("EVALSHA", s.args(s.hash, keysAndArgs)...) + if e, ok := err.(Error); ok && strings.HasPrefix(string(e), "NOSCRIPT ") { + v, err = c.Do("EVAL", s.args(s.src, keysAndArgs)...) + } + return v, err +} + +// SendHash evaluates the script without waiting for the reply. The script is +// evaluated with the EVALSHA command. The application must ensure that the +// script is loaded by a previous call to Send, Do or Load methods. +func (s *Script) SendHash(c Conn, keysAndArgs ...interface{}) error { + return c.Send("EVALSHA", s.args(s.hash, keysAndArgs)...) +} + +// Send evaluates the script without waiting for the reply. +func (s *Script) Send(c Conn, keysAndArgs ...interface{}) error { + return c.Send("EVAL", s.args(s.src, keysAndArgs)...) +} + +// Load loads the script without evaluating it. +func (s *Script) Load(c Conn) error { + _, err := c.Do("SCRIPT", "LOAD", s.src) + return err +} diff --git a/server/app_test.go b/server/app_test.go index 3357ddd..6ae909a 100644 --- a/server/app_test.go +++ b/server/app_test.go @@ -1,7 +1,7 @@ package server import ( - "github.com/garyburd/redigo/redis" + "github.com/siddontang/ledisdb/client/go/redis" "os" "sync" "testing" diff --git a/server/cmd_hash_test.go b/server/cmd_hash_test.go index 26b0b47..3894386 100644 --- a/server/cmd_hash_test.go +++ b/server/cmd_hash_test.go @@ -2,7 +2,7 @@ package server import ( "fmt" - "github.com/garyburd/redigo/redis" + "github.com/siddontang/ledisdb/client/go/redis" "strconv" "testing" ) diff --git a/server/cmd_kv_test.go b/server/cmd_kv_test.go index 8318a97..91e0126 100644 --- a/server/cmd_kv_test.go +++ b/server/cmd_kv_test.go @@ -1,7 +1,7 @@ package server import ( - "github.com/garyburd/redigo/redis" + "github.com/siddontang/ledisdb/client/go/redis" "testing" ) diff --git a/server/cmd_list_test.go b/server/cmd_list_test.go index 58bab33..c970052 100644 --- a/server/cmd_list_test.go +++ b/server/cmd_list_test.go @@ -2,7 +2,7 @@ package server import ( "fmt" - "github.com/garyburd/redigo/redis" + "github.com/siddontang/ledisdb/client/go/redis" "strconv" "testing" ) diff --git a/server/cmd_ttl_test.go b/server/cmd_ttl_test.go index da90d6c..7a1f1f4 100644 --- a/server/cmd_ttl_test.go +++ b/server/cmd_ttl_test.go @@ -1,7 +1,7 @@ package server import ( - "github.com/garyburd/redigo/redis" + "github.com/siddontang/ledisdb/client/go/redis" "testing" "time" ) diff --git a/server/cmd_zset_test.go b/server/cmd_zset_test.go index 8ac59ab..fc5b9e0 100644 --- a/server/cmd_zset_test.go +++ b/server/cmd_zset_test.go @@ -2,7 +2,7 @@ package server import ( "fmt" - "github.com/garyburd/redigo/redis" + "github.com/siddontang/ledisdb/client/go/redis" "strconv" "testing" ) From c5af770387bc493183855b51057faf0582f0d238 Mon Sep 17 00:00:00 2001 From: siddontang Date: Thu, 19 Jun 2014 17:50:27 +0800 Subject: [PATCH 8/9] use my own log --- bootstrap.sh | 2 +- cmd/ledis-benchmark/main.go | 2 +- ledis/binlog.go | 2 +- ledis/ledis.go | 2 +- ledis/replication.go | 2 +- log/filehandler.go | 193 ++++++++++++++++++++++++++++++ log/handler.go | 45 +++++++ log/log.go | 226 ++++++++++++++++++++++++++++++++++++ log/log_test.go | 52 +++++++++ log/sockethandler.go | 62 ++++++++++ server/accesslog.go | 2 +- server/client.go | 2 +- server/replication.go | 2 +- 13 files changed, 586 insertions(+), 8 deletions(-) create mode 100644 log/filehandler.go create mode 100644 log/handler.go create mode 100644 log/log.go create mode 100644 log/log_test.go create mode 100644 log/sockethandler.go diff --git a/bootstrap.sh b/bootstrap.sh index c05f3db..29344e6 100644 --- a/bootstrap.sh +++ b/bootstrap.sh @@ -2,4 +2,4 @@ . ./dev.sh -go get -u github.com/siddontang/go-log/log +#nothing to do now \ No newline at end of file diff --git a/cmd/ledis-benchmark/main.go b/cmd/ledis-benchmark/main.go index ac67411..9209dd7 100644 --- a/cmd/ledis-benchmark/main.go +++ b/cmd/ledis-benchmark/main.go @@ -3,7 +3,7 @@ package main import ( "flag" "fmt" - "github.com/garyburd/redigo/redis" + "github.com/siddontang/ledisdb/client/go/redis" "math/rand" "sync" "time" diff --git a/ledis/binlog.go b/ledis/binlog.go index d6e99f0..4785370 100644 --- a/ledis/binlog.go +++ b/ledis/binlog.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "encoding/json" "fmt" - "github.com/siddontang/go-log/log" + "github.com/siddontang/ledisdb/log" "io/ioutil" "os" "path" diff --git a/ledis/ledis.go b/ledis/ledis.go index 668098c..6bf333a 100644 --- a/ledis/ledis.go +++ b/ledis/ledis.go @@ -3,8 +3,8 @@ package ledis import ( "encoding/json" "fmt" - "github.com/siddontang/go-log/log" "github.com/siddontang/ledisdb/leveldb" + "github.com/siddontang/ledisdb/log" "path" "sync" "time" diff --git a/ledis/replication.go b/ledis/replication.go index e19da6a..fce8fd4 100644 --- a/ledis/replication.go +++ b/ledis/replication.go @@ -5,7 +5,7 @@ import ( "bytes" "encoding/binary" "errors" - "github.com/siddontang/go-log/log" + "github.com/siddontang/ledisdb/log" "io" "os" ) diff --git a/log/filehandler.go b/log/filehandler.go new file mode 100644 index 0000000..e77eefe --- /dev/null +++ b/log/filehandler.go @@ -0,0 +1,193 @@ +package log + +import ( + "fmt" + "os" + "path" + "time" +) + +type FileHandler struct { + fd *os.File +} + +func NewFileHandler(fileName string, flag int) (*FileHandler, error) { + dir := path.Dir(fileName) + os.Mkdir(dir, 0777) + + f, err := os.OpenFile(fileName, flag, 0) + if err != nil { + return nil, err + } + + h := new(FileHandler) + + h.fd = f + + return h, nil +} + +func (h *FileHandler) Write(b []byte) (n int, err error) { + return h.fd.Write(b) +} + +func (h *FileHandler) Close() error { + return h.fd.Close() +} + +type RotatingFileHandler struct { + fd *os.File + + fileName string + maxBytes int + backupCount int +} + +func NewRotatingFileHandler(fileName string, maxBytes int, backupCount int) (*RotatingFileHandler, error) { + dir := path.Dir(fileName) + os.Mkdir(dir, 0777) + + h := new(RotatingFileHandler) + + if maxBytes <= 0 { + return nil, fmt.Errorf("invalid max bytes") + } + + h.fileName = fileName + h.maxBytes = maxBytes + h.backupCount = backupCount + + var err error + h.fd, err = os.OpenFile(fileName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + return nil, err + } + + return h, nil +} + +func (h *RotatingFileHandler) Write(p []byte) (n int, err error) { + h.doRollover() + return h.fd.Write(p) +} + +func (h *RotatingFileHandler) Close() error { + if h.fd != nil { + return h.fd.Close() + } + return nil +} + +func (h *RotatingFileHandler) doRollover() { + f, err := h.fd.Stat() + if err != nil { + return + } + + if h.maxBytes <= 0 { + return + } else if f.Size() < int64(h.maxBytes) { + return + } + + if h.backupCount > 0 { + h.fd.Close() + + for i := h.backupCount - 1; i > 0; i-- { + sfn := fmt.Sprintf("%s.%d", h.fileName, i) + dfn := fmt.Sprintf("%s.%d", h.fileName, i+1) + + os.Rename(sfn, dfn) + } + + dfn := fmt.Sprintf("%s.1", h.fileName) + os.Rename(h.fileName, dfn) + + h.fd, _ = os.OpenFile(h.fileName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + } +} + +//refer: http://docs.python.org/2/library/logging.handlers.html +//same like python TimedRotatingFileHandler + +type TimeRotatingFileHandler struct { + fd *os.File + + baseName string + interval int64 + suffix string + rolloverAt int64 +} + +const ( + WhenSecond = iota + WhenMinute + WhenHour + WhenDay +) + +func NewTimeRotatingFileHandler(baseName string, when int8, interval int) (*TimeRotatingFileHandler, error) { + dir := path.Dir(baseName) + os.Mkdir(dir, 0777) + + h := new(TimeRotatingFileHandler) + + h.baseName = baseName + + switch when { + case WhenSecond: + h.interval = 1 + h.suffix = "2006-01-02_15-04-05" + case WhenMinute: + h.interval = 60 + h.suffix = "2006-01-02_15-04" + case WhenHour: + h.interval = 3600 + h.suffix = "2006-01-02_15" + case WhenDay: + h.interval = 3600 * 24 + h.suffix = "2006-01-02" + default: + return nil, fmt.Errorf("invalid when_rotate: %d", when) + } + + h.interval = h.interval * int64(interval) + + var err error + h.fd, err = os.OpenFile(h.baseName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + return nil, err + } + + fInfo, _ := h.fd.Stat() + h.rolloverAt = fInfo.ModTime().Unix() + h.interval + + return h, nil +} + +func (h *TimeRotatingFileHandler) doRollover() { + //refer http://hg.python.org/cpython/file/2.7/Lib/logging/handlers.py + now := time.Now() + + if h.rolloverAt <= now.Unix() { + fName := h.baseName + now.Format(h.suffix) + h.fd.Close() + e := os.Rename(h.baseName, fName) + if e != nil { + panic(e) + } + + h.fd, _ = os.OpenFile(h.baseName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + + h.rolloverAt = time.Now().Unix() + h.interval + } +} + +func (h *TimeRotatingFileHandler) Write(b []byte) (n int, err error) { + h.doRollover() + return h.fd.Write(b) +} + +func (h *TimeRotatingFileHandler) Close() error { + return h.fd.Close() +} diff --git a/log/handler.go b/log/handler.go new file mode 100644 index 0000000..66257ad --- /dev/null +++ b/log/handler.go @@ -0,0 +1,45 @@ +package log + +import ( + "io" +) + +type Handler interface { + Write(p []byte) (n int, err error) + Close() error +} + +type StreamHandler struct { + w io.Writer +} + +func NewStreamHandler(w io.Writer) (*StreamHandler, error) { + h := new(StreamHandler) + + h.w = w + + return h, nil +} + +func (h *StreamHandler) Write(b []byte) (n int, err error) { + return h.w.Write(b) +} + +func (h *StreamHandler) Close() error { + return nil +} + +type NullHandler struct { +} + +func NewNullHandler() (*NullHandler, error) { + return new(NullHandler), nil +} + +func (h *NullHandler) Write(b []byte) (n int, err error) { + return len(b), nil +} + +func (h *NullHandler) Close() { + +} diff --git a/log/log.go b/log/log.go new file mode 100644 index 0000000..3a59df8 --- /dev/null +++ b/log/log.go @@ -0,0 +1,226 @@ +package log + +import ( + "fmt" + "os" + "runtime" + "strconv" + "sync" + "time" +) + +const ( + LevelTrace = iota + LevelDebug + LevelInfo + LevelWarn + LevelError + LevelFatal +) + +const ( + Ltime = 1 << iota //time format "2006/01/02 15:04:05" + Lfile //file.go:123 + Llevel //[Trace|Debug|Info...] +) + +var LevelName [6]string = [6]string{"Trace", "Debug", "Info", "Warn", "Error", "Fatal"} + +const TimeFormat = "2006/01/02 15:04:05" + +const maxBufPoolSize = 16 + +type Logger struct { + sync.Mutex + + level int + flag int + + handler Handler + + quit chan struct{} + msg chan []byte + + bufs [][]byte +} + +func New(handler Handler, flag int) *Logger { + var l = new(Logger) + + l.level = LevelInfo + l.handler = handler + + l.flag = flag + + l.quit = make(chan struct{}) + + l.msg = make(chan []byte, 1024) + + l.bufs = make([][]byte, 0, 16) + + go l.run() + + return l +} + +func NewDefault(handler Handler) *Logger { + return New(handler, Ltime|Lfile|Llevel) +} + +func newStdHandler() *StreamHandler { + h, _ := NewStreamHandler(os.Stdout) + return h +} + +var std = NewDefault(newStdHandler()) + +func (l *Logger) run() { + for { + select { + case msg := <-l.msg: + l.handler.Write(msg) + l.putBuf(msg) + case <-l.quit: + l.handler.Close() + } + } +} + +func (l *Logger) popBuf() []byte { + l.Lock() + var buf []byte + if len(l.bufs) == 0 { + buf = make([]byte, 0, 1024) + } else { + buf = l.bufs[len(l.bufs)-1] + l.bufs = l.bufs[0 : len(l.bufs)-1] + } + l.Unlock() + + return buf +} + +func (l *Logger) putBuf(buf []byte) { + l.Lock() + if len(l.bufs) < maxBufPoolSize { + buf = buf[0:0] + l.bufs = append(l.bufs, buf) + } + l.Unlock() +} + +func (l *Logger) Close() { + if l.quit == nil { + return + } + + close(l.quit) + l.quit = nil +} + +func (l *Logger) SetLevel(level int) { + l.level = level +} + +func (l *Logger) Output(callDepth int, level int, format string, v ...interface{}) { + if l.level > level { + return + } + + buf := l.popBuf() + + if l.flag&Ltime > 0 { + now := time.Now().Format(TimeFormat) + buf = append(buf, '[') + buf = append(buf, now...) + buf = append(buf, "] "...) + } + + if l.flag&Lfile > 0 { + _, file, line, ok := runtime.Caller(callDepth) + if !ok { + file = "???" + line = 0 + } else { + for i := len(file) - 1; i > 0; i-- { + if file[i] == '/' { + file = file[i+1:] + break + } + } + } + + buf = append(buf, file...) + buf = append(buf, ':') + + strconv.AppendInt(buf, int64(line), 10) + } + + if l.flag&Llevel > 0 { + buf = append(buf, '[') + buf = append(buf, LevelName[level]...) + buf = append(buf, "] "...) + } + + s := fmt.Sprintf(format, v...) + + buf = append(buf, s...) + + if s[len(s)-1] != '\n' { + buf = append(buf, '\n') + } + + l.msg <- buf +} + +func (l *Logger) Trace(format string, v ...interface{}) { + l.Output(2, LevelTrace, format, v...) +} + +func (l *Logger) Debug(format string, v ...interface{}) { + l.Output(2, LevelDebug, format, v...) +} + +func (l *Logger) Info(format string, v ...interface{}) { + l.Output(2, LevelInfo, format, v...) +} + +func (l *Logger) Warn(format string, v ...interface{}) { + l.Output(2, LevelWarn, format, v...) +} + +func (l *Logger) Error(format string, v ...interface{}) { + l.Output(2, LevelError, format, v...) +} + +func (l *Logger) Fatal(format string, v ...interface{}) { + l.Output(2, LevelFatal, format, v...) +} + +func SetLevel(level int) { + std.SetLevel(level) +} + +func Trace(format string, v ...interface{}) { + std.Output(2, LevelTrace, format, v...) +} + +func Debug(format string, v ...interface{}) { + std.Output(2, LevelDebug, format, v...) +} + +func Info(format string, v ...interface{}) { + std.Output(2, LevelInfo, format, v...) +} + +func Warn(format string, v ...interface{}) { + std.Output(2, LevelWarn, format, v...) +} + +func Error(format string, v ...interface{}) { + std.Output(2, LevelError, format, v...) +} + +func Fatal(format string, v ...interface{}) { + std.Output(2, LevelFatal, format, v...) +} diff --git a/log/log_test.go b/log/log_test.go new file mode 100644 index 0000000..67d3b0b --- /dev/null +++ b/log/log_test.go @@ -0,0 +1,52 @@ +package log + +import ( + "os" + "testing" +) + +func TestStdStreamLog(t *testing.T) { + h, _ := NewStreamHandler(os.Stdout) + s := NewDefault(h) + s.Info("hello world") + + s.Close() + + Info("hello world") +} + +func TestRotatingFileLog(t *testing.T) { + path := "./test_log" + os.RemoveAll(path) + + os.Mkdir(path, 0777) + fileName := path + "/test" + + h, err := NewRotatingFileHandler(fileName, 10, 2) + if err != nil { + t.Fatal(err) + } + + buf := make([]byte, 10) + + h.Write(buf) + + h.Write(buf) + + if _, err := os.Stat(fileName + ".1"); err != nil { + t.Fatal(err) + } + + if _, err := os.Stat(fileName + ".2"); err == nil { + t.Fatal(err) + } + + h.Write(buf) + if _, err := os.Stat(fileName + ".2"); err != nil { + t.Fatal(err) + } + + h.Close() + + os.RemoveAll(path) +} diff --git a/log/sockethandler.go b/log/sockethandler.go new file mode 100644 index 0000000..f19db05 --- /dev/null +++ b/log/sockethandler.go @@ -0,0 +1,62 @@ +package log + +import ( + "encoding/binary" + "net" + "time" +) + +type SocketHandler struct { + c net.Conn + protocol string + addr string +} + +func NewSocketHandler(protocol string, addr string) (*SocketHandler, error) { + s := new(SocketHandler) + + s.protocol = protocol + s.addr = addr + + return s, nil +} + +func (h *SocketHandler) Write(p []byte) (n int, err error) { + if err = h.connect(); err != nil { + return + } + + buf := make([]byte, len(p)+4) + + binary.BigEndian.PutUint32(buf, uint32(len(p))) + + copy(buf[4:], p) + + n, err = h.c.Write(buf) + if err != nil { + h.c.Close() + h.c = nil + } + return +} + +func (h *SocketHandler) Close() error { + if h.c != nil { + h.c.Close() + } + return nil +} + +func (h *SocketHandler) connect() error { + if h.c != nil { + return nil + } + + var err error + h.c, err = net.DialTimeout(h.protocol, h.addr, 20*time.Second) + if err != nil { + return err + } + + return nil +} diff --git a/server/accesslog.go b/server/accesslog.go index 9e517a8..2190899 100644 --- a/server/accesslog.go +++ b/server/accesslog.go @@ -1,7 +1,7 @@ package server import ( - "github.com/siddontang/go-log/log" + "github.com/siddontang/ledisdb/log" ) const ( diff --git a/server/client.go b/server/client.go index 3f5c43c..21412a3 100644 --- a/server/client.go +++ b/server/client.go @@ -4,8 +4,8 @@ import ( "bufio" "bytes" "errors" - "github.com/siddontang/go-log/log" "github.com/siddontang/ledisdb/ledis" + "github.com/siddontang/ledisdb/log" "io" "net" "runtime" diff --git a/server/replication.go b/server/replication.go index 383a244..bb3cc49 100644 --- a/server/replication.go +++ b/server/replication.go @@ -7,8 +7,8 @@ import ( "encoding/json" "errors" "fmt" - "github.com/siddontang/go-log/log" "github.com/siddontang/ledisdb/ledis" + "github.com/siddontang/ledisdb/log" "io/ioutil" "net" "os" From 9e09e607c8983958dd9a1a51e6dfa168b9bfcbf8 Mon Sep 17 00:00:00 2001 From: siddontang Date: Fri, 20 Jun 2014 10:12:50 +0800 Subject: [PATCH 9/9] heel use iterator to multi find and delete --- ledis/t_hash.go | 10 ++++++---- server/cmd_hash.go | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/ledis/t_hash.go b/ledis/t_hash.go index 4fa4cdc..5186c04 100644 --- a/ledis/t_hash.go +++ b/ledis/t_hash.go @@ -252,7 +252,7 @@ func (db *DB) HMget(key []byte, args ...[]byte) ([]interface{}, error) { return r, nil } -func (db *DB) HDel(key []byte, args [][]byte) (int64, error) { +func (db *DB) HDel(key []byte, args ...[]byte) (int64, error) { t := db.hashTx var ek []byte @@ -262,6 +262,9 @@ func (db *DB) HDel(key []byte, args [][]byte) (int64, error) { t.Lock() defer t.Unlock() + it := db.db.NewIterator() + defer it.Close() + var num int64 = 0 for i := 0; i < len(args); i++ { if err := checkHashKFSize(key, args[i]); err != nil { @@ -270,9 +273,8 @@ func (db *DB) HDel(key []byte, args [][]byte) (int64, error) { ek = db.hEncodeHashKey(key, args[i]) - if v, err = db.db.Get(ek); err != nil { - return 0, err - } else if v == nil { + v = it.Find(ek) + if v == nil { continue } else { num++ diff --git a/server/cmd_hash.go b/server/cmd_hash.go index a20fef6..23873b1 100644 --- a/server/cmd_hash.go +++ b/server/cmd_hash.go @@ -59,7 +59,7 @@ func hdelCommand(c *client) error { return ErrCmdParams } - if n, err := c.db.HDel(args[0], args[1:]); err != nil { + if n, err := c.db.HDel(args[0], args[1:]...); err != nil { return err } else { c.writeInteger(n)