diff --git a/tools/redis_import/README.md b/tools/redis_import/README.md index 3f97436..e36b629 100644 --- a/tools/redis_import/README.md +++ b/tools/redis_import/README.md @@ -1,7 +1,7 @@ ## Notice 1. We don't support `set` data type. -2. Our `zset` use integer instead of float, so the zset float score in Redis +2. Our `zset` use integer instead of double, so the zset float score in Redis will be **converted to integer**. 3. Only Support Redis version greater than `2.8.0`, because we use `scan` command to scan data. Also, you need `redis-py` greater than `2.9.0` diff --git a/tools/redis_import/redis_import.py b/tools/redis_import/redis_import.py index 7bc7ed3..820487b 100644 --- a/tools/redis_import/redis_import.py +++ b/tools/redis_import/redis_import.py @@ -10,6 +10,7 @@ import os from collections import OrderedDict as od import redis +import ledis total = 0 entries = 0 @@ -29,23 +30,39 @@ def scan_available(redis_client): return False -def copy_key(redis_client, ledis_client, key): +def set_ttl(redis_client, ledis_client, key, k_type): + k_types = { + "string": ledis_client.expire, + "list": ledis_client.lexpire, + "hash": ledis_client.hexpire, + "set": ledis_client.zexpire, + "zset": ledis_client.zexpire + } + timeout = redis_client.ttl(key) + if timeout > 0: + k_types[k_type](key, timeout) + + +def copy_key(redis_client, ledis_client, key, convert=False): global entries k_type = redis_client.type(key) if k_type == "string": value = redis_client.get(key) ledis_client.set(key, value) + set_ttl(redis_client, ledis_client, key, k_type) entries += 1 elif k_type == "list": _list = redis_client.lrange(key, 0, -1) for value in _list: ledis_client.rpush(key, value) + set_ttl(redis_client, ledis_client, key, k_type) entries += 1 elif k_type == "hash": mapping = od(redis_client.hgetall(key)) ledis_client.hmset(key, mapping) + set_ttl(redis_client, ledis_client, key, k_type) entries += 1 elif k_type == "zset": @@ -54,24 +71,47 @@ def copy_key(redis_client, ledis_client, key): for i in od(out).iteritems(): pieces[i[0]] = int(i[1]) ledis_client.zadd(key, **pieces) + set_ttl(redis_client, ledis_client, key, k_type) entries += 1 + elif k_type == "set": + if convert: + print "Convert set %s to zset\n" % key + members = redis_client.smembers(key) + set_to_zset(ledis_client, key, members) + entries += 1 + else: + print "KEY %s of TYPE %s will not be converted to Zset" % (key, k_type) + else: - print "%s is not supported by LedisDB." % k_type + print "KEY %s of TYPE %s is not supported by LedisDB." % (key, k_type) -def copy_keys(redis_client, ledis_client, keys): +def copy_keys(redis_client, ledis_client, keys, convert=False): for key in keys: - copy_key(redis_client, ledis_client, key) + copy_key(redis_client, ledis_client, key, convert=convert) -def copy(redis_client, ledis_client, redis_db): - global total +def scan(redis_client, count=1000): + keys = [] + total = redis_client.dbsize() + + first = True + cursor = 0 + while cursor != 0 or first: + cursor, data = redis_client.scan(cursor, count=count) + keys.extend(data) + first = False + print len(keys) + print total + assert len(keys) == total + return keys, total + + +def copy(redis_client, ledis_client, count=1000, convert=False): if scan_available(redis_client): - total = redis_client.dbsize() - # scan return a - keys = redis_client.scan(cursor=0, count=total)[1] - copy_keys(redis_client, ledis_client, keys) + keys, total = scan(redis_client, count=count) + copy_keys(redis_client, ledis_client, keys, convert=convert) else: msg = """We do not support Redis version less than 2.8.0. @@ -83,23 +123,33 @@ def copy(redis_client, ledis_client, redis_db): print "%d keys, %d entries copied" % (total, entries) +def set_to_zset(ledis_client, key, members): + d = {} + for m in members: + d[m] = int(0) + ledis_client.zadd(key, **d) + + def usage(): usage = """ Usage: - python %s redis_host redis_port redis_db ledis_host ledis_port + python %s redis_host redis_port redis_db ledis_host ledis_port [True] """ print usage % os.path.basename(sys.argv[0]) def main(): - if len(sys.argv) != 6: + if len(sys.argv) < 6: usage() sys.exit() - - (redis_host, redis_port, redis_db, ledis_host, ledis_port) = sys.argv[1:] + convert = False + if len(sys.argv) >= 6: + (redis_host, redis_port, redis_db, ledis_host, ledis_port) = sys.argv[1:6] + if len(sys.argv) == 7 and sys.argv[-1] == "True" or sys.argv[-1] == "true": + convert = True redis_c = redis.Redis(host=redis_host, port=int(redis_port), db=int(redis_db)) - ledis_c = redis.Redis(host=ledis_host, port=int(ledis_port), db=int(redis_db)) + ledis_c = ledis.Ledis(host=ledis_host, port=int(ledis_port), db=int(redis_db)) try: redis_c.ping() except redis.ConnectionError: @@ -112,8 +162,8 @@ def main(): print "Could not connect to LedisDB Server" sys.exit() - copy(redis_c, ledis_c, redis_db) - print "done\n" + copy(redis_c, ledis_c, convert=convert) + print "done\n" if __name__ == "__main__": diff --git a/tools/redis_import/test.py b/tools/redis_import/test.py index 9395321..a7f7f09 100644 --- a/tools/redis_import/test.py +++ b/tools/redis_import/test.py @@ -3,8 +3,12 @@ import random, string import redis +import ledis -from redis_import import copy +from redis_import import copy, scan, set_ttl + +rds = redis.Redis() +lds = ledis.Ledis(port=6380) def random_word(words, length): @@ -14,7 +18,7 @@ def random_word(words, length): def get_words(): word_file = "/usr/share/dict/words" words = open(word_file).read().splitlines() - return words[:10] + return words[:1000] def get_mapping(words, length=1000): @@ -24,44 +28,57 @@ def get_mapping(words, length=1000): return d -def random_set(client, words, length=1000): +def random_string(client, words, length=1000): d = get_mapping(words, length) client.mset(d) -def random_hset(client, words, length=1000): +def random_hash(client, words, length=1000): d = get_mapping(words, length) client.hmset("hashName", d) -def random_lpush(client, words, length=1000): +def random_list(client, words, length=1000): client.lpush("listName", *words) -def random_zadd(client, words, length=1000): +def random_zset(client, words, length=1000): d = get_mapping(words, length) - client.zadd("myset", **d) + client.zadd("zsetName", **d) + + +def random_set(client, words, length=1000): + client.sadd("setName", *words) def test(): words = get_words() - rds = redis.Redis() print "Flush all redis data before insert new." rds.flushall() + random_string(rds, words) + print "random_string done" + + random_hash(rds, words) + print "random_hash done" + + random_list(rds, words) + print "random_list done" + + random_zset(rds, words) + print "random_zset done" + random_set(rds, words) print "random_set done" - random_hset(rds, words) - print "random_hset done" - random_lpush(rds, words) - print "random_lpush done" - random_zadd(rds, words) - lds = redis.Redis(port=6380) - copy(rds, lds, 0) + lds.lclear("listName") + lds.hclear("hashName") + lds.zclear("zsetName") + lds.zclear("setName") + copy(rds, lds, convert=True) # for all keys - keys = rds.scan(0, count=rds.dbsize()) + keys = scan(rds, 1000) for key in keys: if rds.type(key) == "string" and not lds.exists(key): print key @@ -73,17 +90,51 @@ def test(): assert l1 == l2 #for hash - for key in keys: - if rds.type(key) == "hash" and not lds.hexists("hashName", key): - print "List data not consistent" + if rds.type(key) == "hash": + assert rds.hgetall(key) == lds.hgetall(key) + assert sorted(rds.hkeys(key)) == sorted(lds.hkeys(key)) + assert sorted(rds.hvals(key)) == sorted(lds.hvals(key)) # for zset - z1 = rds.zrange("myset", 0, -1, withscores=True) - z2 = lds.zrange("myset", 0, -1, withscores=True) + z1 = rds.zrange("zsetName", 0, -1, withscores=True) + z2 = lds.zrange("zsetName", 0, -1, withscores=True) assert z1 == z2 - + + # fo set + assert set(rds.smembers("setName")) == set(lds.zrange("setName", 0, -1)) + for key in lds.zrange("setName", 0, -1): + assert int(lds.zscore("setName", key)) == 0 + + +def ledis_ttl(ledis_client, key, k_type): + ttls = { + "string": lds.ttl, + "list": lds.lttl, + "hash": lds.httl, + "zset": lds.zttl, + "set": lds.zttl + } + return ttls[k_type](key) + + +def test_ttl(): + keys, total = scan(rds, 1000) + invalid = [] + for key in keys: + k_type = rds.type(key) + rds.expire(key, 100) + set_ttl(rds, lds, key, k_type) + # if rds.ttl(key) != ledis_ttl(lds, key, k_type): + # print key + # print rds.ttl(key) + # print ledis_ttl(lds, key, k_type) + # invalid.append(key) + + assert rds.ttl(key) == ledis_ttl(lds, key, k_type) + print len(invalid) if __name__ == "__main__": test() + test_ttl() print "Test passed."