diff --git a/Makefile b/Makefile index e2725be..f5b6dcd 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ INSTALL_PATH ?= $(CURDIR) -$(shell ./bootstrap.sh) +$(shell ./bootstrap.sh >> /dev/null 2>&1) $(shell ./tools/build_config.sh build_config.mk $INSTALL_PATH) @@ -23,3 +23,6 @@ clean: test: go test -tags '$(GO_BUILD_TAGS)' ./... + +pytest: + sh client/ledis-py/tests/all.sh diff --git a/README.md b/README.md index 4524884..07437da 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,9 @@ LedisDB now supports multiple databases as backend to store data, you can test a + Rich data structure: KV, List, Hash, ZSet, Bitmap, Set. + Stores lots of data, over the memory limit. -+ Various backend database to use: LevelDB, goleveldb, LMDB, RocksDB, BoltDB, HyperLevelDB. ++ Various backend database to use: LevelDB, goleveldb, LMDB, RocksDB, BoltDB, HyperLevelDB, Memory. + Supports transaction using LMDB or BotlDB. ++ Supports lua scripting. + Supports expiration and ttl. + Redis clients, like redis-cli, are supported directly. + Multiple client API supports, including Go, Python, Lua(Openresty), C/C++, Node.js. @@ -92,6 +93,7 @@ Choosing a store database to use is very simple, you have two ways: + You must known that changing store database runtime is very dangerous, LedisDB will not guarantee the data validation if you do it. + Begin a transaction will block any other write operators before you call `commit` or `rollback`. Don't use long-time transaction. ++ `pcall` and `xpcall` are not supported in lua, you can see the readme in [golua](https://github.com/aarzilli/golua). ## Configuration diff --git a/bootstrap.sh b/bootstrap.sh index ee260b7..ffb4c46 100755 --- a/bootstrap.sh +++ b/bootstrap.sh @@ -15,3 +15,5 @@ go get github.com/ugorji/go/codec go get github.com/BurntSushi/toml go get github.com/siddontang/go-bson/bson + +go get github.com/siddontang/golua/lua \ No newline at end of file diff --git a/client/ledis-py/Makefile b/client/ledis-py/Makefile new file mode 100644 index 0000000..9f53bed --- /dev/null +++ b/client/ledis-py/Makefile @@ -0,0 +1,5 @@ +.PHONY: test + + +test: + sh tests/all.sh diff --git a/client/ledis-py/ledis/client.py b/client/ledis-py/ledis/client.py index a083ca6..6615c91 100644 --- a/client/ledis-py/ledis/client.py +++ b/client/ledis-py/ledis/client.py @@ -4,7 +4,7 @@ import time as mod_time from itertools import chain, starmap from ledis._compat import (b, izip, imap, iteritems, basestring, long, nativestr, bytes) -from ledis.connection import ConnectionPool, UnixDomainSocketConnection +from ledis.connection import ConnectionPool, UnixDomainSocketConnection, Token from ledis.exceptions import ( ConnectionError, DataError, @@ -64,6 +64,31 @@ def int_or_none(response): return int(response) +def parse_info(response): + + info = {} + response = nativestr(response) + + def get_value(value): + if ',' not in value or '=' not in value: + try: + if '.' in value: + return float(value) + else: + return int(value) + except ValueError: + return value + + for line in response.splitlines(): + if line and not line.startswith('#'): + if line.find(':') != -1: + key, value = line.split(':', 1) + info[key] = get_value(value) + + return info + +# def parse_lscan(response, ) + class Ledis(object): """ Implementation of the Redis protocol. @@ -111,7 +136,9 @@ class Ledis(object): 'HGETALL': lambda r: r and pairs_to_dict(r) or {}, 'PING': lambda r: nativestr(r) == 'PONG', 'SET': lambda r: r and nativestr(r) == 'OK', - }, + 'INFO': parse_info, + } + ) @@ -219,8 +246,15 @@ class Ledis(object): db = 0 return self.execute_command('SELECT', db) - def info(self, section): - return self.execute_command('PING', section) + def info(self, section=None): + """ + Return + """ + + if section is None: + return self.execute_command("INFO") + else: + return self.execute_command('INFO', section) def flushall(self): return self.execute_command('FLUSHALL') @@ -350,8 +384,21 @@ class Ledis(object): "Removes an expiration on name" return self.execute_command('PERSIST', name) - def scan(self, key, match = "", count = 10): - return self.execute_command("SCAN", key, match, count) + def scan(self, key="" , match=None, count=10): + pieces = [key] + if match is not None: + pieces.extend(["MATCH", match]) + + pieces.extend(["COUNT", count]) + + return self.execute_command("SCAN", *pieces) + + def scan_iter(self, match=None, count=10): + key = "" + while key != "": + key, data = self.scan(key=key, match=match, count=count) + for item in data: + yield item #### LIST COMMANDS #### def lindex(self, name, index): @@ -428,8 +475,8 @@ class Ledis(object): "Removes an expiration on ``name``" return self.execute_command('LPERSIST', name) - def lscan(self, key, match = "", count = 10): - return self.execute_command("LSCAN", key, match, count) + def lscan(self, key="", match=None, count=10): + return self.scan_generic("LSCAN", key=key, match=match, count=count) #### SET COMMANDS #### @@ -528,8 +575,8 @@ class Ledis(object): "Removes an expiration on name" return self.execute_command('SPERSIST', name) - def sscan(self, key, match = "", count = 10): - return self.execute_command("SSCAN", key, match, count) + def sscan(self, key="", match=None, count = 10): + return self.scan_generic("SSCAN", key=key, match=match, count=count) #### SORTED SET COMMANDS #### @@ -727,9 +774,17 @@ class Ledis(object): "Removes an expiration on name" return self.execute_command('ZPERSIST', name) - def zscan(self, key, match = "", count = 10): - return self.execute_command("ZSCAN", key, match, count) + def scan_generic(self, scan_type, key="", match=None, count=10): + pieces = [key] + if match is not None: + pieces.extend([Token("MATCH"), match]) + pieces.extend([Token("count"), count]) + scan_type = scan_type.upper() + return self.execute_command(scan_type, *pieces) + + def zscan(self, key="", match=None, count=10): + return self.scan_generic("ZSCAN", key=key, match=match, count=count) #### HASH COMMANDS #### def hdel(self, name, *keys): @@ -823,8 +878,8 @@ class Ledis(object): "Removes an expiration on name" return self.execute_command('HPERSIST', name) - def hscan(self, key, match = "", count = 10): - return self.execute_command("HSCAN", key, match, count) + def hscan(self, key="", match=None, count=10): + return self.scan_generic("HSCAN", key=key, match=match, count=count) ### BIT COMMANDS @@ -902,8 +957,28 @@ class Ledis(object): "Removes an expiration on name" return self.execute_command('BPERSIST', name) - def bscan(self, key, match = "", count = 10): - return self.execute_command("BSCAN", key, match, count) + def bscan(self, key="", match=None, count=10): + return self.scan_generic("BSCAN", key=key, match=match, count=count) + + def eval(self, script, keys, *args): + n = len(keys) + args = list_or_args(keys, args) + return self.execute_command('EVAL', script, n, *args) + + def evalsha(self, sha1, keys, *args): + n = len(keys) + args = list_or_args(keys, args) + return self.execute_command('EVALSHA', sha1, n, *args) + + def scriptload(self, script): + return self.execute_command('SCRIPT', 'LOAD', script) + + def scriptexists(self, *args): + return self.execute_command('SCRIPT', 'EXISTS', *args) + + def scriptflush(self): + return self.execute_command('SCRIPT', 'FLUSH') + class Transaction(Ledis): def __init__(self, connection_pool, response_callbacks): diff --git a/client/ledis-py/ledis/connection.py b/client/ledis-py/ledis/connection.py index 4a39317..5372838 100644 --- a/client/ledis-py/ledis/connection.py +++ b/client/ledis-py/ledis/connection.py @@ -588,3 +588,23 @@ class BlockingConnectionPool(object): timeout=self.timeout, connection_class=self.connection_class, queue_class=self.queue_class, **self.connection_kwargs) + + +class Token(object): + """ + Literal strings in Redis commands, such as the command names and any + hard-coded arguments are wrapped in this class so we know not to apply + and encoding rules on them. + """ + def __init__(self, value): + if isinstance(value, Token): + value = value.value + self.value = value + + def __repr__(self): + return self.value + + def __str__(self): + return self.value + + diff --git a/client/ledis-py/tests/all.sh b/client/ledis-py/tests/all.sh new file mode 100644 index 0000000..8b7ae0f --- /dev/null +++ b/client/ledis-py/tests/all.sh @@ -0,0 +1,7 @@ +dbs=(leveldb rocksdb hyperleveldb goleveldb boltdb lmdb) +for db in "${dbs[@]}" +do + killall ledis-server + ledis-server -db_name=$db & + py.test +done diff --git a/client/ledis-py/tests/test_cmd_bit.py b/client/ledis-py/tests/test_cmd_bit.py index c1976a7..3d8d0c5 100644 --- a/client/ledis-py/tests/test_cmd_bit.py +++ b/client/ledis-py/tests/test_cmd_bit.py @@ -17,8 +17,7 @@ class TestCmdBit(unittest.TestCase): pass def tearDown(self): - l.bdelete('a') - l.bdelete('non_exists_key') + l.flushdb() def test_bget(self): "bget is the same as get in K/V commands" diff --git a/client/ledis-py/tests/test_cmd_hash.py b/client/ledis-py/tests/test_cmd_hash.py index 8a89af2..5efc86f 100644 --- a/client/ledis-py/tests/test_cmd_hash.py +++ b/client/ledis-py/tests/test_cmd_hash.py @@ -19,7 +19,7 @@ class TestCmdHash(unittest.TestCase): pass def tearDown(self): - l.hmclear('myhash', 'a') + l.flushdb() def test_hdel(self): diff --git a/client/ledis-py/tests/test_cmd_kv.py b/client/ledis-py/tests/test_cmd_kv.py index 774b800..b556c7a 100644 --- a/client/ledis-py/tests/test_cmd_kv.py +++ b/client/ledis-py/tests/test_cmd_kv.py @@ -18,7 +18,7 @@ class TestCmdKv(unittest.TestCase): pass def tearDown(self): - l.delete('a', 'b', 'c', 'non_exist_key') + l.flushdb() def test_decr(self): assert l.delete('a') == 1 diff --git a/client/ledis-py/tests/test_cmd_list.py b/client/ledis-py/tests/test_cmd_list.py index f87b8a1..065cee5 100644 --- a/client/ledis-py/tests/test_cmd_list.py +++ b/client/ledis-py/tests/test_cmd_list.py @@ -18,7 +18,7 @@ class TestCmdList(unittest.TestCase): pass def tearDown(self): - l.lmclear('mylist', 'mylist1', 'mylist2') + l.flushdb() def test_lindex(self): l.rpush('mylist', '1', '2', '3') diff --git a/client/ledis-py/tests/test_cmd_script.py b/client/ledis-py/tests/test_cmd_script.py new file mode 100644 index 0000000..4a08cb5 --- /dev/null +++ b/client/ledis-py/tests/test_cmd_script.py @@ -0,0 +1,55 @@ +# coding: utf-8 +# Test Cases for bit commands + +import unittest +import sys +sys.path.append('..') + +import ledis +from ledis._compat import b +from util import expire_at, expire_at_seconds + +l = ledis.Ledis(port=6380) + + +simple_script = "return {KEYS[1], KEYS[2], ARGV[1], ARGV[2]}" + + +class TestCmdScript(unittest.TestCase): + def setUp(self): + pass + + def tearDown(self): + l.flushdb() + + def test_eval(self): + assert l.eval(simple_script, ["key1", "key2"], "first", "second") == ["key1", "key2", "first", "second"] + + def test_evalsha(self): + sha1 = l.scriptload(simple_script) + assert len(sha1) == 40 + + assert l.evalsha(sha1, ["key1", "key2"], "first", "second") == ["key1", "key2", "first", "second"] + + def test_scriptload(self): + sha1 = l.scriptload(simple_script) + assert len(sha1) == 40 + + def test_scriptexists(self): + sha1 = l.scriptload(simple_script) + assert l.scriptexists(sha1) == [1L] + + def test_scriptflush(self): + sha1 = l.scriptload(simple_script) + assert l.scriptexists(sha1) == [1L] + assert l.scriptflush() == 'OK' + + assert l.scriptexists(sha1) == [0L] + + + + + + + + \ No newline at end of file diff --git a/client/ledis-py/tests/test_cmd_set.py b/client/ledis-py/tests/test_cmd_set.py index 0d2eec9..e98a762 100644 --- a/client/ledis-py/tests/test_cmd_set.py +++ b/client/ledis-py/tests/test_cmd_set.py @@ -20,7 +20,7 @@ class TestCmdSet(unittest.TestCase): pass def tearDown(self): - l.smclear('a', 'b', 'c') + l.flushdb() def test_sadd(self): members = set([b('1'), b('2'), b('3')]) diff --git a/client/ledis-py/tests/test_cmd_zset.py b/client/ledis-py/tests/test_cmd_zset.py index 08233fc..9277fce 100644 --- a/client/ledis-py/tests/test_cmd_zset.py +++ b/client/ledis-py/tests/test_cmd_zset.py @@ -17,7 +17,7 @@ class TestCmdZset(unittest.TestCase): pass def tearDown(self): - l.zclear('a') + l.flushdb() def test_zadd(self): l.zadd('a', a1=1, a2=2, a3=3) diff --git a/client/ledis-py/tests/test_others.py b/client/ledis-py/tests/test_others.py index d57d332..2cd7110 100644 --- a/client/ledis-py/tests/test_others.py +++ b/client/ledis-py/tests/test_others.py @@ -10,13 +10,14 @@ from ledis._compat import b from ledis import ResponseError l = ledis.Ledis(port=6380) +dbs = ["leveldb", "rocksdb", "goleveldb", "hyperleveldb", "lmdb", "boltdb"] class TestOtherCommands(unittest.TestCase): def setUp(self): pass def tearDown(self): - pass + l.flushdb() # server information def test_echo(self): @@ -28,4 +29,93 @@ class TestOtherCommands(unittest.TestCase): def test_select(self): assert l.select('1') assert l.select('15') - self.assertRaises(ResponseError, lambda: l.select('16')) \ No newline at end of file + self.assertRaises(ResponseError, lambda: l.select('16')) + + + def test_info(self): + info1 = l.info() + assert info1.get("db_name") in dbs + info2 = l.info(section="server") + assert info2.get("os") in ["linux", "darwin"] + + def test_flushdb(self): + l.set("a", 1) + assert l.flushdb() == "OK" + assert l.get("a") is None + + def test_flushall(self): + l.select(1) + l.set("a", 1) + assert l.get("a") == b("1") + + l.select(10) + l.set("a", 1) + assert l.get("a") == b("1") + + assert l.flushall() == "OK" + + assert l.get("a") is None + l.select(1) + assert l.get("a") is None + + + # test *scan commands + + def check_keys(self, scan_type): + d = { + "scan": l.scan, + "sscan": l.sscan, + "lscan": l.lscan, + "hscan": l.hscan, + "zscan": l.zscan, + "bscan": l.bscan + } + + key, keys = d[scan_type]() + assert key == "" + assert set(keys) == set([b("a"), b("b"), b("c")]) + + _, keys = d[scan_type](match="a") + assert set(keys) == set([b("a")]) + + _, keys = d[scan_type](key="a") + assert set(keys) == set([b("b"), b("c")]) + + + def test_scan(self): + d = {"a":1, "b":2, "c": 3} + l.mset(d) + self.check_keys("scan") + + + def test_lscan(self): + l.rpush("a", 1) + l.rpush("b", 1) + l.rpush("c", 1) + self.check_keys("lscan") + + + def test_hscan(self): + l.hset("a", "hello", "world") + l.hset("b", "hello", "world") + l.hset("c", "hello", "world") + self.check_keys("hscan") + + def test_sscan(self): + l.sadd("a", 1) + l.sadd("b", 2) + l.sadd("c", 3) + self.check_keys("sscan") + + def test_zscan(self): + l.zadd("a", 1, "a") + l.zadd("b", 1, "a") + l.zadd("c", 1, "a") + self.check_keys("zscan") + + def test_bscan(self): + l.bsetbit("a", 1, 1) + l.bsetbit("b", 1, 1) + l.bsetbit("c", 1, 1) + self.check_keys("bscan") + diff --git a/client/ledis-py/tests/test_tx.py b/client/ledis-py/tests/test_tx.py index b589dc7..cfbab20 100644 --- a/client/ledis-py/tests/test_tx.py +++ b/client/ledis-py/tests/test_tx.py @@ -4,14 +4,21 @@ sys.path.append("..") import ledis +global_l = ledis.Ledis() + +#db that do not support transaction +dbs = ["leveldb", "rocksdb", "hyperleveldb", "goleveldb"] +check = global_l.info().get("db_name") in dbs + class TestTx(unittest.TestCase): def setUp(self): self.l = ledis.Ledis(port=6380) def tearDown(self): - self.l.delete("a") - + self.l.flushdb() + + @unittest.skipIf(check, reason="db not support transaction") def test_commit(self): tx = self.l.tx() self.l.set("a", "no-tx") @@ -24,6 +31,7 @@ class TestTx(unittest.TestCase): tx.commit() assert self.l.get("a") == "tx" + @unittest.skipIf(check, reason="db not support transaction") def test_rollback(self): tx = self.l.tx() self.l.set("a", "no-tx") diff --git a/client/nodejs/ledis/lib/commands.js b/client/nodejs/ledis/lib/commands.js index 696bc7e..7271f3d 100644 --- a/client/nodejs/ledis/lib/commands.js +++ b/client/nodejs/ledis/lib/commands.js @@ -129,4 +129,7 @@ module.exports = [ "rollback", "commit", + "eval", + "evalsha", + "script", ]; diff --git a/client/openresty/ledis.lua b/client/openresty/ledis.lua index a25319e..973d489 100644 --- a/client/openresty/ledis.lua +++ b/client/openresty/ledis.lua @@ -151,7 +151,12 @@ local commands = { -- [[transaction]] "begin", "commit", - "rollback" + "rollback", + + -- [[script]] + "eval", + "evalsha", + "script" } @@ -391,7 +396,6 @@ function _M.hmset(self, hashname, ...) return _do_cmd(self, "hmset", hashname, ...) end - function _M.array_to_hash(self, t) local n = #t -- print("n = ", n) diff --git a/cmd/ledis-benchmark/main.go b/cmd/ledis-benchmark/main.go index 0084df1..f91c98c 100644 --- a/cmd/ledis-benchmark/main.go +++ b/cmd/ledis-benchmark/main.go @@ -1,11 +1,13 @@ package main import ( + crand "crypto/rand" "flag" "fmt" "github.com/siddontang/ledisdb/client/go/ledis" "math/rand" "sync" + "sync/atomic" "time" ) @@ -13,6 +15,8 @@ var ip = flag.String("ip", "127.0.0.1", "redis/ledis/ssdb server ip") var port = flag.Int("port", 6380, "redis/ledis/ssdb server port") var number = flag.Int("n", 1000, "request number") var clients = flag.Int("c", 50, "number of clients") +var reverse = flag.Bool("rev", false, "enable zset rev benchmark") +var round = flag.Int("r", 1, "benchmark round number") var wg sync.WaitGroup @@ -21,17 +25,13 @@ var client *ledis.Client var loop int = 0 func waitBench(cmd string, args ...interface{}) { - defer wg.Done() - c := client.Get() defer c.Close() - for i := 0; i < loop; i++ { - _, err := c.Do(cmd, args...) - if err != nil { - fmt.Printf("do %s error %s", cmd, err.Error()) - return - } + _, err := c.Do(cmd, args...) + if err != nil { + fmt.Printf("do %s error %s", cmd, err.Error()) + return } } @@ -40,7 +40,12 @@ func bench(cmd string, f func()) { t1 := time.Now().UnixNano() for i := 0; i < *clients; i++ { - go f() + go func() { + for i := 0; i < loop; i++ { + f() + } + wg.Done() + }() } wg.Wait() @@ -52,10 +57,17 @@ func bench(cmd string, f func()) { fmt.Printf("%s: %0.2f requests per second\n", cmd, (float64(*number) / delta)) } +var kvSetBase int64 = 0 +var kvGetBase int64 = 0 +var kvIncrBase int64 = 0 +var kvDelBase int64 = 0 + func benchSet() { f := func() { - n := rand.Int() - waitBench("set", n, n) + value := make([]byte, 100) + crand.Read(value) + n := atomic.AddInt64(&kvSetBase, 1) + waitBench("set", n, value) } bench("set", f) @@ -63,26 +75,36 @@ func benchSet() { func benchGet() { f := func() { - n := rand.Int() + n := atomic.AddInt64(&kvGetBase, 1) waitBench("get", n) } bench("get", f) } -func benchIncr() { +func benchRandGet() { f := func() { n := rand.Int() - waitBench("incr", n) + waitBench("get", n) } - bench("incr", f) + bench("randget", f) +} + +func benchDel() { + f := func() { + n := atomic.AddInt64(&kvDelBase, 1) + waitBench("del", n) + } + + bench("del", f) } func benchPushList() { f := func() { - n := rand.Int() - waitBench("rpush", "mytestlist", n) + value := make([]byte, 10) + crand.Read(value) + waitBench("rpush", "mytestlist", value) } bench("rpush", f) @@ -93,7 +115,7 @@ func benchRangeList10() { waitBench("lrange", "mytestlist", 0, 10) } - bench("lrange", f) + bench("lrange10", f) } func benchRangeList50() { @@ -101,7 +123,7 @@ func benchRangeList50() { waitBench("lrange", "mytestlist", 0, 50) } - bench("lrange", f) + bench("lrange50", f) } func benchRangeList100() { @@ -109,7 +131,7 @@ func benchRangeList100() { waitBench("lrange", "mytestlist", 0, 100) } - bench("lrange", f) + bench("lrange100", f) } func benchPopList() { @@ -120,46 +142,60 @@ func benchPopList() { bench("lpop", f) } +var hashSetBase int64 = 0 +var hashIncrBase int64 = 0 +var hashGetBase int64 = 0 +var hashDelBase int64 = 0 + func benchHset() { f := func() { - n := rand.Int() - waitBench("hset", "myhashkey", n, n) + value := make([]byte, 100) + crand.Read(value) + + n := atomic.AddInt64(&hashSetBase, 1) + waitBench("hset", "myhashkey", n, value) } bench("hset", f) } -func benchHIncr() { - f := func() { - n := rand.Int() - waitBench("hincrby", "myhashkey", n, 1) - } - - bench("hincrby", f) -} - func benchHGet() { f := func() { - n := rand.Int() + n := atomic.AddInt64(&hashGetBase, 1) waitBench("hget", "myhashkey", n) } bench("hget", f) } -func benchHDel() { +func benchHRandGet() { f := func() { n := rand.Int() + waitBench("hget", "myhashkey", n) + } + + bench("hrandget", f) +} + +func benchHDel() { + f := func() { + n := atomic.AddInt64(&hashDelBase, 1) waitBench("hdel", "myhashkey", n) } bench("hdel", f) } +var zsetAddBase int64 = 0 +var zsetDelBase int64 = 0 +var zsetIncrBase int64 = 0 + func benchZAdd() { f := func() { - n := rand.Int() - waitBench("zadd", "myzsetkey", n, n) + member := make([]byte, 16) + crand.Read(member) + n := atomic.AddInt64(&zsetAddBase, 1) + waitBench("zadd", "myzsetkey", n, member) } bench("zadd", f) @@ -167,7 +203,7 @@ func benchZAdd() { func benchZDel() { f := func() { - n := rand.Int() + n := atomic.AddInt64(&zsetDelBase, 1) waitBench("zrem", "myzsetkey", n) } @@ -176,7 +212,7 @@ func benchZDel() { func benchZIncr() { f := func() { - n := rand.Int() + n := atomic.AddInt64(&zsetIncrBase, 1) waitBench("zincrby", "myzsetkey", 1, n) } @@ -234,28 +270,44 @@ func main() { cfg := new(ledis.Config) cfg.Addr = addr + cfg.MaxIdleConns = *clients client = ledis.NewClient(cfg) - benchSet() - benchIncr() - benchGet() + if *round <= 0 { + *round = 1 + } - benchPushList() - benchRangeList10() - benchRangeList50() - benchRangeList100() - benchPopList() + for i := 0; i < *round; i++ { + benchSet() + benchGet() + benchRandGet() + benchDel() - benchHset() - benchHGet() - benchHIncr() - benchHDel() + benchPushList() + benchRangeList10() + benchRangeList50() + benchRangeList100() + benchPopList() - benchZAdd() - benchZIncr() - benchZRangeByRank() - benchZRangeByScore() - benchZRevRangeByRank() - benchZRevRangeByScore() - benchZDel() + benchHset() + benchHGet() + benchHRandGet() + benchHDel() + + benchZAdd() + benchZIncr() + benchZRangeByRank() + benchZRangeByScore() + + //rev is too slow in leveldb, rocksdb or other + //maybe disable for huge data benchmark + if *reverse == true { + benchZRevRangeByRank() + benchZRevRangeByScore() + } + + benchZDel() + + println("") + } } diff --git a/cmd/ledis-binlog/main.go b/cmd/ledis-binlog/main.go index 7212b0e..3725920 100644 --- a/cmd/ledis-binlog/main.go +++ b/cmd/ledis-binlog/main.go @@ -63,12 +63,12 @@ func main() { } } -func printEvent(createTime uint32, event []byte) error { - if createTime < startTime || createTime > stopTime { +func printEvent(head *ledis.BinLogHead, event []byte) error { + if head.CreateTime < startTime || head.CreateTime > stopTime { return nil } - t := time.Unix(int64(createTime), 0) + t := time.Unix(int64(head.CreateTime), 0) fmt.Printf("%s ", t.Format(TimeFormat)) diff --git a/cmd/ledis-cli/const.go b/cmd/ledis-cli/const.go index 48b78aa..9560c44 100644 --- a/cmd/ledis-cli/const.go +++ b/cmd/ledis-cli/const.go @@ -1,4 +1,4 @@ -//This file was generated by .tools/generate_commands.py on Wed Aug 27 2014 11:14:50 +0800 +//This file was generated by .tools/generate_commands.py on Tue Sep 02 2014 22:27:45 +0800 package main var helpCommands = [][]string{ @@ -20,6 +20,8 @@ var helpCommands = [][]string{ {"DECRBY", "key decrement", "KV"}, {"DEL", "key [key ...]", "KV"}, {"ECHO", "message", "Server"}, + {"EVAL", "script numkeys key [key ...] arg [arg ...]", "Script"}, + {"EVALSHA", "sha1 numkeys key [key ...] arg [arg ...]", "Script"}, {"EXISTS", "key", "KV"}, {"EXPIRE", "key seconds", "KV"}, {"EXPIREAT", "key timestamp", "KV"}, @@ -72,6 +74,9 @@ var helpCommands = [][]string{ {"SCAN", "key [MATCH match] [COUNT count]", "KV"}, {"SCARD", "key", "Set"}, {"SCLEAR", "key", "Set"}, + {"SCRIPT EXISTS", "script [script ...]", "Script"}, + {"SCRIPT FLUSH", "-", "Script"}, + {"SCRIPT LOAD", "script", "Script"}, {"SDIFF", "key [key ...]", "Set"}, {"SDIFFSTORE", "destination key [key ...]", "Set"}, {"SELECT", "index", "Server"}, diff --git a/cmd/ledis-cli/main.go b/cmd/ledis-cli/main.go index 70f0b93..ba97b30 100644 --- a/cmd/ledis-cli/main.go +++ b/cmd/ledis-cli/main.go @@ -65,22 +65,19 @@ func main() { if cmd == "help" || cmd == "?" { printHelp(cmds) } else { - if len(cmds) == 2 && strings.ToLower(cmds[0]) == "select" { - if db, _ := strconv.Atoi(cmds[1]); db < 16 && db >= 0 { - *dbn = db - } - - } - r, err := c.Do(cmds[0], args...) + if err == nil && strings.ToLower(cmds[0]) == "select" { + *dbn, _ = strconv.Atoi(cmds[1]) + } + if err != nil { fmt.Printf("%s", err.Error()) } else { if cmd == "info" { printInfo(r.([]byte)) } else { - printReply(cmd, r) + printReply(0, r) } } @@ -95,7 +92,7 @@ func printInfo(s []byte) { fmt.Printf("%s", s) } -func printReply(cmd string, reply interface{}) { +func printReply(level int, reply interface{}) { switch reply := reply.(type) { case int64: fmt.Printf("(integer) %d", reply) @@ -109,12 +106,14 @@ func printReply(cmd string, reply interface{}) { fmt.Printf("%s", string(reply)) case []interface{}: for i, v := range reply { - fmt.Printf("%d) ", i+1) - if v == nil { - fmt.Printf("(nil)") - } else { - fmt.Printf("%q", v) + if i != 0 { + fmt.Printf("%s", strings.Repeat(" ", level*4)) } + + s := fmt.Sprintf("%d) ", i+1) + fmt.Printf("%-4s", s) + + printReply(level+1, v) if i != len(reply)-1 { fmt.Printf("\n") } diff --git a/config/config.go b/config/config.go index 48a45f7..ca93d29 100644 --- a/config/config.go +++ b/config/config.go @@ -102,6 +102,7 @@ func NewConfigDefault() *Config { // disable access log cfg.AccessLog = "" + cfg.LMDB.MapSize = 20 * 1024 * 1024 cfg.LMDB.NoSync = true return cfg diff --git a/config/config.toml b/config/config.toml index 573db9a..2a3a246 100644 --- a/config/config.toml +++ b/config/config.toml @@ -22,6 +22,8 @@ slaveof = "" # goleveldb # lmdb # boltdb +# hyperleveldb +# memory # db_name = "leveldb" diff --git a/dev.sh b/dev.sh index 798ffab..a9be046 100644 --- a/dev.sh +++ b/dev.sh @@ -74,6 +74,13 @@ if [ -f $HYPERLEVELDB_DIR/include/hyperleveldb/c.h ]; then GO_BUILD_TAGS="$GO_BUILD_TAGS hyperleveldb" fi +#check lua +CHECK_LUA_FILE="$LEDISTOP/tools/check_lua.go" +go run $CHECK_LUA_FILE >> /dev/null 2>&1 +if [ "$?" = 0 ]; then + GO_BUILD_TAGS="$GO_BUILD_TAGS lua" +fi + export CGO_CFLAGS export CGO_CXXFLAGS export CGO_LDFLAGS diff --git a/doc/commands.json b/doc/commands.json index e6ad2e2..370d588 100644 --- a/doc/commands.json +++ b/doc/commands.json @@ -580,5 +580,36 @@ "arguments": "[section]", "group": "Server", "readonly": true + }, + + "EVAL": { + "arguments": "script numkeys key [key ...] arg [arg ...]", + "group": "Script", + "readonly": false + }, + + "EVALSHA": { + "arguments": "sha1 numkeys key [key ...] arg [arg ...]", + "group": "Script", + "readonly": false + }, + + "SCRIPT LOAD": { + "arguments": "script", + "group": "Script", + "readonly": false + }, + + "SCRIPT EXISTS": { + "arguments": "script [script ...]", + "group": "Script", + "readonly": false + }, + + "SCRIPT FLUSH": { + "arguments" : "-", + "group": "Script", + "readonly": false } + } diff --git a/doc/commands.md b/doc/commands.md index 535602d..d3c3b39 100644 --- a/doc/commands.md +++ b/doc/commands.md @@ -70,7 +70,7 @@ Table of Contents - [SINTERSTORE destination key [key ...]](#sinterstore-destination-key-key-) - [SISMEMBER key member](#sismember-key-member) - [SMEMBERS key](#smembers-key) - - [SREM key member [member]](#srem-key-member-member-) + - [SREM key member [member ...]](#srem-key-member-member-) - [SUNION key [key ...]](#sunion-key-key-) - [SUNIONSTORE destination key [key ...]](#sunionstore-destination-key-key-) - [SCLEAR key](#sclear-key) @@ -133,7 +133,12 @@ Table of Contents - [BEGIN](#begin) - [ROLLBACK](#rollback) - [COMMIT](#commit) - +- [Script](#script) + - [EVAL script numkeys key [key ...] arg [arg ...]](#eval-script-numkeys-key-key--arg-arg-) + - [EVALSHA sha1 numkeys key [key ...] arg [arg ...]](#evalsha-sha1-numkeys-key-key--arg-arg-) + - [SCRIPT LOAD script](#script-load-script) + - [SCRIPT EXISTS script [script ...]](#script-exists-script-script-) + - [SCRIPT FLUSH](#script-flush) ## KV @@ -869,7 +874,7 @@ ledis> HPERSIST not_exists_key Iterate Hash keys incrementally. -See `SCAN` for more information. +See [Scan](#scan-key-match-match-count-count) for more information. ## List @@ -1166,7 +1171,7 @@ ledis> LPERSIST b Iterate list keys incrementally. -See `SCAN` for more information. +See [Scan](#scan-key-match-match-count-count) for more information. ## Set @@ -1594,7 +1599,7 @@ ledis> STTL key Iterate Set keys incrementally. -See `SCAN` for more information. +See [Scan](#scan-key-match-match-count-count) for more information. ## ZSet @@ -2220,7 +2225,7 @@ ledis> ZRANGE out 0 -1 WITHSCORES Iterate ZSet keys incrementally. -See `SCAN` for more information. +See [Scan](#scan-key-match-match-count-count) for more information. ## Bitmap @@ -2386,7 +2391,7 @@ ledis> BCOUNT flag 5 6 Iterate Bitmap keys incrementally. -See `SCAN` for more information. +See [Scan](#scan-key-match-match-count-count) for more information. ## Replication @@ -2562,4 +2567,21 @@ ledis> GET HELLO "WORLD" ``` +## Script + +LedisDB's script is refer to Redis, you can see more [http://redis.io/commands/eval](http://redis.io/commands/eval) + +You must notice that executing lua will block any other write operations. + +### EVAL script numkeys key [key ...] arg [arg ...] + +### EVALSHA sha1 numkeys key [key ...] arg [arg ...] + +### SCRIPT LOAD script + +### SCRIPT EXISTS script [script ...] + +### SCRIPT FLUSH + + Thanks [doctoc](http://doctoc.herokuapp.com/) diff --git a/etc/ledis.conf b/etc/ledis.conf index 2097f65..d3adbd8 100644 --- a/etc/ledis.conf +++ b/etc/ledis.conf @@ -24,6 +24,8 @@ slaveof = "" # goleveldb # lmdb # boltdb +# hyperleveldb +# memory # db_name = "leveldb" diff --git a/ledis/batch.go b/ledis/batch.go new file mode 100644 index 0000000..b23cc47 --- /dev/null +++ b/ledis/batch.go @@ -0,0 +1,105 @@ +package ledis + +import ( + "github.com/siddontang/ledisdb/store" + "sync" +) + +type batch struct { + l *Ledis + + store.WriteBatch + + sync.Locker + + logs [][]byte + + tx *Tx +} + +func (b *batch) Commit() error { + b.l.commitLock.Lock() + defer b.l.commitLock.Unlock() + + err := b.WriteBatch.Commit() + + if b.l.binlog != nil { + if err == nil { + if b.tx == nil { + b.l.binlog.Log(b.logs...) + } else { + b.tx.logs = append(b.tx.logs, b.logs...) + } + } + b.logs = [][]byte{} + } + + return err +} + +func (b *batch) Lock() { + b.Locker.Lock() +} + +func (b *batch) Unlock() { + if b.l.binlog != nil { + b.logs = [][]byte{} + } + b.WriteBatch.Rollback() + b.Locker.Unlock() +} + +func (b *batch) Put(key []byte, value []byte) { + if b.l.binlog != nil { + buf := encodeBinLogPut(key, value) + b.logs = append(b.logs, buf) + } + b.WriteBatch.Put(key, value) +} + +func (b *batch) Delete(key []byte) { + if b.l.binlog != nil { + buf := encodeBinLogDelete(key) + b.logs = append(b.logs, buf) + } + b.WriteBatch.Delete(key) +} + +type dbBatchLocker struct { + l *sync.Mutex + wrLock *sync.RWMutex +} + +func (l *dbBatchLocker) Lock() { + l.wrLock.RLock() + l.l.Lock() +} + +func (l *dbBatchLocker) Unlock() { + l.l.Unlock() + l.wrLock.RUnlock() +} + +type txBatchLocker struct { +} + +func (l *txBatchLocker) Lock() {} +func (l *txBatchLocker) Unlock() {} + +type multiBatchLocker struct { +} + +func (l *multiBatchLocker) Lock() {} +func (l *multiBatchLocker) Unlock() {} + +func (l *Ledis) newBatch(wb store.WriteBatch, locker sync.Locker, tx *Tx) *batch { + b := new(batch) + b.l = l + b.WriteBatch = wb + + b.tx = tx + b.Locker = locker + + b.logs = [][]byte{} + return b +} diff --git a/ledis/binlog.go b/ledis/binlog.go index 087c13f..6eb0c30 100644 --- a/ledis/binlog.go +++ b/ledis/binlog.go @@ -6,14 +6,75 @@ import ( "fmt" "github.com/siddontang/go-log/log" "github.com/siddontang/ledisdb/config" + "io" "io/ioutil" "os" "path" "strconv" "strings" + "sync" "time" ) +type BinLogHead struct { + CreateTime uint32 + BatchId uint32 + PayloadLen uint32 +} + +func (h *BinLogHead) Len() int { + return 12 +} + +func (h *BinLogHead) Write(w io.Writer) error { + if err := binary.Write(w, binary.BigEndian, h.CreateTime); err != nil { + return err + } + + if err := binary.Write(w, binary.BigEndian, h.BatchId); err != nil { + return err + } + + if err := binary.Write(w, binary.BigEndian, h.PayloadLen); err != nil { + return err + } + + return nil +} + +func (h *BinLogHead) handleReadError(err error) error { + if err == io.EOF { + return io.ErrUnexpectedEOF + } else { + return err + } +} + +func (h *BinLogHead) Read(r io.Reader) error { + var err error + if err = binary.Read(r, binary.BigEndian, &h.CreateTime); err != nil { + return err + } + + if err = binary.Read(r, binary.BigEndian, &h.BatchId); err != nil { + return h.handleReadError(err) + } + + if err = binary.Read(r, binary.BigEndian, &h.PayloadLen); err != nil { + return h.handleReadError(err) + } + + return nil +} + +func (h *BinLogHead) InSameBatch(ho *BinLogHead) bool { + if h.CreateTime == ho.CreateTime && h.BatchId == ho.BatchId { + return true + } else { + return false + } +} + /* index file format: ledis-bin.00001 @@ -22,11 +83,15 @@ ledis-bin.00003 log file format -timestamp(bigendian uint32, seconds)|PayloadLen(bigendian uint32)|PayloadData +Log: Head|PayloadData + +Head: createTime|batchId|payloadData */ type BinLog struct { + sync.Mutex + path string cfg *config.BinLogConfig @@ -38,6 +103,8 @@ type BinLog struct { indexName string logNames []string lastLogIndex int64 + + batchId uint32 } func NewBinLog(cfg *config.Config) (*BinLog, error) { @@ -46,7 +113,7 @@ func NewBinLog(cfg *config.Config) (*BinLog, error) { l.cfg = &cfg.BinLog l.cfg.Adjust() - l.path = path.Join(cfg.DataDir, "bin_log") + l.path = path.Join(cfg.DataDir, "binlog") if err := os.MkdirAll(l.path, os.ModePerm); err != nil { return nil, err @@ -177,16 +244,20 @@ func (l *BinLog) checkLogFileSize() bool { st, _ := l.logFile.Stat() if st.Size() >= int64(l.cfg.MaxFileSize) { - l.lastLogIndex++ - - l.logFile.Close() - l.logFile = nil + l.closeLog() return true } return false } +func (l *BinLog) closeLog() { + l.lastLogIndex++ + + l.logFile.Close() + l.logFile = nil +} + func (l *BinLog) purge(n int) { for i := 0; i < n; i++ { logPath := path.Join(l.path, l.logNames[i]) @@ -238,6 +309,9 @@ func (l *BinLog) LogPath() string { } func (l *BinLog) Purge(n int) error { + l.Lock() + defer l.Unlock() + if len(l.logNames) == 0 { return nil } @@ -255,7 +329,18 @@ func (l *BinLog) Purge(n int) error { return l.flushIndex() } +func (l *BinLog) PurgeAll() error { + l.Lock() + defer l.Unlock() + + l.closeLog() + return l.openNewLogFile() +} + func (l *BinLog) Log(args ...[]byte) error { + l.Lock() + defer l.Unlock() + var err error if l.logFile == nil { @@ -264,17 +349,17 @@ func (l *BinLog) Log(args ...[]byte) error { } } - //we treat log many args as a batch, so use same createTime - createTime := uint32(time.Now().Unix()) + head := &BinLogHead{} + + head.CreateTime = uint32(time.Now().Unix()) + head.BatchId = l.batchId + + l.batchId++ for _, data := range args { - payLoadLen := uint32(len(data)) + head.PayloadLen = uint32(len(data)) - if err := binary.Write(l.logWb, binary.BigEndian, createTime); err != nil { - return err - } - - if err := binary.Write(l.logWb, binary.BigEndian, payLoadLen); err != nil { + if err := head.Write(l.logWb); err != nil { return err } diff --git a/ledis/binlog_util.go b/ledis/binlog_util.go index 5167b40..da058bd 100644 --- a/ledis/binlog_util.go +++ b/ledis/binlog_util.go @@ -54,15 +54,6 @@ func decodeBinLogPut(sz []byte) ([]byte, []byte, error) { return sz[3 : 3+keyLen], sz[3+keyLen:], nil } -func encodeBinLogCommand(commandType uint8, args ...[]byte) []byte { - //to do - return nil -} - -func decodeBinLogCommand(sz []byte) (uint8, [][]byte, error) { - return 0, nil, errBinLogCommandType -} - func FormatBinLogEvent(event []byte) (string, error) { logType := uint8(event[0]) diff --git a/ledis/const.go b/ledis/const.go index ef416de..e889f4e 100644 --- a/ledis/const.go +++ b/ledis/const.go @@ -86,3 +86,9 @@ const ( BinLogTypePut uint8 = 0x1 BinLogTypeCommand uint8 = 0x2 ) + +const ( + DBAutoCommit uint8 = 0x0 + DBInTransaction uint8 = 0x1 + DBInMulti uint8 = 0x2 +) diff --git a/ledis/dump.go b/ledis/dump.go index 14d7ff7..63c1d58 100644 --- a/ledis/dump.go +++ b/ledis/dump.go @@ -57,16 +57,17 @@ func (l *Ledis) DumpFile(path string) error { func (l *Ledis) Dump(w io.Writer) error { var m *MasterInfo = new(MasterInfo) - l.Lock() - defer l.Unlock() + + var err error + + l.wLock.Lock() + defer l.wLock.Unlock() if l.binlog != nil { m.LogFileIndex = l.binlog.LogFileIndex() m.LogPos = l.binlog.LogFilePos() } - var err error - wb := bufio.NewWriterSize(w, 4096) if err = m.WriteTo(wb); err != nil { return err @@ -80,8 +81,8 @@ func (l *Ledis) Dump(w io.Writer) error { var key []byte var value []byte for ; it.Valid(); it.Next() { - key = it.Key() - value = it.Value() + key = it.RawKey() + value = it.RawValue() if key, err = snappy.Encode(compressBuf, key); err != nil { return err @@ -128,8 +129,8 @@ func (l *Ledis) LoadDumpFile(path string) (*MasterInfo, error) { } func (l *Ledis) LoadDump(r io.Reader) (*MasterInfo, error) { - l.Lock() - defer l.Unlock() + l.wLock.Lock() + defer l.wLock.Unlock() info := new(MasterInfo) @@ -182,10 +183,6 @@ func (l *Ledis) LoadDump(r io.Reader) (*MasterInfo, error) { return nil, err } - if l.binlog != nil { - err = l.binlog.Log(encodeBinLogPut(key, value)) - } - keyBuf.Reset() valueBuf.Reset() } @@ -193,5 +190,10 @@ func (l *Ledis) LoadDump(r io.Reader) (*MasterInfo, error) { deKeyBuf = nil deValueBuf = nil + //if binlog enable, we will delete all binlogs and open a new one for handling simply + if l.binlog != nil { + l.binlog.PurgeAll() + } + return info, nil } diff --git a/ledis/ledis.go b/ledis/ledis.go index 70d22d1..f3c1c8c 100644 --- a/ledis/ledis.go +++ b/ledis/ledis.go @@ -10,8 +10,6 @@ import ( ) type Ledis struct { - sync.Mutex - cfg *config.Config ldb *store.DB @@ -21,11 +19,13 @@ type Ledis struct { jobs *sync.WaitGroup binlog *BinLog + + wLock sync.RWMutex //allow one write at same time + commitLock sync.Mutex //allow one write commit at same time } func Open(cfg *config.Config) (*Ledis, error) { if len(cfg.DataDir) == 0 { - fmt.Printf("no datadir set, use default %s\n", config.DefaultDataDir) cfg.DataDir = config.DefaultDataDir } @@ -42,7 +42,6 @@ func Open(cfg *config.Config) (*Ledis, error) { l.ldb = ldb if cfg.BinLog.MaxFileNum > 0 && cfg.BinLog.MaxFileSize > 0 { - println("binlog will be refactored later, use your own risk!!!") l.binlog, err = NewBinLog(cfg) if err != nil { return nil, err diff --git a/ledis/ledis_db.go b/ledis/ledis_db.go index c774d03..dd8ff74 100644 --- a/ledis/ledis_db.go +++ b/ledis/ledis_db.go @@ -29,8 +29,6 @@ type DB struct { bucket ibucket - dbLock *sync.RWMutex - index uint8 kvBatch *batch @@ -40,7 +38,7 @@ type DB struct { binBatch *batch setBatch *batch - isTx bool + status uint8 } func (l *Ledis) newDB(index uint8) *DB { @@ -52,9 +50,8 @@ func (l *Ledis) newDB(index uint8) *DB { d.bucket = d.sdb - d.isTx = false + d.status = DBAutoCommit d.index = index - d.dbLock = &sync.RWMutex{} d.kvBatch = d.newBatch() d.listBatch = d.newBatch() @@ -66,10 +63,18 @@ func (l *Ledis) newDB(index uint8) *DB { return d } +func (db *DB) newBatch() *batch { + return db.l.newBatch(db.bucket.NewWriteBatch(), &dbBatchLocker{l: &sync.Mutex{}, wrLock: &db.l.wLock}, nil) +} + func (db *DB) Index() int { return int(db.index) } +func (db *DB) IsAutoCommit() bool { + return db.status == DBAutoCommit +} + func (db *DB) FlushAll() (drop int64, err error) { all := [...](func() (int64, error)){ db.flush, diff --git a/ledis/ledis_test.go b/ledis/ledis_test.go index aff4ebe..d5a5476 100644 --- a/ledis/ledis_test.go +++ b/ledis/ledis_test.go @@ -14,8 +14,8 @@ func getTestDB() *DB { f := func() { cfg := new(config.Config) cfg.DataDir = "/tmp/test_ledis" - cfg.BinLog.MaxFileSize = 1073741824 - cfg.BinLog.MaxFileNum = 3 + // cfg.BinLog.MaxFileSize = 1073741824 + // cfg.BinLog.MaxFileNum = 3 os.RemoveAll(cfg.DataDir) diff --git a/ledis/multi.go b/ledis/multi.go new file mode 100644 index 0000000..a549c2c --- /dev/null +++ b/ledis/multi.go @@ -0,0 +1,73 @@ +package ledis + +import ( + "errors" + "fmt" +) + +var ( + ErrNestMulti = errors.New("nest multi not supported") + ErrMultiDone = errors.New("multi has been closed") +) + +type Multi struct { + *DB +} + +func (db *DB) IsInMulti() bool { + return db.status == DBInMulti +} + +// begin a mutli to execute commands, +// it will block any other write operations before you close the multi, unlike transaction, mutli can not rollback +func (db *DB) Multi() (*Multi, error) { + if db.IsInMulti() { + return nil, ErrNestMulti + } + + m := new(Multi) + + m.DB = new(DB) + m.DB.status = DBInMulti + + m.DB.l = db.l + + m.l.wLock.Lock() + + m.DB.sdb = db.sdb + + m.DB.bucket = db.sdb + + m.DB.index = db.index + + m.DB.kvBatch = m.newBatch() + m.DB.listBatch = m.newBatch() + m.DB.hashBatch = m.newBatch() + m.DB.zsetBatch = m.newBatch() + m.DB.binBatch = m.newBatch() + m.DB.setBatch = m.newBatch() + + return m, nil +} + +func (m *Multi) newBatch() *batch { + return m.l.newBatch(m.bucket.NewWriteBatch(), &multiBatchLocker{}, nil) +} + +func (m *Multi) Close() error { + if m.bucket == nil { + return ErrMultiDone + } + m.l.wLock.Unlock() + m.bucket = nil + return nil +} + +func (m *Multi) Select(index int) error { + if index < 0 || index >= int(MaxDBNumber) { + return fmt.Errorf("invalid db index %d", index) + } + + m.DB.index = uint8(index) + return nil +} diff --git a/ledis/multi_test.go b/ledis/multi_test.go new file mode 100644 index 0000000..936c141 --- /dev/null +++ b/ledis/multi_test.go @@ -0,0 +1,51 @@ +package ledis + +import ( + "sync" + "testing" +) + +func TestMulti(t *testing.T) { + db := getTestDB() + + key := []byte("test_multi_1") + v1 := []byte("v1") + v2 := []byte("v2") + + m, err := db.Multi() + if err != nil { + t.Fatal(err) + } + + wg := sync.WaitGroup{} + + wg.Add(1) + + go func() { + if err := db.Set(key, v2); err != nil { + t.Fatal(err) + } + wg.Done() + }() + + if err := m.Set(key, v1); err != nil { + t.Fatal(err) + } + + if v, err := m.Get(key); err != nil { + t.Fatal(err) + } else if string(v) != string(v1) { + t.Fatal(string(v)) + } + + m.Close() + + wg.Wait() + + if v, err := db.Get(key); err != nil { + t.Fatal(err) + } else if string(v) != string(v2) { + t.Fatal(string(v)) + } + +} diff --git a/ledis/replication.go b/ledis/replication.go index bd6c192..421a5ab 100644 --- a/ledis/replication.go +++ b/ledis/replication.go @@ -3,13 +3,18 @@ package ledis import ( "bufio" "bytes" - "encoding/binary" "errors" "github.com/siddontang/go-log/log" + "github.com/siddontang/ledisdb/store/driver" "io" "os" ) +const ( + maxReplBatchNum = 100 + maxReplLogSize = 1 * 1024 * 1024 +) + var ( ErrSkipEvent = errors.New("skip to next event") ) @@ -19,70 +24,90 @@ var ( errInvalidBinLogFile = errors.New("invalid binlog file") ) -func (l *Ledis) ReplicateEvent(event []byte) error { +type replBatch struct { + wb driver.IWriteBatch + events [][]byte + l *Ledis + + lastHead *BinLogHead +} + +func (b *replBatch) Commit() error { + b.l.commitLock.Lock() + defer b.l.commitLock.Unlock() + + err := b.wb.Commit() + if err != nil { + b.Rollback() + return err + } + + if b.l.binlog != nil { + if err = b.l.binlog.Log(b.events...); err != nil { + b.Rollback() + return err + } + } + + b.events = [][]byte{} + b.lastHead = nil + + return nil +} + +func (b *replBatch) Rollback() error { + b.wb.Rollback() + b.events = [][]byte{} + b.lastHead = nil + return nil +} + +func (l *Ledis) replicateEvent(b *replBatch, event []byte) error { if len(event) == 0 { return errInvalidBinLogEvent } + b.events = append(b.events, event) + logType := uint8(event[0]) switch logType { case BinLogTypePut: - return l.replicatePutEvent(event) + return l.replicatePutEvent(b, event) case BinLogTypeDeletion: - return l.replicateDeleteEvent(event) - case BinLogTypeCommand: - return l.replicateCommandEvent(event) + return l.replicateDeleteEvent(b, event) default: return errInvalidBinLogEvent } } -func (l *Ledis) replicatePutEvent(event []byte) error { +func (l *Ledis) replicatePutEvent(b *replBatch, event []byte) error { key, value, err := decodeBinLogPut(event) if err != nil { return err } - if err = l.ldb.Put(key, value); err != nil { - return err - } + b.wb.Put(key, value) - if l.binlog != nil { - err = l.binlog.Log(event) - } - - return err + return nil } -func (l *Ledis) replicateDeleteEvent(event []byte) error { +func (l *Ledis) replicateDeleteEvent(b *replBatch, event []byte) error { key, err := decodeBinLogDelete(event) if err != nil { return err } - if err = l.ldb.Delete(key); err != nil { - return err - } + b.wb.Delete(key) - if l.binlog != nil { - err = l.binlog.Log(event) - } - - return err + return nil } -func (l *Ledis) replicateCommandEvent(event []byte) error { - return errors.New("command event not supported now") -} - -func ReadEventFromReader(rb io.Reader, f func(createTime uint32, event []byte) error) error { - var createTime uint32 - var dataLen uint32 - var dataBuf bytes.Buffer +func ReadEventFromReader(rb io.Reader, f func(head *BinLogHead, event []byte) error) error { + head := &BinLogHead{} var err error for { - if err = binary.Read(rb, binary.BigEndian, &createTime); err != nil { + if err = head.Read(rb); err != nil { if err == io.EOF { break } else { @@ -90,28 +115,39 @@ func ReadEventFromReader(rb io.Reader, f func(createTime uint32, event []byte) e } } - if err = binary.Read(rb, binary.BigEndian, &dataLen); err != nil { + var dataBuf bytes.Buffer + + if _, err = io.CopyN(&dataBuf, rb, int64(head.PayloadLen)); err != nil { return err } - if _, err = io.CopyN(&dataBuf, rb, int64(dataLen)); err != nil { - return err - } - - err = f(createTime, dataBuf.Bytes()) + err = f(head, dataBuf.Bytes()) if err != nil && err != ErrSkipEvent { return err } - - dataBuf.Reset() } return nil } func (l *Ledis) ReplicateFromReader(rb io.Reader) error { - f := func(createTime uint32, event []byte) error { - err := l.ReplicateEvent(event) + b := new(replBatch) + + b.wb = l.ldb.NewWriteBatch() + b.l = l + + f := func(head *BinLogHead, event []byte) error { + if b.lastHead == nil { + b.lastHead = head + } else if !b.lastHead.InSameBatch(head) { + if err := b.Commit(); err != nil { + log.Fatal("replication error %s, skip to next", err.Error()) + return ErrSkipEvent + } + b.lastHead = head + } + + err := l.replicateEvent(b, event) if err != nil { log.Fatal("replication error %s, skip to next", err.Error()) return ErrSkipEvent @@ -119,15 +155,18 @@ func (l *Ledis) ReplicateFromReader(rb io.Reader) error { return nil } - return ReadEventFromReader(rb, f) + err := ReadEventFromReader(rb, f) + if err != nil { + b.Rollback() + return err + } + return b.Commit() } func (l *Ledis) ReplicateFromData(data []byte) error { rb := bytes.NewReader(data) - l.Lock() err := l.ReplicateFromReader(rb) - l.Unlock() return err } @@ -140,17 +179,13 @@ func (l *Ledis) ReplicateFromBinLog(filePath string) error { rb := bufio.NewReaderSize(f, 4096) - l.Lock() err = l.ReplicateFromReader(rb) - l.Unlock() f.Close() return err } -const maxSyncEvents = 64 - func (l *Ledis) ReadEventsTo(info *MasterInfo, w io.Writer) (n int, err error) { n = 0 if l.binlog == nil { @@ -201,14 +236,14 @@ func (l *Ledis) ReadEventsTo(info *MasterInfo, w io.Writer) (n int, err error) { return } - var lastCreateTime uint32 = 0 - var createTime uint32 - var dataLen uint32 + var lastHead *BinLogHead = nil - var eventsNum int = 0 + head := &BinLogHead{} + + batchNum := 0 for { - if err = binary.Read(f, binary.BigEndian, &createTime); err != nil { + if err = head.Read(f); err != nil { if err == io.EOF { //we will try to use next binlog if index < l.binlog.LogFileIndex() { @@ -220,35 +255,30 @@ func (l *Ledis) ReadEventsTo(info *MasterInfo, w io.Writer) (n int, err error) { } else { return } + } - eventsNum++ - if lastCreateTime == 0 { - lastCreateTime = createTime - } else if lastCreateTime != createTime { - return - } else if eventsNum > maxSyncEvents { + if lastHead == nil { + lastHead = head + batchNum++ + } else if !lastHead.InSameBatch(head) { + lastHead = head + batchNum++ + if batchNum > maxReplBatchNum || n > maxReplLogSize { + return + } + } + + if err = head.Write(w); err != nil { return } - if err = binary.Read(f, binary.BigEndian, &dataLen); err != nil { + if _, err = io.CopyN(w, f, int64(head.PayloadLen)); err != nil { return } - if err = binary.Write(w, binary.BigEndian, createTime); err != nil { - return - } - - if err = binary.Write(w, binary.BigEndian, dataLen); err != nil { - return - } - - if _, err = io.CopyN(w, f, int64(dataLen)); err != nil { - return - } - - n += (8 + int(dataLen)) - info.LogPos = info.LogPos + 8 + int64(dataLen) + n += (head.Len() + int(head.PayloadLen)) + info.LogPos = info.LogPos + int64(head.Len()) + int64(head.PayloadLen) } return diff --git a/ledis/replication_test.go b/ledis/replication_test.go index 96bb10a..2a64a11 100644 --- a/ledis/replication_test.go +++ b/ledis/replication_test.go @@ -70,6 +70,12 @@ func TestReplication(t *testing.T) { db.HSet([]byte("c"), []byte("3"), []byte("value")) } + m, _ := db.Multi() + m.Set([]byte("a1"), []byte("value")) + m.Set([]byte("b1"), []byte("value")) + m.Set([]byte("c1"), []byte("value")) + m.Close() + for _, name := range master.binlog.LogNames() { p := path.Join(master.binlog.LogPath(), name) diff --git a/ledis/t_bit.go b/ledis/t_bit.go index fadab4d..496c37a 100644 --- a/ledis/t_bit.go +++ b/ledis/t_bit.go @@ -506,6 +506,7 @@ func (db *DB) BSetBit(key []byte, offset int32, val uint8) (ori uint8, err error if setBit(segment, off, val) { t := db.binBatch t.Lock() + defer t.Unlock() t.Put(bk, segment) if _, _, e := db.bUpdateMeta(t, key, seq, off); e != nil { @@ -514,7 +515,6 @@ func (db *DB) BSetBit(key []byte, offset int32, val uint8) (ori uint8, err error } err = t.Commit() - t.Unlock() } } diff --git a/ledis/tx.go b/ledis/tx.go index 38eb626..6339bae 100644 --- a/ledis/tx.go +++ b/ledis/tx.go @@ -2,8 +2,8 @@ package ledis import ( "errors" + "fmt" "github.com/siddontang/ledisdb/store" - "sync" ) var ( @@ -11,143 +11,44 @@ var ( ErrTxDone = errors.New("Transaction has already been committed or rolled back") ) -type batch struct { - l *Ledis - - store.WriteBatch - - sync.Locker - - logs [][]byte - - tx *Tx -} - -type dbBatchLocker struct { - sync.Mutex - dbLock *sync.RWMutex -} - -type txBatchLocker struct { -} - -func (l *txBatchLocker) Lock() { -} - -func (l *txBatchLocker) Unlock() { -} - -func (l *dbBatchLocker) Lock() { - l.dbLock.RLock() - l.Mutex.Lock() -} - -func (l *dbBatchLocker) Unlock() { - l.Mutex.Unlock() - l.dbLock.RUnlock() -} - -func (db *DB) newBatch() *batch { - b := new(batch) - - b.WriteBatch = db.bucket.NewWriteBatch() - b.Locker = &dbBatchLocker{dbLock: db.dbLock} - b.l = db.l - - return b -} - -func (b *batch) Commit() error { - b.l.Lock() - defer b.l.Unlock() - - err := b.WriteBatch.Commit() - - if b.l.binlog != nil { - if err == nil { - if b.tx == nil { - b.l.binlog.Log(b.logs...) - } else { - b.tx.logs = append(b.tx.logs, b.logs...) - } - } - b.logs = [][]byte{} - } - - return err -} - -func (b *batch) Lock() { - b.Locker.Lock() -} - -func (b *batch) Unlock() { - if b.l.binlog != nil { - b.logs = [][]byte{} - } - b.Rollback() - b.Locker.Unlock() -} - -func (b *batch) Put(key []byte, value []byte) { - if b.l.binlog != nil { - buf := encodeBinLogPut(key, value) - b.logs = append(b.logs, buf) - } - b.WriteBatch.Put(key, value) -} - -func (b *batch) Delete(key []byte) { - if b.l.binlog != nil { - buf := encodeBinLogDelete(key) - b.logs = append(b.logs, buf) - } - b.WriteBatch.Delete(key) -} - type Tx struct { *DB tx *store.Tx logs [][]byte - - index uint8 } func (db *DB) IsTransaction() bool { - return db.isTx + return db.status == DBInTransaction } // Begin a transaction, it will block all other write operations before calling Commit or Rollback. // You must be very careful to prevent long-time transaction. func (db *DB) Begin() (*Tx, error) { - if db.isTx { + if db.IsTransaction() { return nil, ErrNestTx } tx := new(Tx) tx.DB = new(DB) - tx.DB.dbLock = db.dbLock - - tx.DB.dbLock.Lock() - tx.DB.l = db.l - tx.index = db.index + + tx.l.wLock.Lock() tx.DB.sdb = db.sdb var err error tx.tx, err = db.sdb.Begin() if err != nil { - tx.DB.dbLock.Unlock() + tx.l.wLock.Unlock() return nil, err } tx.DB.bucket = tx.tx - tx.DB.isTx = true + tx.DB.status = DBInTransaction tx.DB.index = db.index @@ -166,7 +67,7 @@ func (tx *Tx) Commit() error { return ErrTxDone } - tx.l.Lock() + tx.l.commitLock.Lock() err := tx.tx.Commit() tx.tx = nil @@ -174,10 +75,12 @@ func (tx *Tx) Commit() error { tx.l.binlog.Log(tx.logs...) } - tx.l.Unlock() + tx.l.commitLock.Unlock() + + tx.l.wLock.Unlock() + + tx.DB.bucket = nil - tx.DB.dbLock.Unlock() - tx.DB = nil return err } @@ -189,22 +92,21 @@ func (tx *Tx) Rollback() error { err := tx.tx.Rollback() tx.tx = nil - tx.DB.dbLock.Unlock() - tx.DB = nil + tx.l.wLock.Unlock() + tx.DB.bucket = nil + return err } func (tx *Tx) newBatch() *batch { - b := new(batch) - - b.l = tx.l - b.WriteBatch = tx.tx.NewWriteBatch() - b.Locker = &txBatchLocker{} - b.tx = tx - - return b + return tx.l.newBatch(tx.tx.NewWriteBatch(), &txBatchLocker{}, tx) } -func (tx *Tx) Index() int { - return int(tx.index) +func (tx *Tx) Select(index int) error { + if index < 0 || index >= int(MaxDBNumber) { + return fmt.Errorf("invalid db index %d", index) + } + + tx.DB.index = uint8(index) + return nil } diff --git a/ledis/tx_test.go b/ledis/tx_test.go index bf06012..026b70d 100644 --- a/ledis/tx_test.go +++ b/ledis/tx_test.go @@ -144,6 +144,51 @@ func testTxCommit(t *testing.T, db *DB) { } } +func testTxSelect(t *testing.T, db *DB) { + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + + defer tx.Rollback() + + tx.Set([]byte("tx_select_1"), []byte("a")) + + tx.Select(1) + + tx.Set([]byte("tx_select_2"), []byte("b")) + + if err = tx.Commit(); err != nil { + t.Fatal(err) + } + + if v, err := db.Get([]byte("tx_select_1")); err != nil { + t.Fatal(err) + } else if string(v) != "a" { + t.Fatal(string(v)) + } + + if v, err := db.Get([]byte("tx_select_2")); err != nil { + t.Fatal(err) + } else if v != nil { + t.Fatal("must nil") + } + + db, _ = db.l.Select(1) + + if v, err := db.Get([]byte("tx_select_2")); err != nil { + t.Fatal(err) + } else if string(v) != "b" { + t.Fatal(string(v)) + } + + if v, err := db.Get([]byte("tx_select_1")); err != nil { + t.Fatal(err) + } else if v != nil { + t.Fatal("must nil") + } +} + func testTx(t *testing.T, name string) { cfg := new(config.Config) cfg.DataDir = "/tmp/ledis_test_tx" @@ -164,6 +209,7 @@ func testTx(t *testing.T, name string) { testTxRollback(t, db) testTxCommit(t, db) + testTxSelect(t, db) } //only lmdb, boltdb support Transaction diff --git a/server/app.go b/server/app.go index d5c77c9..edd65c8 100644 --- a/server/app.go +++ b/server/app.go @@ -27,6 +27,8 @@ type App struct { m *master info *info + + s *script } func netType(s string) string { @@ -85,6 +87,8 @@ func NewApp(cfg *config.Config) (*App, error) { app.m = newMaster(app) + app.openScript() + return app, nil } @@ -103,6 +107,8 @@ func (app *App) Close() { app.httpListener.Close() } + app.closeScript() + app.m.Close() if app.access != nil { diff --git a/server/client.go b/server/client.go index f28a930..27e08b1 100644 --- a/server/client.go +++ b/server/client.go @@ -16,6 +16,18 @@ var txUnsupportedCmds = map[string]struct{}{ "begin": struct{}{}, "flushall": struct{}{}, "flushdb": struct{}{}, + "eval": struct{}{}, +} + +var scriptUnsupportedCmds = map[string]struct{}{ + "slaveof": struct{}{}, + "fullsync": struct{}{}, + "sync": struct{}{}, + "begin": struct{}{}, + "commit": struct{}{}, + "rollback": struct{}{}, + "flushall": struct{}{}, + "flushdb": struct{}{}, } type responseWriter interface { @@ -34,7 +46,8 @@ type responseWriter interface { type client struct { app *App ldb *ledis.Ledis - db *ledis.DB + + db *ledis.DB remoteAddr string cmd string @@ -49,7 +62,8 @@ type client struct { buf bytes.Buffer - tx *ledis.Tx + tx *ledis.Tx + script *ledis.Multi } func newClient(app *App) *client { @@ -59,16 +73,12 @@ func newClient(app *App) *client { c.ldb = app.ldb c.db, _ = app.ldb.Select(0) //use default db - c.compressBuf = make([]byte, 256) + c.compressBuf = []byte{} c.reqErr = make(chan error) return c } -func (c *client) isInTransaction() bool { - return c.tx != nil -} - func (c *client) perform() { var err error @@ -79,10 +89,14 @@ func (c *client) perform() { } else if exeCmd, ok := regCmds[c.cmd]; !ok { err = ErrNotFound } else { - if c.isInTransaction() { + if c.db.IsTransaction() { if _, ok := txUnsupportedCmds[c.cmd]; ok { err = fmt.Errorf("%s not supported in transaction", c.cmd) } + } else if c.db.IsInMulti() { + if _, ok := scriptUnsupportedCmds[c.cmd]; ok { + err = fmt.Errorf("%s not supported in multi", c.cmd) + } } if err == nil { @@ -128,3 +142,22 @@ func (c *client) catGenericCommand() []byte { return buffer.Bytes() } + +func writeValue(w responseWriter, value interface{}) { + switch v := value.(type) { + case []interface{}: + w.writeArray(v) + case [][]byte: + w.writeSliceArray(v) + case []byte: + w.writeBulk(v) + case string: + w.writeStatus(v) + case nil: + w.writeBulk(nil) + case int64: + w.writeInteger(v) + default: + panic("invalid value type") + } +} diff --git a/server/cmd_script.go b/server/cmd_script.go new file mode 100644 index 0000000..de1844f --- /dev/null +++ b/server/cmd_script.go @@ -0,0 +1,226 @@ +// +build lua + +package server + +import ( + "crypto/sha1" + "encoding/hex" + "fmt" + "github.com/siddontang/golua/lua" + "github.com/siddontang/ledisdb/ledis" + "strconv" + "strings" +) + +func parseEvalArgs(l *lua.State, c *client) error { + args := c.args + if len(args) < 2 { + return ErrCmdParams + } + + args = args[1:] + + n, err := strconv.Atoi(ledis.String(args[0])) + if err != nil { + return err + } + + if n > len(args)-1 { + return ErrCmdParams + } + + luaSetGlobalArray(l, "KEYS", args[1:n+1]) + luaSetGlobalArray(l, "ARGV", args[n+1:]) + + return nil +} + +func evalGenericCommand(c *client, evalSha1 bool) error { + m, err := c.db.Multi() + if err != nil { + return err + } + + s := c.app.s + luaClient := s.c + l := s.l + + s.Lock() + + base := l.GetTop() + + defer func() { + l.SetTop(base) + luaClient.db = nil + luaClient.script = nil + + s.Unlock() + + m.Close() + }() + + luaClient.db = m.DB + luaClient.script = m + luaClient.remoteAddr = c.remoteAddr + + if err := parseEvalArgs(l, c); err != nil { + return err + } + + var key string + if !evalSha1 { + h := sha1.Sum(c.args[0]) + key = hex.EncodeToString(h[0:20]) + } else { + key = strings.ToLower(ledis.String(c.args[0])) + } + + l.GetGlobal(key) + + if l.IsNil(-1) { + l.Pop(1) + + if evalSha1 { + return fmt.Errorf("missing %s script", key) + } + + if r := l.LoadString(ledis.String(c.args[0])); r != 0 { + err := fmt.Errorf("%s", l.ToString(-1)) + l.Pop(1) + return err + } else { + l.PushValue(-1) + l.SetGlobal(key) + + s.chunks[key] = struct{}{} + } + } + + if err := l.Call(0, lua.LUA_MULTRET); err != nil { + return err + } else { + r := luaReplyToLedisReply(l) + m.Close() + + if v, ok := r.(error); ok { + return v + } + + writeValue(c.resp, r) + } + + return nil +} + +func evalCommand(c *client) error { + return evalGenericCommand(c, false) +} + +func evalshaCommand(c *client) error { + return evalGenericCommand(c, true) +} + +func scriptCommand(c *client) error { + s := c.app.s + l := s.l + + s.Lock() + + base := l.GetTop() + + defer func() { + l.SetTop(base) + s.Unlock() + }() + + args := c.args + + if len(args) < 1 { + return ErrCmdParams + } + + switch strings.ToLower(ledis.String(args[0])) { + case "load": + return scriptLoadCommand(c) + case "exists": + return scriptExistsCommand(c) + case "flush": + return scriptFlushCommand(c) + default: + return fmt.Errorf("invalid script %s", args[0]) + } + + return nil +} + +func scriptLoadCommand(c *client) error { + s := c.app.s + l := s.l + + if len(c.args) != 2 { + return ErrCmdParams + } + + h := sha1.Sum(c.args[1]) + key := hex.EncodeToString(h[0:20]) + + if r := l.LoadString(ledis.String(c.args[1])); r != 0 { + err := fmt.Errorf("%s", l.ToString(-1)) + l.Pop(1) + return err + } else { + l.PushValue(-1) + l.SetGlobal(key) + + s.chunks[key] = struct{}{} + } + + c.resp.writeBulk(ledis.Slice(key)) + return nil +} + +func scriptExistsCommand(c *client) error { + s := c.app.s + + if len(c.args) < 2 { + return ErrCmdParams + } + + ay := make([]interface{}, len(c.args[1:])) + for i, n := range c.args[1:] { + if _, ok := s.chunks[ledis.String(n)]; ok { + ay[i] = int64(1) + } else { + ay[i] = int64(0) + } + } + + c.resp.writeArray(ay) + return nil +} + +func scriptFlushCommand(c *client) error { + s := c.app.s + l := s.l + + if len(c.args) != 1 { + return ErrCmdParams + } + + for n, _ := range s.chunks { + l.PushNil() + l.SetGlobal(n) + } + + s.chunks = map[string]struct{}{} + + c.resp.writeStatus(OK) + + return nil +} + +func init() { + register("eval", evalCommand) + register("evalsha", evalshaCommand) + register("script", scriptCommand) +} diff --git a/server/cmd_script_test.go b/server/cmd_script_test.go new file mode 100644 index 0000000..017e527 --- /dev/null +++ b/server/cmd_script_test.go @@ -0,0 +1,59 @@ +// +build lua + +package server + +import ( + "fmt" + "github.com/siddontang/ledisdb/client/go/ledis" + "reflect" + "testing" +) + +func TestCmdEval(t *testing.T) { + c := getTestConn() + defer c.Close() + + if v, err := ledis.Strings(c.Do("eval", "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", 2, "key1", "key2", "first", "second")); err != nil { + t.Fatal(err) + } else if !reflect.DeepEqual(v, []string{"key1", "key2", "first", "second"}) { + t.Fatal(fmt.Sprintf("%v", v)) + } + + if v, err := ledis.Strings(c.Do("eval", "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", 2, "key1", "key2", "first", "second")); err != nil { + t.Fatal(err) + } else if !reflect.DeepEqual(v, []string{"key1", "key2", "first", "second"}) { + t.Fatal(fmt.Sprintf("%v", v)) + } + + var sha1 string + var err error + if sha1, err = ledis.String(c.Do("script", "load", "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}")); err != nil { + t.Fatal(err) + } else if len(sha1) != 40 { + t.Fatal(sha1) + } + + if v, err := ledis.Strings(c.Do("evalsha", sha1, 2, "key1", "key2", "first", "second")); err != nil { + t.Fatal(err) + } else if !reflect.DeepEqual(v, []string{"key1", "key2", "first", "second"}) { + t.Fatal(fmt.Sprintf("%v", v)) + } + + if ay, err := ledis.Values(c.Do("script", "exists", sha1, "01234567890123456789")); err != nil { + t.Fatal(err) + } else if !reflect.DeepEqual(ay, []interface{}{int64(1), int64(0)}) { + t.Fatal(fmt.Sprintf("%v", ay)) + } + + if ok, err := ledis.String(c.Do("script", "flush")); err != nil { + t.Fatal(err) + } else if ok != "OK" { + t.Fatal(ok) + } + + if ay, err := ledis.Values(c.Do("script", "exists", sha1)); err != nil { + t.Fatal(err) + } else if !reflect.DeepEqual(ay, []interface{}{int64(0)}) { + t.Fatal(fmt.Sprintf("%v", ay)) + } +} diff --git a/server/cmd_ttl_test.go b/server/cmd_ttl_test.go index 702348d..c9d388c 100644 --- a/server/cmd_ttl_test.go +++ b/server/cmd_ttl_test.go @@ -1,6 +1,7 @@ package server import ( + "fmt" "github.com/siddontang/ledisdb/client/go/ledis" "testing" "time" @@ -10,422 +11,119 @@ func now() int64 { return time.Now().Unix() } -func TestKVExpire(t *testing.T) { +func TestExpire(t *testing.T) { + // test for kv, list, hash, set, zset, bitmap in all + ttlType := []string{"k", "l", "h", "s", "z", "b"} + + var ( + expire string + expireat string + ttl string + persist string + key string + ) + c := getTestConn() defer c.Close() - k := "a_ttl" - c.Do("set", k, "123") + idx := 1 + for _, tt := range ttlType { + if tt == "k" { + expire = "expire" + expireat = "expireat" + ttl = "ttl" + persist = "persist" - // expire + ttl - exp := int64(10) - if n, err := ledis.Int(c.Do("expire", k, exp)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } + } else { + expire = fmt.Sprintf("%sexpire", tt) + expireat = fmt.Sprintf("%sexpireat", tt) + ttl = fmt.Sprintf("%sttl", tt) + persist = fmt.Sprintf("%spersist", tt) + } - if ttl, err := ledis.Int64(c.Do("ttl", k)); err != nil { - t.Fatal(err) - } else if ttl != exp { - t.Fatal(ttl) - } + switch tt { + case "k": + key = "kv_ttl" + c.Do("set", key, "123") + case "l": + key = "list_ttl" + c.Do("rpush", key, "123") + case "h": + key = "hash_ttl" + c.Do("hset", key, "a", "123") + case "s": + key = "set_ttl" + c.Do("sadd", key, "123") + case "z": + key = "zset_ttl" + c.Do("zadd", key, 123, "a") + case "b": + key = "bitmap_ttl" + c.Do("bsetbit", key, 0, 1) + } - // expireat + ttl - tm := now() + 3 - if n, err := ledis.Int(c.Do("expireat", k, tm)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } + // expire + ttl + exp := int64(10) + if n, err := ledis.Int(c.Do(expire, key, exp)); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Fatal(n) + } - if ttl, err := ledis.Int64(c.Do("ttl", k)); err != nil { - t.Fatal(err) - } else if ttl != 3 { - t.Fatal(ttl) - } + if ttl, err := ledis.Int64(c.Do(ttl, key)); err != nil { + t.Fatal(err) + } else if ttl != exp { + t.Fatal(ttl) + } - kErr := "not_exist_ttl" + // expireat + ttl + tm := now() + 3 + if n, err := ledis.Int(c.Do(expireat, key, tm)); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Fatal(n) + } - // err - expire, expireat - if n, err := ledis.Int(c.Do("expire", kErr, tm)); err != nil || n != 0 { - t.Fatal(false) - } + if ttl, err := ledis.Int64(c.Do(ttl, key)); err != nil { + t.Fatal(err) + } else if ttl != 3 { + t.Fatal(ttl) + } - if n, err := ledis.Int(c.Do("expireat", kErr, tm)); err != nil || n != 0 { - t.Fatal(false) - } + kErr := "not_exist_ttl" - if n, err := ledis.Int(c.Do("ttl", kErr)); err != nil || n != -1 { - t.Fatal(false) - } + // err - expire, expireat + if n, err := ledis.Int(c.Do(expire, kErr, tm)); err != nil || n != 0 { + t.Fatal(false) + } - if n, err := ledis.Int(c.Do("persist", k)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } + if n, err := ledis.Int(c.Do(expireat, kErr, tm)); err != nil || n != 0 { + t.Fatal(false) + } - if n, err := ledis.Int(c.Do("expire", k, 10)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } + if n, err := ledis.Int(c.Do(ttl, kErr)); err != nil || n != -1 { + t.Fatal(false) + } - if n, err := ledis.Int(c.Do("persist", k)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - -} - -func TestSetExpire(t *testing.T) { - c := getTestConn() - defer c.Close() - - k := "set_ttl" - c.Do("sadd", k, "123") - - // expire + ttl - exp := int64(10) - if n, err := ledis.Int(c.Do("sexpire", k, exp)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - - if ttl, err := ledis.Int64(c.Do("sttl", k)); err != nil { - t.Fatal(err) - } else if ttl != exp { - t.Fatal(ttl) - } - - // expireat + ttl - tm := now() + 3 - if n, err := ledis.Int(c.Do("sexpireat", k, tm)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - - if ttl, err := ledis.Int64(c.Do("sttl", k)); err != nil { - t.Fatal(err) - } else if ttl != 3 { - t.Fatal(ttl) - } - - kErr := "not_exist_ttl" - - // err - expire, expireat - if n, err := ledis.Int(c.Do("sexpire", kErr, tm)); err != nil || n != 0 { - t.Fatal(false) - } - - if n, err := ledis.Int(c.Do("sexpireat", kErr, tm)); err != nil || n != 0 { - t.Fatal(false) - } - - if n, err := ledis.Int(c.Do("sttl", kErr)); err != nil || n != -1 { - t.Fatal(false) - } - - if n, err := ledis.Int(c.Do("spersist", k)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - - if n, err := ledis.Int(c.Do("sexpire", k, 10)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - - if n, err := ledis.Int(c.Do("spersist", k)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - -} - -func TestListExpire(t *testing.T) { - c := getTestConn() - defer c.Close() - - k := "list_ttl" - c.Do("rpush", k, "123") - - // expire + ttl - exp := int64(10) - if n, err := ledis.Int(c.Do("lexpire", k, exp)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - - if ttl, err := ledis.Int64(c.Do("lttl", k)); err != nil { - t.Fatal(err) - } else if ttl != exp { - t.Fatal(ttl) - } - - // expireat + ttl - tm := now() + 3 - if n, err := ledis.Int(c.Do("lexpireat", k, tm)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - - if ttl, err := ledis.Int64(c.Do("lttl", k)); err != nil { - t.Fatal(err) - } else if ttl != 3 { - t.Fatal(ttl) - } - - kErr := "not_exist_ttl" - - // err - expire, expireat - if n, err := ledis.Int(c.Do("lexpire", kErr, tm)); err != nil || n != 0 { - t.Fatal(false) - } - - if n, err := ledis.Int(c.Do("lexpireat", kErr, tm)); err != nil || n != 0 { - t.Fatal(false) - } - - if n, err := ledis.Int(c.Do("lttl", kErr)); err != nil || n != -1 { - t.Fatal(false) - } - - if n, err := ledis.Int(c.Do("lpersist", k)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - - if n, err := ledis.Int(c.Do("lexpire", k, 10)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - - if n, err := ledis.Int(c.Do("lpersist", k)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - -} - -func TestHashExpire(t *testing.T) { - c := getTestConn() - defer c.Close() - - k := "hash_ttl" - c.Do("hset", k, "f", 123) - - // expire + ttl - exp := int64(10) - if n, err := ledis.Int(c.Do("hexpire", k, exp)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - - if ttl, err := ledis.Int64(c.Do("httl", k)); err != nil { - t.Fatal(err) - } else if ttl != exp { - t.Fatal(ttl) - } - - // expireat + ttl - tm := now() + 3 - if n, err := ledis.Int(c.Do("hexpireat", k, tm)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - - if ttl, err := ledis.Int64(c.Do("httl", k)); err != nil { - t.Fatal(err) - } else if ttl != 3 { - t.Fatal(ttl) - } - - kErr := "not_exist_ttl" - - // err - expire, expireat - if n, err := ledis.Int(c.Do("hexpire", kErr, tm)); err != nil || n != 0 { - t.Fatal(false) - } - - if n, err := ledis.Int(c.Do("hexpireat", kErr, tm)); err != nil || n != 0 { - t.Fatal(false) - } - - if n, err := ledis.Int(c.Do("httl", kErr)); err != nil || n != -1 { - t.Fatal(false) - } - - if n, err := ledis.Int(c.Do("hpersist", k)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - - if n, err := ledis.Int(c.Do("hexpire", k, 10)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - - if n, err := ledis.Int(c.Do("hpersist", k)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - -} - -func TestZsetExpire(t *testing.T) { - c := getTestConn() - defer c.Close() - - k := "zset_ttl" - c.Do("zadd", k, 123, "m") - - // expire + ttl - exp := int64(10) - if n, err := ledis.Int(c.Do("zexpire", k, exp)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - - if ttl, err := ledis.Int64(c.Do("zttl", k)); err != nil { - t.Fatal(err) - } else if ttl != exp { - t.Fatal(ttl) - } - - // expireat + ttl - tm := now() + 3 - if n, err := ledis.Int(c.Do("zexpireat", k, tm)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - - if ttl, err := ledis.Int64(c.Do("zttl", k)); err != nil { - t.Fatal(err) - } else if ttl != 3 { - t.Fatal(ttl) - } - - kErr := "not_exist_ttl" - - // err - expire, expireat - if n, err := ledis.Int(c.Do("zexpire", kErr, tm)); err != nil || n != 0 { - t.Fatal(false) - } - - if n, err := ledis.Int(c.Do("zexpireat", kErr, tm)); err != nil || n != 0 { - t.Fatal(false) - } - - if n, err := ledis.Int(c.Do("zttl", kErr)); err != nil || n != -1 { - t.Fatal(false) - } - - if n, err := ledis.Int(c.Do("zpersist", k)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - - if n, err := ledis.Int(c.Do("zexpire", k, 10)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - - if n, err := ledis.Int(c.Do("zpersist", k)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - -} - -func TestBitmapExpire(t *testing.T) { - c := getTestConn() - defer c.Close() - - k := "bit_ttl" - c.Do("bsetbit", k, 0, 1) - - // expire + ttl - exp := int64(10) - if n, err := ledis.Int(c.Do("bexpire", k, exp)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - - if ttl, err := ledis.Int64(c.Do("bttl", k)); err != nil { - t.Fatal(err) - } else if ttl != exp { - t.Fatal(ttl) - } - - // expireat + ttl - tm := now() + 3 - if n, err := ledis.Int(c.Do("bexpireat", k, tm)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - - if ttl, err := ledis.Int64(c.Do("bttl", k)); err != nil { - t.Fatal(err) - } else if ttl != 3 { - t.Fatal(ttl) - } - - kErr := "not_exist_ttl" - - // err - expire, expireat - if n, err := ledis.Int(c.Do("bexpire", kErr, tm)); err != nil || n != 0 { - t.Fatal(false) - } - - if n, err := ledis.Int(c.Do("bexpireat", kErr, tm)); err != nil || n != 0 { - t.Fatal(false) - } - - if n, err := ledis.Int(c.Do("bttl", kErr)); err != nil || n != -1 { - t.Fatal(false) - } - - if n, err := ledis.Int(c.Do("bpersist", k)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - - if n, err := ledis.Int(c.Do("bexpire", k, 10)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) - } - - if n, err := ledis.Int(c.Do("bpersist", k)); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Fatal(n) + if n, err := ledis.Int(c.Do(persist, key)); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Fatal(n) + } + + if n, err := ledis.Int(c.Do(expire, key, 10)); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Fatal(n) + } + + if n, err := ledis.Int(c.Do(persist, key)); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Fatal(n) + } + + idx++ } } diff --git a/server/command.go b/server/command.go index af95244..458343b 100644 --- a/server/command.go +++ b/server/command.go @@ -41,12 +41,26 @@ func selectCommand(c *client) error { if index, err := strconv.Atoi(ledis.String(c.args[0])); err != nil { return err } else { - if db, err := c.ldb.Select(index); err != nil { - return err + if c.db.IsTransaction() { + if err := c.tx.Select(index); err != nil { + return err + } else { + c.db = c.tx.DB + } + } else if c.db.IsInMulti() { + if err := c.script.Select(index); err != nil { + return err + } else { + c.db = c.script.DB + } } else { - c.db = db - c.resp.writeStatus(OK) + if db, err := c.ldb.Select(index); err != nil { + return err + } else { + c.db = db + } } + c.resp.writeStatus(OK) } return nil diff --git a/server/script.go b/server/script.go new file mode 100644 index 0000000..4f230fa --- /dev/null +++ b/server/script.go @@ -0,0 +1,393 @@ +// +build lua + +package server + +import ( + "encoding/hex" + "fmt" + "github.com/siddontang/golua/lua" + "github.com/siddontang/ledisdb/ledis" + "io" + "sync" +) + +//ledis <-> lua type conversion, same as http://redis.io/commands/eval + +type luaWriter struct { + l *lua.State +} + +func (w *luaWriter) writeError(err error) { + panic(err) +} + +func (w *luaWriter) writeStatus(status string) { + w.l.NewTable() + top := w.l.GetTop() + + w.l.PushString("ok") + w.l.PushString(status) + w.l.SetTable(top) +} + +func (w *luaWriter) writeInteger(n int64) { + w.l.PushInteger(n) +} + +func (w *luaWriter) writeBulk(b []byte) { + if b == nil { + w.l.PushBoolean(false) + } else { + w.l.PushString(ledis.String(b)) + } +} + +func (w *luaWriter) writeArray(lst []interface{}) { + if lst == nil { + w.l.PushBoolean(false) + return + } + + w.l.CreateTable(len(lst), 0) + top := w.l.GetTop() + + for i, _ := range lst { + w.l.PushInteger(int64(i) + 1) + + switch v := lst[i].(type) { + case []interface{}: + w.writeArray(v) + case [][]byte: + w.writeSliceArray(v) + case []byte: + w.writeBulk(v) + case nil: + w.writeBulk(nil) + case int64: + w.writeInteger(v) + default: + panic("invalid array type") + } + + w.l.SetTable(top) + } +} + +func (w *luaWriter) writeSliceArray(lst [][]byte) { + if lst == nil { + w.l.PushBoolean(false) + return + } + + w.l.CreateTable(len(lst), 0) + for i, v := range lst { + w.l.PushString(ledis.String(v)) + w.l.RawSeti(-2, i+1) + } +} + +func (w *luaWriter) writeFVPairArray(lst []ledis.FVPair) { + if lst == nil { + w.l.PushBoolean(false) + return + } + + w.l.CreateTable(len(lst)*2, 0) + for i, v := range lst { + w.l.PushString(ledis.String(v.Field)) + w.l.RawSeti(-2, 2*i+1) + + w.l.PushString(ledis.String(v.Value)) + w.l.RawSeti(-2, 2*i+2) + } +} + +func (w *luaWriter) writeScorePairArray(lst []ledis.ScorePair, withScores bool) { + if lst == nil { + w.l.PushBoolean(false) + return + } + + if withScores { + w.l.CreateTable(len(lst)*2, 0) + for i, v := range lst { + w.l.PushString(ledis.String(v.Member)) + w.l.RawSeti(-2, 2*i+1) + + w.l.PushString(ledis.String(ledis.StrPutInt64(v.Score))) + w.l.RawSeti(-2, 2*i+2) + } + } else { + w.l.CreateTable(len(lst), 0) + for i, v := range lst { + w.l.PushString(ledis.String(v.Member)) + w.l.RawSeti(-2, i+1) + } + } +} + +func (w *luaWriter) writeBulkFrom(n int64, rb io.Reader) { + w.writeError(fmt.Errorf("unsupport")) +} + +func (w *luaWriter) flush() { + +} + +type script struct { + sync.Mutex + + app *App + l *lua.State + c *client + + chunks map[string]struct{} +} + +func (app *App) openScript() { + s := new(script) + s.app = app + + s.chunks = make(map[string]struct{}) + + app.s = s + + l := lua.NewState() + + l.OpenBase() + l.OpenLibs() + l.OpenMath() + l.OpenString() + l.OpenTable() + l.OpenPackage() + + l.OpenCJson() + l.OpenCMsgpack() + l.OpenStruct() + + l.Register("error", luaErrorHandler) + + s.l = l + s.c = newClient(app) + s.c.db = nil + + w := new(luaWriter) + w.l = l + s.c.resp = w + + l.NewTable() + l.PushString("call") + l.PushGoFunction(luaCall) + l.SetTable(-3) + + l.PushString("pcall") + l.PushGoFunction(luaPCall) + l.SetTable(-3) + + l.PushString("sha1hex") + l.PushGoFunction(luaSha1Hex) + l.SetTable(-3) + + l.PushString("error_reply") + l.PushGoFunction(luaErrorReply) + l.SetTable(-3) + + l.PushString("status_reply") + l.PushGoFunction(luaStatusReply) + l.SetTable(-3) + + l.SetGlobal("ledis") + + setMapState(l, s) +} + +func (app *App) closeScript() { + app.s.l.Close() + delMapState(app.s.l) + app.s = nil +} + +var mapState = map[*lua.State]*script{} +var stateLock sync.Mutex + +func setMapState(l *lua.State, s *script) { + stateLock.Lock() + defer stateLock.Unlock() + + mapState[l] = s +} + +func getMapState(l *lua.State) *script { + stateLock.Lock() + defer stateLock.Unlock() + + return mapState[l] +} + +func delMapState(l *lua.State) { + stateLock.Lock() + defer stateLock.Unlock() + + delete(mapState, l) +} + +func luaErrorHandler(l *lua.State) int { + msg := l.ToString(1) + panic(fmt.Errorf(msg)) +} + +func luaCall(l *lua.State) int { + return luaCallGenericCommand(l) +} + +func luaPCall(l *lua.State) (n int) { + defer func() { + if e := recover(); e != nil { + luaPushError(l, fmt.Sprintf("%v", e)) + n = 1 + } + return + }() + return luaCallGenericCommand(l) +} + +func luaErrorReply(l *lua.State) int { + return luaReturnSingleFieldTable(l, "err") +} + +func luaStatusReply(l *lua.State) int { + return luaReturnSingleFieldTable(l, "ok") +} + +func luaReturnSingleFieldTable(l *lua.State, filed string) int { + if l.GetTop() != 1 || l.Type(-1) != lua.LUA_TSTRING { + luaPushError(l, "wrong number or type of arguments") + return 1 + } + + l.NewTable() + l.PushString(filed) + l.PushValue(-3) + l.SetTable(-3) + return 1 +} + +func luaSha1Hex(l *lua.State) int { + argc := l.GetTop() + if argc != 1 { + luaPushError(l, "wrong number of arguments") + return 1 + } + + s := l.ToString(1) + s = hex.EncodeToString(ledis.Slice(s)) + + l.PushString(s) + return 1 +} + +func luaPushError(l *lua.State, msg string) { + l.NewTable() + l.PushString("err") + err := l.NewError(msg) + l.PushString(err.Error()) + l.SetTable(-3) +} + +func luaCallGenericCommand(l *lua.State) int { + s := getMapState(l) + if s == nil { + panic("Invalid lua call") + } else if s.c.db == nil { + panic("Invalid lua call, not prepared") + } + + c := s.c + + argc := l.GetTop() + if argc < 1 { + panic("Please specify at least one argument for ledis.call()") + } + + c.cmd = l.ToString(1) + + c.args = make([][]byte, argc-1) + + for i := 2; i <= argc; i++ { + switch l.Type(i) { + case lua.LUA_TNUMBER: + c.args[i-2] = []byte(fmt.Sprintf("%.17g", l.ToNumber(i))) + case lua.LUA_TSTRING: + c.args[i-2] = []byte(l.ToString(i)) + default: + panic("Lua ledis() command arguments must be strings or integers") + } + } + + c.perform() + + return 1 +} + +func luaSetGlobalArray(l *lua.State, name string, ay [][]byte) { + l.NewTable() + + for i := 0; i < len(ay); i++ { + l.PushString(ledis.String(ay[i])) + l.RawSeti(-2, i+1) + } + + l.SetGlobal(name) +} + +func luaReplyToLedisReply(l *lua.State) interface{} { + base := l.GetTop() + defer func() { + l.SetTop(base - 1) + }() + + switch l.Type(-1) { + case lua.LUA_TSTRING: + return ledis.Slice(l.ToString(-1)) + case lua.LUA_TBOOLEAN: + if l.ToBoolean(-1) { + return int64(1) + } else { + return nil + } + case lua.LUA_TNUMBER: + return int64(l.ToInteger(-1)) + case lua.LUA_TTABLE: + l.PushString("err") + l.GetTable(-2) + if l.Type(-1) == lua.LUA_TSTRING { + return fmt.Errorf("%s", l.ToString(-1)) + } + + l.Pop(1) + l.PushString("ok") + l.GetTable(-2) + if l.Type(-1) == lua.LUA_TSTRING { + return l.ToString(-1) + } else { + l.Pop(1) + + ay := make([]interface{}, 0) + + for i := 1; ; i++ { + l.PushInteger(int64(i)) + l.GetTable(-2) + if l.Type(-1) == lua.LUA_TNIL { + l.Pop(1) + break + } + + ay = append(ay, luaReplyToLedisReply(l)) + } + return ay + + } + default: + return nil + } +} diff --git a/server/script_dummy.go b/server/script_dummy.go new file mode 100644 index 0000000..f19f3b8 --- /dev/null +++ b/server/script_dummy.go @@ -0,0 +1,10 @@ +// +build !lua + +package server + +type script struct { +} + +func (app *App) openScript() {} + +func (app *App) closeScript() {} diff --git a/server/script_test.go b/server/script_test.go new file mode 100644 index 0000000..74160d3 --- /dev/null +++ b/server/script_test.go @@ -0,0 +1,179 @@ +// +build lua + +package server + +import ( + "fmt" + "github.com/siddontang/golua/lua" + "github.com/siddontang/ledisdb/config" + + "testing" +) + +var testLuaWriter = &luaWriter{} + +func testLuaWriteError(l *lua.State) int { + testLuaWriter.writeError(fmt.Errorf("test error")) + return 1 +} + +func testLuaWriteArray(l *lua.State) int { + ay := make([]interface{}, 2) + ay[0] = []byte("1") + b := make([]interface{}, 2) + b[0] = int64(10) + b[1] = []byte("11") + + ay[1] = b + + testLuaWriter.writeArray(ay) + + return 1 +} + +func TestLuaWriter(t *testing.T) { + l := lua.NewState() + + l.OpenBase() + + testLuaWriter.l = l + + l.Register("WriteError", testLuaWriteError) + + str := ` + WriteError() + ` + + err := l.DoString(str) + + if err == nil { + t.Fatal("must error") + } + + l.Register("WriteArray", testLuaWriteArray) + + str = ` + local a = WriteArray() + + if #a ~= 2 then + error("len a must 2") + elseif a[1] ~= "1" then + error("a[1] must 1") + elseif #a[2] ~= 2 then + error("len a[2] must 2") + elseif a[2][1] ~= 10 then + error("a[2][1] must 10") + elseif a[2][2] ~= "11" then + error("a[2][2] must 11") + end + ` + + err = l.DoString(str) + if err != nil { + t.Fatal(err) + } + + l.Close() +} + +var testScript1 = ` + return {1,2,3} +` + +var testScript2 = ` + return ledis.call("ping") +` + +var testScript3 = ` + ledis.call("set", 1, "a") + + local a = ledis.call("get", 1) + if type(a) ~= "string" then + error("must string") + elseif a ~= "a" then + error("must a") + end +` + +var testScript4 = ` + ledis.call("select", 2) + ledis.call("set", 2, "b") +` + +func TestLuaCall(t *testing.T) { + cfg := new(config.Config) + cfg.Addr = ":11188" + cfg.DataDir = "/tmp/testscript" + cfg.DBName = "memory" + + app, e := NewApp(cfg) + if e != nil { + t.Fatal(e) + } + go app.Run() + + defer app.Close() + + db, _ := app.ldb.Select(0) + m, _ := db.Multi() + defer m.Close() + + luaClient := app.s.c + luaClient.db = m.DB + luaClient.script = m + + l := app.s.l + + err := app.s.l.DoString(testScript1) + if err != nil { + t.Fatal(err) + } + + v := luaReplyToLedisReply(l) + if vv, ok := v.([]interface{}); ok { + if len(vv) != 3 { + t.Fatal(len(vv)) + } + } else { + t.Fatal(fmt.Sprintf("%v %T", v, v)) + } + + err = app.s.l.DoString(testScript2) + if err != nil { + t.Fatal(err) + } + + v = luaReplyToLedisReply(l) + if vv := v.(string); vv != "PONG" { + t.Fatal(fmt.Sprintf("%v %T", v, v)) + } + + err = app.s.l.DoString(testScript3) + if err != nil { + t.Fatal(err) + } + + if v, err := db.Get([]byte("1")); err != nil { + t.Fatal(err) + } else if string(v) != "a" { + t.Fatal(string(v)) + } + + err = app.s.l.DoString(testScript4) + if err != nil { + t.Fatal(err) + } + + if luaClient.db.Index() != 2 { + t.Fatal(luaClient.db.Index()) + } + + db2, _ := app.ldb.Select(2) + if v, err := db2.Get([]byte("2")); err != nil { + t.Fatal(err) + } else if string(v) != "b" { + t.Fatal(string(v)) + } + + luaClient.db = nil +} diff --git a/store/db.go b/store/db.go index 76c63e3..ca57326 100644 --- a/store/db.go +++ b/store/db.go @@ -19,6 +19,16 @@ func (db *DB) NewWriteBatch() WriteBatch { return db.IDB.NewWriteBatch() } +func (db *DB) NewSnapshot() (*Snapshot, error) { + var err error + s := &Snapshot{} + if s.ISnapshot, err = db.IDB.NewSnapshot(); err != nil { + return nil, err + } + + return s, nil +} + func (db *DB) RangeIterator(min []byte, max []byte, rangeType uint8) *RangeLimitIterator { return NewRangeLimitIterator(db.NewIterator(), &Range{min, max, rangeType}, &Limit{0, -1}) } diff --git a/store/goleveldb/const.go b/store/goleveldb/const.go index 6486e3f..2fffa7c 100644 --- a/store/goleveldb/const.go +++ b/store/goleveldb/const.go @@ -1,3 +1,4 @@ package goleveldb const DBName = "goleveldb" +const MemDBName = "memory" diff --git a/store/goleveldb/db.go b/store/goleveldb/db.go index ca2de60..e873feb 100644 --- a/store/goleveldb/db.go +++ b/store/goleveldb/db.go @@ -5,6 +5,8 @@ import ( "github.com/siddontang/goleveldb/leveldb/cache" "github.com/siddontang/goleveldb/leveldb/filter" "github.com/siddontang/goleveldb/leveldb/opt" + "github.com/siddontang/goleveldb/leveldb/storage" + "github.com/siddontang/ledisdb/config" "github.com/siddontang/ledisdb/store/driver" @@ -20,6 +22,13 @@ func (s Store) String() string { return DBName } +type MemStore struct { +} + +func (s MemStore) String() string { + return MemDBName +} + type DB struct { path string @@ -45,7 +54,12 @@ func (s Store) Open(path string, cfg *config.Config) (driver.IDB, error) { db.path = path db.cfg = &cfg.LevelDB - if err := db.open(); err != nil { + db.initOpts() + + var err error + db.db, err = leveldb.OpenFile(db.path, db.opts) + + if err != nil { return nil, err } @@ -62,16 +76,31 @@ func (s Store) Repair(path string, cfg *config.Config) error { return nil } -func (db *DB) open() error { +func (s MemStore) Open(path string, cfg *config.Config) (driver.IDB, error) { + db := new(DB) + db.path = path + db.cfg = &cfg.LevelDB + + db.initOpts() + + var err error + db.db, err = leveldb.Open(storage.NewMemStorage(), db.opts) + if err != nil { + return nil, err + } + + return db, nil +} + +func (s MemStore) Repair(path string, cfg *config.Config) error { + return nil +} + +func (db *DB) initOpts() { db.opts = newOptions(db.cfg) db.iteratorOpts = &opt.ReadOptions{} db.iteratorOpts.DontFillCache = true - - var err error - db.db, err = leveldb.OpenFile(db.path, db.opts) - - return err } func newOptions(cfg *config.LevelDBConfig) *opt.Options { @@ -153,4 +182,5 @@ func (db *DB) NewSnapshot() (driver.ISnapshot, error) { func init() { driver.Register(Store{}) + driver.Register(MemStore{}) } diff --git a/store/hyperleveldb/db.go b/store/hyperleveldb/db.go index e8ec944..6d0e176 100644 --- a/store/hyperleveldb/db.go +++ b/store/hyperleveldb/db.go @@ -234,10 +234,8 @@ func (db *DB) get(ro *ReadOptions, key []byte) ([]byte, error) { k = (*C.char)(unsafe.Pointer(&key[0])) } - var value *C.char - - c := C.hyperleveldb_get_ext( - db.db, ro.Opt, k, C.size_t(len(key)), &value, &vallen, &errStr) + value := C.leveldb_get( + db.db, ro.Opt, k, C.size_t(len(key)), &vallen, &errStr) if errStr != nil { return nil, saveError(errStr) @@ -247,7 +245,7 @@ func (db *DB) get(ro *ReadOptions, key []byte) ([]byte, error) { return nil, nil } - defer C.hyperleveldb_get_free_ext(unsafe.Pointer(c)) + defer C.leveldb_free(unsafe.Pointer(value)) return C.GoBytes(unsafe.Pointer(value), C.int(vallen)), nil } diff --git a/store/hyperleveldb/hyperleveldb_ext.cc b/store/hyperleveldb/hyperleveldb_ext.cc index dab687c..f775ee9 100644 --- a/store/hyperleveldb/hyperleveldb_ext.cc +++ b/store/hyperleveldb/hyperleveldb_ext.cc @@ -3,60 +3,60 @@ #include "hyperleveldb_ext.h" #include -#include +//#include -#include "hyperleveldb/db.h" +//#include "hyperleveldb/db.h" -using namespace leveldb; +//using namespace leveldb; extern "C" { -static bool SaveError(char** errptr, const Status& s) { - assert(errptr != NULL); - if (s.ok()) { - return false; - } else if (*errptr == NULL) { - *errptr = strdup(s.ToString().c_str()); - } else { - free(*errptr); - *errptr = strdup(s.ToString().c_str()); - } - return true; -} +// static bool SaveError(char** errptr, const Status& s) { +// assert(errptr != NULL); +// if (s.ok()) { +// return false; +// } else if (*errptr == NULL) { +// *errptr = strdup(s.ToString().c_str()); +// } else { +// free(*errptr); +// *errptr = strdup(s.ToString().c_str()); +// } +// return true; +// } -void* hyperleveldb_get_ext( - leveldb_t* db, - const leveldb_readoptions_t* options, - const char* key, size_t keylen, - char** valptr, - size_t* vallen, - char** errptr) { +// void* hyperleveldb_get_ext( +// leveldb_t* db, +// const leveldb_readoptions_t* options, +// const char* key, size_t keylen, +// char** valptr, +// size_t* vallen, +// char** errptr) { - std::string *tmp = new(std::string); +// std::string *tmp = new(std::string); - //very tricky, maybe changed with c++ leveldb upgrade - Status s = (*(DB**)db)->Get(*(ReadOptions*)options, Slice(key, keylen), tmp); +// //very tricky, maybe changed with c++ leveldb upgrade +// Status s = (*(DB**)db)->Get(*(ReadOptions*)options, Slice(key, keylen), tmp); - if (s.ok()) { - *valptr = (char*)tmp->data(); - *vallen = tmp->size(); - } else { - delete(tmp); - tmp = NULL; - *valptr = NULL; - *vallen = 0; - if (!s.IsNotFound()) { - SaveError(errptr, s); - } - } - return tmp; -} +// if (s.ok()) { +// *valptr = (char*)tmp->data(); +// *vallen = tmp->size(); +// } else { +// delete(tmp); +// tmp = NULL; +// *valptr = NULL; +// *vallen = 0; +// if (!s.IsNotFound()) { +// SaveError(errptr, s); +// } +// } +// return tmp; +// } -void hyperleveldb_get_free_ext(void* context) { - std::string* s = (std::string*)context; +// void hyperleveldb_get_free_ext(void* context) { +// std::string* s = (std::string*)context; - delete(s); -} +// delete(s); +// } unsigned char hyperleveldb_iter_seek_to_first_ext(leveldb_iterator_t* iter) { diff --git a/store/hyperleveldb/hyperleveldb_ext.h b/store/hyperleveldb/hyperleveldb_ext.h index 940a090..9182768 100644 --- a/store/hyperleveldb/hyperleveldb_ext.h +++ b/store/hyperleveldb/hyperleveldb_ext.h @@ -10,19 +10,19 @@ extern "C" { #include "hyperleveldb/c.h" -/* Returns NULL if not found. Otherwise stores the value in **valptr. - Stores the length of the value in *vallen. - Returns a context must be later to free*/ -extern void* hyperleveldb_get_ext( - leveldb_t* db, - const leveldb_readoptions_t* options, - const char* key, size_t keylen, - char** valptr, - size_t* vallen, - char** errptr); +// /* Returns NULL if not found. Otherwise stores the value in **valptr. +// Stores the length of the value in *vallen. +// Returns a context must be later to free*/ +// extern void* hyperleveldb_get_ext( +// leveldb_t* db, +// const leveldb_readoptions_t* options, +// const char* key, size_t keylen, +// char** valptr, +// size_t* vallen, +// char** errptr); -// Free context returns by hyperleveldb_get_ext -extern void hyperleveldb_get_free_ext(void* context); +// // Free context returns by hyperleveldb_get_ext +// extern void hyperleveldb_get_free_ext(void* context); // Below iterator functions like leveldb iterator but returns valid status for iterator diff --git a/store/leveldb/db.go b/store/leveldb/db.go index 0a40953..43ee0c2 100644 --- a/store/leveldb/db.go +++ b/store/leveldb/db.go @@ -234,10 +234,8 @@ func (db *DB) get(ro *ReadOptions, key []byte) ([]byte, error) { k = (*C.char)(unsafe.Pointer(&key[0])) } - var value *C.char - - c := C.leveldb_get_ext( - db.db, ro.Opt, k, C.size_t(len(key)), &value, &vallen, &errStr) + value := C.leveldb_get( + db.db, ro.Opt, k, C.size_t(len(key)), &vallen, &errStr) if errStr != nil { return nil, saveError(errStr) @@ -247,7 +245,7 @@ func (db *DB) get(ro *ReadOptions, key []byte) ([]byte, error) { return nil, nil } - defer C.leveldb_get_free_ext(unsafe.Pointer(c)) + defer C.leveldb_free(unsafe.Pointer(value)) return C.GoBytes(unsafe.Pointer(value), C.int(vallen)), nil } diff --git a/store/leveldb/leveldb_ext.cc b/store/leveldb/leveldb_ext.cc index 96d6541..a362ab5 100644 --- a/store/leveldb/leveldb_ext.cc +++ b/store/leveldb/leveldb_ext.cc @@ -3,60 +3,60 @@ #include "leveldb_ext.h" #include -#include +//#include -#include "leveldb/db.h" +//#include "leveldb/db.h" -using namespace leveldb; +//using namespace leveldb; extern "C" { -static bool SaveError(char** errptr, const Status& s) { - assert(errptr != NULL); - if (s.ok()) { - return false; - } else if (*errptr == NULL) { - *errptr = strdup(s.ToString().c_str()); - } else { - free(*errptr); - *errptr = strdup(s.ToString().c_str()); - } - return true; -} +// static bool SaveError(char** errptr, const Status& s) { +// assert(errptr != NULL); +// if (s.ok()) { +// return false; +// } else if (*errptr == NULL) { +// *errptr = strdup(s.ToString().c_str()); +// } else { +// free(*errptr); +// *errptr = strdup(s.ToString().c_str()); +// } +// return true; +// } -void* leveldb_get_ext( - leveldb_t* db, - const leveldb_readoptions_t* options, - const char* key, size_t keylen, - char** valptr, - size_t* vallen, - char** errptr) { +// void* leveldb_get_ext( +// leveldb_t* db, +// const leveldb_readoptions_t* options, +// const char* key, size_t keylen, +// char** valptr, +// size_t* vallen, +// char** errptr) { - std::string *tmp = new(std::string); +// std::string *tmp = new(std::string); - //very tricky, maybe changed with c++ leveldb upgrade - Status s = (*(DB**)db)->Get(*(ReadOptions*)options, Slice(key, keylen), tmp); +// //very tricky, maybe changed with c++ leveldb upgrade +// Status s = (*(DB**)db)->Get(*(ReadOptions*)options, Slice(key, keylen), tmp); - if (s.ok()) { - *valptr = (char*)tmp->data(); - *vallen = tmp->size(); - } else { - delete(tmp); - tmp = NULL; - *valptr = NULL; - *vallen = 0; - if (!s.IsNotFound()) { - SaveError(errptr, s); - } - } - return tmp; -} +// if (s.ok()) { +// *valptr = (char*)tmp->data(); +// *vallen = tmp->size(); +// } else { +// delete(tmp); +// tmp = NULL; +// *valptr = NULL; +// *vallen = 0; +// if (!s.IsNotFound()) { +// SaveError(errptr, s); +// } +// } +// return tmp; +// } -void leveldb_get_free_ext(void* context) { - std::string* s = (std::string*)context; +// void leveldb_get_free_ext(void* context) { +// std::string* s = (std::string*)context; - delete(s); -} +// delete(s); +// } unsigned char leveldb_iter_seek_to_first_ext(leveldb_iterator_t* iter) { diff --git a/store/leveldb/leveldb_ext.h b/store/leveldb/leveldb_ext.h index 8222ae3..1c5f986 100644 --- a/store/leveldb/leveldb_ext.h +++ b/store/leveldb/leveldb_ext.h @@ -10,19 +10,19 @@ extern "C" { #include "leveldb/c.h" -/* Returns NULL if not found. Otherwise stores the value in **valptr. - Stores the length of the value in *vallen. - Returns a context must be later to free*/ -extern void* leveldb_get_ext( - leveldb_t* db, - const leveldb_readoptions_t* options, - const char* key, size_t keylen, - char** valptr, - size_t* vallen, - char** errptr); +// /* Returns NULL if not found. Otherwise stores the value in **valptr. +// Stores the length of the value in *vallen. +// Returns a context must be later to free*/ +// extern void* leveldb_get_ext( +// leveldb_t* db, +// const leveldb_readoptions_t* options, +// const char* key, size_t keylen, +// char** valptr, +// size_t* vallen, +// char** errptr); -// Free context returns by leveldb_get_ext -extern void leveldb_get_free_ext(void* context); +// // Free context returns by leveldb_get_ext +// extern void leveldb_get_free_ext(void* context); // Below iterator functions like leveldb iterator but returns valid status for iterator diff --git a/store/mdb/mdb.go b/store/mdb/mdb.go index 171c088..d5c3987 100644 --- a/store/mdb/mdb.go +++ b/store/mdb/mdb.go @@ -92,14 +92,16 @@ func (s Store) Repair(path string, c *config.Config) error { func (db MDB) Put(key, value []byte) error { itr := db.iterator(false) + defer itr.Close() itr.err = itr.c.Put(key, value, 0) itr.setState() - return itr.Close() + return itr.err } func (db MDB) BatchPut(writes []driver.Write) error { itr := db.iterator(false) + defer itr.Close() for _, w := range writes { if w.Value == nil { @@ -117,7 +119,7 @@ func (db MDB) BatchPut(writes []driver.Write) error { } itr.setState() - return itr.Close() + return itr.err } func (db MDB) Get(key []byte) ([]byte, error) { @@ -208,6 +210,8 @@ func (itr *MDBIterator) setState() { itr.err = nil } itr.valid = false + } else { + itr.valid = true } } diff --git a/store/snapshot.go b/store/snapshot.go new file mode 100644 index 0000000..3f7538a --- /dev/null +++ b/store/snapshot.go @@ -0,0 +1,16 @@ +package store + +import ( + "github.com/siddontang/ledisdb/store/driver" +) + +type Snapshot struct { + driver.ISnapshot +} + +func (s *Snapshot) NewIterator() *Iterator { + it := new(Iterator) + it.it = s.ISnapshot.NewIterator() + + return it +} diff --git a/store/store.go b/store/store.go index 50d2744..e2a6b85 100644 --- a/store/store.go +++ b/store/store.go @@ -7,12 +7,12 @@ import ( "os" "path" - "github.com/siddontang/ledisdb/store/boltdb" - "github.com/siddontang/ledisdb/store/goleveldb" - "github.com/siddontang/ledisdb/store/hyperleveldb" - "github.com/siddontang/ledisdb/store/leveldb" - "github.com/siddontang/ledisdb/store/mdb" - "github.com/siddontang/ledisdb/store/rocksdb" + _ "github.com/siddontang/ledisdb/store/boltdb" + _ "github.com/siddontang/ledisdb/store/goleveldb" + _ "github.com/siddontang/ledisdb/store/hyperleveldb" + _ "github.com/siddontang/ledisdb/store/leveldb" + _ "github.com/siddontang/ledisdb/store/mdb" + _ "github.com/siddontang/ledisdb/store/rocksdb" ) func getStorePath(cfg *config.Config) string { @@ -53,10 +53,4 @@ func Repair(cfg *config.Config) error { } func init() { - _ = boltdb.DBName - _ = goleveldb.DBName - _ = hyperleveldb.DBName - _ = leveldb.DBName - _ = mdb.DBName - _ = rocksdb.DBName } diff --git a/tools/check_lua.go b/tools/check_lua.go new file mode 100644 index 0000000..bc82c04 --- /dev/null +++ b/tools/check_lua.go @@ -0,0 +1,10 @@ +// +build ignore + +package main + +import "github.com/siddontang/golua/lua" + +func main() { + L := lua.NewState() + L.Close() +}