diff --git a/bootstrap.sh b/bootstrap.sh index a62b20e..6dd225f 100755 --- a/bootstrap.sh +++ b/bootstrap.sh @@ -7,4 +7,5 @@ go get github.com/siddontang/go-snappy/snappy go get github.com/siddontang/copier go get github.com/siddontang/goleveldb/leveldb -go get github.com/influxdb/gomdb \ No newline at end of file + +go get -d github.com/siddontang/gomdb diff --git a/dev.sh b/dev.sh index 9709cc6..6d753f2 100644 --- a/dev.sh +++ b/dev.sh @@ -35,7 +35,7 @@ CGO_LDFLAGS= # check dependent libray, now we only check simply, maybe later add proper checking way. # check snappy -if [ -f $LEVELDB_DIR/include/snappy.h ]; then +if [ -f $SNAPPY_DIR/include/snappy.h ]; then CGO_CFLAGS="$CGO_CFLAGS -I$SNAPPY_DIR/include" CGO_CXXFLAGS="$CGO_CXXFLAGS -I$SNAPPY_DIR/include" CGO_LDFLAGS="$CGO_LDFLAGS -L$SNAPPY_DIR/lib -lsnappy" diff --git a/doc/commands.json b/doc/commands.json new file mode 100644 index 0000000..521f86e --- /dev/null +++ b/doc/commands.json @@ -0,0 +1,418 @@ +{ + "BCOUNT": { + "arguments": "key [start end]", + "group": "Bitmap", + "readonly": true + }, + "BDELETE": { + "arguments": "key", + "group": "ZSet", + "readonly": false + }, + "BEXPIRE": { + "arguments": "key seconds", + "group": "Bitmap", + "readonly": false + }, + "BEXPIREAT": { + "arguments": "key timestamp", + "group": "Bitmap", + "readonly": false, + }, + "BGET": { + "arguments": "key", + "group": "Bitmap", + "readonly": true + }, + "BGETBIT": { + "arguments": "key offset", + "group": "Bitmap", + "readonly": true + }, + "BMSETBIT": { + "arguments": "key offset value [offset value ...]", + "group": "Bitmap", + "readonly": false + }, + "BOPT": { + "arguments": "operation destkey key [key ...]", + "group": "Bitmap", + "readonly": false + }, + "BPERSIST": { + "arguments": "key", + "group": "Bitmap", + "readonly": false + }, + "BSETBIT": { + "arguments": "key offset value", + "group": "Bitmap", + "readonly": false + }, + "BTTL": { + "arguments": "key", + "group": "Bitmap", + "readonly": true + }, + "DECR": { + "arguments": "key", + "group": "KV", + "readonly": false + }, + "DECRBY": { + "arguments": "key decrement", + "group": "KV", + "readonly": false + }, + "DEL": { + "arguments": "key [key ...]", + "group": "KV", + "readonly": false + }, + "ECHO": { + "arguments": "message", + "group": "Server", + "readonly": true + }, + "EXISTS": { + "arguments": "key", + "group": "KV", + "readonly": true + }, + "EXPIRE": { + "arguments": "key seconds", + "group": "KV", + "readonly": false + }, + "EXPIREAT": { + "arguments": "key timestamp", + "group": "KV", + "readonly": false + }, + "FULLSYNC": { + "arguments": "-", + "group": "Replication", + "readonly": false + + }, + "GET": { + "arguments": "key", + "group": "KV", + "readonly": true + }, + "GETSET": { + "arguments": " key value", + "group": "KV", + "readonly": false + }, + "HCLEAR": { + "arguments": "key", + "group": "Hash", + "readonly": false + }, + "HDEL": { + "arguments": "key field [field ...]", + "group": "Hash", + "readonly": false + }, + "HEXISTS": { + "arguments": "key field", + "group": "Hash", + "readonly": true + }, + "HEXPIRE": { + "arguments": "key seconds", + "group": "Hash", + "readonly": false + }, + "HEXPIREAT": { + "arguments": "key timestamp", + "group": "Hash", + "readonly": false + }, + "HGET": { + "arguments": "key field", + "group": "Hash", + "readonly": true + }, + "HGETALL": { + "arguments": "key", + "group": "Hash", + "readonly": true + }, + "HINCRBY": { + "arguments": "key field increment", + "group": "Hash", + "readonly": false + }, + "HKEYS": { + "arguments": "key", + "group": "Hash", + "readonly": true + }, + "HLEN": { + "arguments": "key", + "group": "Hash", + "readonly": true + }, + "HMCLEAR": { + "arguments": "key [key ...]", + "group": "Hash", + "readonly": false + }, + "HMGET": { + "arguments": "key field [field ...]", + "group": "Hash", + "readonly": true + }, + "HMSET": { + "arguments": "key field value [field value ...]", + "group": "Hash", + "readonly": false + }, + "HPERSIST": { + "arguments": "key", + "group": "Hash", + "readonly": false + }, + "HSET": { + "arguments": "key field value", + "group": "Hash", + "readonly": false + }, + "HTTL": { + "arguments": "key", + "group": "Hash", + "readonly": true + }, + "HVALS": { + "arguments": "key", + "group": "Hash", + "readonly": true + }, + "INCR": { + "arguments": "key", + "group": "KV", + "readonly": false + }, + "INCRBY": { + "arguments": "key increment", + "group": "KV", + "readonly": false + }, + "LCLEAR": { + "arguments": "key", + "group": "List", + "readonly": false + }, + "LEXPIRE": { + "arguments": "key seconds", + "group": "List", + "readonly": false + }, + "LEXPIREAT": { + "arguments": "key timestamp", + "group": "List", + "readonly": false + }, + "LINDEX": { + "arguments": "key index", + "group": "List", + "readonly": true + }, + "LLEN": { + "arguments": "key", + "group": "List", + "readonly": true + }, + "LMCLEAR": { + "arguments": "key [key ...]", + "group": "List", + "readonly": false + }, + "LPERSIST": { + "arguments": "key", + "group": "List", + "readonly": false + }, + "LPOP": { + "arguments": "key", + "group": "List", + "readonly": false + }, + "LPUSH": { + "arguments": "key value [value ...]", + "group": "List", + "readonly": false + }, + "LRANGE": { + "arguments": "key start stop", + "group": "List", + "readonly": true + }, + "LTTL": { + "arguments": "key", + "group": "List", + "readonly": true + }, + "MGET": { + "arguments": "key [key ...]", + "group": "KV", + "readonly": true + }, + "MSET": { + "arguments": "key value [key value ...]", + "group": "KV", + "readonly": false + }, + "PERSIST": { + "arguments": "key", + "group": "KV", + "readonly": false + }, + "PING": { + "arguments": "-", + "group": "Server", + "readonly": true + }, + "RPOP": { + "arguments": "key", + "group": "List", + "readonly": false + }, + "RPUSH": { + "arguments": "key value [value ...]", + "group": "List", + "readonly": false + }, + "SELECT": { + "arguments": "index", + "group": "Server", + "readonly": false + }, + "SET": { + "arguments": "key value", + "group": "KV", + "readonly": false + }, + "SETNX": { + "arguments": "key value", + "group": "KV", + "readonly": false + }, + "SLAVEOF": { + "arguments": "host port", + "group": "Replication", + "readonly": false + }, + "SYNC": { + "arguments": "index offset", + "group": "Replication", + "readonly": false + }, + "TTL": { + "arguments": "key", + "group": "KV", + "readonly": true + }, + "ZADD": { + "arguments": "key score member [score member ...]", + "group": "ZSet", + "readonly": false + }, + "ZCARD": { + "arguments": "key", + "group": "ZSet", + "readonly": true + }, + "ZCLEAR": { + "arguments": "key", + "group": "ZSet", + "readonly": false + }, + "ZCOUNT": { + "arguments": "key min max", + "group": "ZSet", + "readonly": true + }, + "ZEXPIRE": { + "arguments": "key seconds", + "group": "ZSet", + "readonly": false + }, + "ZEXPIREAT": { + "arguments": "key timestamp", + "group": "ZSet", + "readonly": false + }, + "ZINCRBY": { + "arguments": "key increment member", + "group": "ZSet", + "readonly": false + }, + "ZMCLEAR": { + "arguments": "key [key ...]", + "group": "ZSet", + "readonly": false + }, + "ZPERSIST": { + "arguments": "key", + "group": "ZSet", + "readonly": false + }, + "ZRANGE": { + "arguments": "key start stop [WITHSCORES]", + "group": "ZSet", + "readonly": false + }, + "ZRANGEBYSCORE": { + "arguments": "key min max [WITHSCORES] [LIMIT offset count]", + "group": "ZSet", + "readonly": true + }, + "ZRANK": { + "arguments": "key member", + "group": "ZSet", + "readonly": true + }, + "ZREM": { + "arguments": "key member [member ...]", + "group": "ZSet", + "readonly": false + }, + "ZREMRANGEBYRANK": { + "arguments": "key start stop", + "group": "ZSet", + "readonly": false + }, + "ZREMRANGEBYSCORE": { + "arguments": "key min max", + "group": "ZSet", + "readonly": false + }, + "ZREVRANGE": { + "arguments": "key start stop [WITHSCORES]", + "group": "ZSet", + "readonly": true + }, + "ZREVRANGEBYSCORE": { + "arguments": "key max min [WITHSCORES][LIMIT offset count]", + "group": "ZSet", + "readonly": true + }, + "ZREVRANK": { + "arguments": "key member", + "group": "ZSet", + "readonly": true + }, + "ZSCORE": { + "arguments": "key member", + "group": "ZSet", + "readonly": true + }, + "ZTTL": { + "arguments": "key", + "group": "ZSet", + "readonly": true + } +} \ No newline at end of file diff --git a/generate.py b/generate.py new file mode 100644 index 0000000..beba9e5 --- /dev/null +++ b/generate.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python + +import json +import time +from collections import OrderedDict as dict + + +def go_array_to_json(path): + """Convert `./cmd/ledis-cli/const.go` to commands.json""" + fp = open(path).read() + commands_str = fp.split("string")[1] + _commands_str = commands_str.splitlines()[1:len(commands_str.splitlines())-1] + commands_d = dict() + values_d = dict() + for i in _commands_str: + t = i.split('"') + values_d.update( + { + "arguments": "%s" % t[3], + "group": "%s" % t[5] + }) + values_d = dict(sorted(values_d.items())) + d = { + "%s" % t[1]: values_d + } + commands_d.update(d) + + fp = open("commands.json", "w") + json.dump(commands_d, fp, indent=4) + fp.close() + + +def json_to_js(json_path, js_path): + """Convert `commands.json` to `commands.js`""" + keys = [] + with open(json_path) as fp: + _json = json.load(fp) + for k in _json.keys(): + keys.append(k.encode('utf-8')) + with open(js_path, "w") as fp: + generate_time(fp) + fp.write("module.exports = [\n" ) + for k in sorted(keys): + fp.write('\t"%s",\n' % k.lower()) + fp.write("]") + + +def json_to_go_array(json_path, go_path): + g_fp = open(go_path, "w") + with open(json_path) as fp: + _json = json.load(fp) + generate_time(g_fp) + g_fp.write("package main\n\nvar helpCommands = [][]string{\n") + _json_sorted = dict(sorted(_json.items(), key=lambda x: x[0])) + for k, v in _json_sorted.iteritems(): + print k, v + g_fp.write('\t{"%s", "%s", "%s"},\n' % (k, v["arguments"], v["group"])) + g_fp.write("}\n") + g_fp.close() + + +def generate_time(fp): + fp.write("//This file was generated by ./generate.py on %s \n" % \ + time.strftime('%a %b %d %Y %H:%M:%S %z')) + +if __name__ == "__main__": + path = "./cmd/ledis-cli/const.go" + # go_array_to_json(path) + json_path = "./commands.json" + js_path = "./commands.js" + json_to_js(json_path, js_path) + go_path = "const.go" + + json_to_go_array(json_path, path) diff --git a/ledis/tx.go b/ledis/tx.go index f49cd34..09e4cea 100644 --- a/ledis/tx.go +++ b/ledis/tx.go @@ -9,7 +9,7 @@ type tx struct { m sync.Mutex l *Ledis - wb *store.WriteBatch + wb store.WriteBatch binlog *BinLog batch [][]byte @@ -27,7 +27,7 @@ func newTx(l *Ledis) *tx { } func (t *tx) Close() { - t.wb.Close() + t.wb = nil } func (t *tx) Put(key []byte, value []byte) { diff --git a/server/cmd_bit.go b/server/cmd_bit.go index 8fdd678..cb3d593 100644 --- a/server/cmd_bit.go +++ b/server/cmd_bit.go @@ -44,13 +44,18 @@ func bsetbitCommand(c *client) error { var val int8 offset, err = ledis.StrInt32(args[1], nil) + if err != nil { - return err + return ErrOffset } val, err = ledis.StrInt8(args[2], nil) + if val != 0 && val != 1 { + return ErrBool + } + if err != nil { - return err + return ErrBool } if ori, err := c.db.BSetBit(args[0], offset, uint8(val)); err != nil { @@ -68,8 +73,9 @@ func bgetbitCommand(c *client) error { } offset, err := ledis.StrInt32(args[1], nil) + if err != nil { - return err + return ErrOffset } if v, err := c.db.BGetBit(args[0], offset); err != nil { @@ -100,13 +106,18 @@ func bmsetbitCommand(c *client) error { pairs := make([]ledis.BitPair, len(args)>>1) for i := 0; i < len(pairs); i++ { offset, err = ledis.StrInt32(args[i<<1], nil) + if err != nil { - return err + return ErrOffset } val, err = ledis.StrInt8(args[i<<1+1], nil) + if val != 0 && val != 1 { + return ErrBool + } + if err != nil { - return err + return ErrBool } pairs[i].Pos = offset @@ -137,14 +148,14 @@ func bcountCommand(c *client) error { if argCnt > 1 { start, err = ledis.StrInt32(args[1], nil) if err != nil { - return err + return ErrValue } } if argCnt > 2 { end, err = ledis.StrInt32(args[2], nil) if err != nil { - return err + return ErrValue } } @@ -180,6 +191,9 @@ func boptCommand(c *client) error { return ErrCmdParams } + if len(srcKeys) == 0 { + return ErrCmdParams + } if blen, err := c.db.BOperation(op, dstKey, srcKeys...); err != nil { return err } else { @@ -190,7 +204,7 @@ func boptCommand(c *client) error { func bexpireCommand(c *client) error { args := c.args - if len(args) == 0 { + if len(args) != 2 { return ErrCmdParams } @@ -208,9 +222,9 @@ func bexpireCommand(c *client) error { return nil } -func bexpireatCommand(c *client) error { +func bexpireAtCommand(c *client) error { args := c.args - if len(args) == 0 { + if len(args) != 2 { return ErrCmdParams } @@ -230,7 +244,7 @@ func bexpireatCommand(c *client) error { func bttlCommand(c *client) error { args := c.args - if len(args) == 0 { + if len(args) != 1 { return ErrCmdParams } @@ -267,7 +281,7 @@ func init() { register("bcount", bcountCommand) register("bopt", boptCommand) register("bexpire", bexpireCommand) - register("bexpireat", bexpireatCommand) + register("bexpireat", bexpireAtCommand) register("bttl", bttlCommand) register("bpersist", bpersistCommand) } diff --git a/server/cmd_bit_test.go b/server/cmd_bit_test.go index f5438f9..7e45d4e 100644 --- a/server/cmd_bit_test.go +++ b/server/cmd_bit_test.go @@ -221,3 +221,163 @@ func testBitOpt(t *testing.T) { return } + +func TestBitErrorParams(t *testing.T) { + c := getTestConn() + defer c.Close() + + if _, err := c.Do("bget"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("bdelete"); err == nil { + t.Fatal("invalid err of %v", err) + } + + // bsetbit + if _, err := c.Do("bsetbit"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("bsetbit", "test_bsetbit"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("bsetbit", "test_bsetbit", "o", "v"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("bsetbit", "test_bsetbit", "o", 1); err == nil { + t.Fatal("invalid err of %v", err) + } + + // if _, err := c.Do("bsetbit", "test_bsetbit", -1, 1); err == nil { + // t.Fatal("invalid err of %v", err) + // } + + if _, err := c.Do("bsetbit", "test_bsetbit", 1, "v"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("bsetbit", "test_bsetbit", 1, 2); err == nil { + t.Fatal("invalid err of %v", err) + } + + //bgetbit + if _, err := c.Do("bgetbit", "test_bgetbit"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("bgetbit", "test_bgetbit", "o"); err == nil { + t.Fatal("invalid err of %v", err) + } + + // if _, err := c.Do("bgetbit", "test_bgetbit", -1); err == nil { + // t.Fatal("invalid err of %v", err) + // } + + //bmsetbit + if _, err := c.Do("bmsetbit", "test_bmsetbit"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("bmsetbit", "test_bmsetbit", 0, 1, 2); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("bmsetbit", "test_bmsetbit", "o", "v"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("bmsetbit", "test_bmsetbit", "o", 1); err == nil { + t.Fatal("invalid err of %v", err) + } + + // if _, err := c.Do("bmsetbit", "test_bmsetbit", -1, 1); err == nil { + // t.Fatal("invalid err of %v", err) + // } + + if _, err := c.Do("bmsetbit", "test_bmsetbit", 1, "v"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("bmsetbit", "test_bmsetbit", 1, 2); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("bmsetbit", "test_bmsetbit", 1, 0.1); err == nil { + t.Fatal("invalid err of %v", err) + } + + //bcount + + if _, err := c.Do("bcount"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("bcount", "a", "b", "c"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("bcount", 1, "a"); err == nil { + t.Fatal("invalid err of %v", err) + } + + // if _, err := c.Do("bcount", 1); err == nil { + // t.Fatal("invalid err of %v", err) + // } + + //bopt + if _, err := c.Do("bopt"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("bopt", "and", 1); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("bopt", "x", 1); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("bopt", ""); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("bexpire", "test_bexpire"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("bexpireat", "test_bexpireat"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("bttl"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("bpersist"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //bexpire + if _, err := c.Do("bexpire", "test_bexpire"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //bexpireat + if _, err := c.Do("bexpireat", "test_bexpireat"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //bttl + if _, err := c.Do("bttl"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //bpersist + if _, err := c.Do("bpersist"); err == nil { + t.Fatal("invalid err of %v", err) + } + +} diff --git a/server/cmd_hash_test.go b/server/cmd_hash_test.go index 61f513f..04388cc 100644 --- a/server/cmd_hash_test.go +++ b/server/cmd_hash_test.go @@ -221,79 +221,79 @@ func TestHashErrorParams(t *testing.T) { c := getTestConn() defer c.Close() - if _, err := c.Do("hset", "test_hset"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("hset", "test_hset"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("hget", "test_hget"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("hget", "test_hget"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("hexists", "test_hexists"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("hexists", "test_hexists"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("hdel", "test_hdel"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("hdel", "test_hdel"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("hlen", "test_hlen", "a"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("hlen", "test_hlen", "a"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("hincrby", "test_hincrby"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("hincrby", "test_hincrby"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("hmset", "test_hmset"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("hmset", "test_hmset"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("hmset", "test_hmset", "f1", "v1", "f2"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("hmset", "test_hmset", "f1", "v1", "f2"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("hmget", "test_hget"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("hmget", "test_hget"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("hgetall"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("hgetall"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("hkeys"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("hkeys"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("hvals"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("hvals"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("hclear"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("hclear"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("hclear", "test_hclear", "a"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("hclear", "test_hclear", "a"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("hmclear"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("hmclear"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("hexpire", "test_hexpire"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("hexpire", "test_hexpire"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("hexpireat", "test_hexpireat"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("hexpireat", "test_hexpireat"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("httl"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("httl"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("hpersist"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("hpersist"); err == nil { t.Fatal("invalid err of %v", err) } diff --git a/server/cmd_kv_test.go b/server/cmd_kv_test.go index 45ac11e..d24fd25 100644 --- a/server/cmd_kv_test.go +++ b/server/cmd_kv_test.go @@ -142,75 +142,75 @@ func TestKVErrorParams(t *testing.T) { c := getTestConn() defer c.Close() - if _, err := c.Do("get", "a", "b", "c"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("get", "a", "b", "c"); err == nil { t.Fatalf("invalid err %v", err) } - if _, err := c.Do("set", "a", "b", "c"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("set", "a", "b", "c"); err == nil { t.Fatalf("invalid err %v", err) } - if _, err := c.Do("getset", "a", "b", "c"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("getset", "a", "b", "c"); err == nil { t.Fatalf("invalid err %v", err) } - if _, err := c.Do("setnx", "a", "b", "c"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("setnx", "a", "b", "c"); err == nil { t.Fatalf("invalid err %v", err) } - if _, err := c.Do("exists", "a", "b"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("exists", "a", "b"); err == nil { t.Fatalf("invalid err %v", err) } - if _, err := c.Do("incr", "a", "b"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("incr", "a", "b"); err == nil { t.Fatalf("invalid err %v", err) } - if _, err := c.Do("incrby", "a"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("incrby", "a"); err == nil { t.Fatalf("invalid err %v", err) } - if _, err := c.Do("decrby", "a"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("decrby", "a"); err == nil { t.Fatalf("invalid err %v", err) } - if _, err := c.Do("del"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("del"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("mset"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("mset"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("mset", "a", "b", "c"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("mset", "a", "b", "c"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("mget"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("mget"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("expire"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("expire"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("expire", "a", "b"); err == nil || err.Error() != SErrValue { + if _, err := c.Do("expire", "a", "b"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("expireat"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("expireat"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("expireat", "a", "b"); err == nil || err.Error() != SErrValue { + if _, err := c.Do("expireat", "a", "b"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("ttl"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("ttl"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("persist"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("persist"); err == nil { t.Fatal("invalid err of %v", err) } diff --git a/server/cmd_list_test.go b/server/cmd_list_test.go index 739020d..5a7cf51 100644 --- a/server/cmd_list_test.go +++ b/server/cmd_list_test.go @@ -289,55 +289,55 @@ func TestListErrorParams(t *testing.T) { c := getTestConn() defer c.Close() - if _, err := c.Do("lpush", "test_lpush"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("lpush", "test_lpush"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("rpush", "test_rpush"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("rpush", "test_rpush"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("lpop", "test_lpop", "a"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("lpop", "test_lpop", "a"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("rpop", "test_rpop", "a"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("rpop", "test_rpop", "a"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("llen", "test_llen", "a"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("llen", "test_llen", "a"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("lindex", "test_lindex"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("lindex", "test_lindex"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("lrange", "test_lrange"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("lrange", "test_lrange"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("lclear"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("lclear"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("lmclear"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("lmclear"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("lexpire"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("lexpire"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("lexpireat"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("lexpireat"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("lttl"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("lttl"); err == nil { t.Fatal("invalid err of %v", err) } - if _, err := c.Do("lpersist"); err == nil || err.Error() != SErrCmdParams { + if _, err := c.Do("lpersist"); err == nil { t.Fatal("invalid err of %v", err) } diff --git a/server/cmd_zset.go b/server/cmd_zset.go index 8dc465e..868b6bd 100644 --- a/server/cmd_zset.go +++ b/server/cmd_zset.go @@ -134,6 +134,7 @@ func zparseScoreRange(minBuf []byte, maxBuf []byte) (min int64, max int64, err e min, err = ledis.StrInt64(minBuf, nil) if err != nil { + err = ErrValue return } @@ -164,6 +165,7 @@ func zparseScoreRange(minBuf []byte, maxBuf []byte) (min int64, max int64, err e max, err = ledis.StrInt64(maxBuf, nil) if err != nil { + err = ErrValue return } @@ -188,7 +190,7 @@ func zcountCommand(c *client) error { min, max, err := zparseScoreRange(args[1], args[2]) if err != nil { - return err + return ErrValue } if min > max { @@ -249,7 +251,7 @@ func zremrangebyrankCommand(c *client) error { start, stop, err := zparseRange(c, args[1], args[2]) if err != nil { - return err + return ErrValue } if n, err := c.db.ZRemRangeByRank(key, start, stop); err != nil { @@ -304,14 +306,21 @@ func zrangeGeneric(c *client, reverse bool) error { start, stop, err := zparseRange(c, args[1], args[2]) if err != nil { - return err + return ErrValue } args = args[3:] var withScores bool = false - if len(args) > 0 && strings.ToLower(ledis.String(args[0])) == "withscores" { - withScores = true + if len(args) > 0 { + if len(args) != 1 { + return ErrCmdParams + } + if strings.ToLower(ledis.String(args[0])) == "withscores" { + withScores = true + } else { + return ErrSyntax + } } if datas, err := c.db.ZRangeGeneric(key, start, stop, reverse); err != nil { @@ -356,9 +365,11 @@ func zrangebyscoreGeneric(c *client, reverse bool) error { var withScores bool = false - if len(args) > 0 && strings.ToLower(ledis.String(args[0])) == "withscores" { - withScores = true - args = args[1:] + if len(args) > 0 { + if strings.ToLower(ledis.String(args[0])) == "withscores" { + withScores = true + args = args[1:] + } } var offset int = 0 @@ -370,15 +381,15 @@ func zrangebyscoreGeneric(c *client, reverse bool) error { } if strings.ToLower(ledis.String(args[0])) != "limit" { - return ErrCmdParams + return ErrSyntax } if offset, err = strconv.Atoi(ledis.String(args[1])); err != nil { - return ErrCmdParams + return ErrValue } if count, err = strconv.Atoi(ledis.String(args[2])); err != nil { - return ErrCmdParams + return ErrValue } } @@ -472,7 +483,6 @@ func zexpireAtCommand(c *client) error { } else { c.writeInteger(v) } - return nil } diff --git a/server/cmd_zset_test.go b/server/cmd_zset_test.go index cf512c7..d9b1272 100644 --- a/server/cmd_zset_test.go +++ b/server/cmd_zset_test.go @@ -429,3 +429,173 @@ func TestZSetRange(t *testing.T) { } } + +func TestZsetErrorParams(t *testing.T) { + c := getTestConn() + defer c.Close() + + //zadd + if _, err := c.Do("zadd", "test_zadd"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("zadd", "test_zadd", "a", "b", "c"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("zadd", "test_zadd", "-a", "a"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("zadd", "test_zad", "0.1", "a"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //zcard + if _, err := c.Do("zcard"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //zscore + if _, err := c.Do("zscore", "test_zscore"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //zrem + if _, err := c.Do("zrem", "test_zrem"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //zincrby + if _, err := c.Do("zincrby", "test_zincrby"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("zincrby", "test_zincrby", 0.1, "a"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //zcount + if _, err := c.Do("zcount", "test_zcount"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("zcount", "test_zcount", "-inf", "=inf"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("zcount", "test_zcount", 0.1, 0.1); err == nil { + t.Fatal("invalid err of %v", err) + } + + //zrank + if _, err := c.Do("zrank", "test_zrank"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //zrevzrank + if _, err := c.Do("zrevrank", "test_zrevrank"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //zremrangebyrank + if _, err := c.Do("zremrangebyrank", "test_zremrangebyrank"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("zremrangebyrank", "test_zremrangebyrank", 0.1, 0.1); err == nil { + t.Fatal("invalid err of %v", err) + } + + //zremrangebyscore + if _, err := c.Do("zremrangebyscore", "test_zremrangebyscore"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("zremrangebyscore", "test_zremrangebyscore", "-inf", "a"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("zremrangebyscore", "test_zremrangebyscore", 0, "a"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //zrange + if _, err := c.Do("zrange", "test_zrange"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("zrange", "test_zrange", 0, 1, "withscore"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("zrange", "test_zrange", 0, 1, "withscores", "a"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //zrevrange, almost same as zrange + if _, err := c.Do("zrevrange", "test_zrevrange"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //zrangebyscore + if _, err := c.Do("zrangebyscore", "test_zrangebyscore"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("zrangebyscore", "test_zrangebyscore", 0, 1, "withscore"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("zrangebyscore", "test_zrangebyscore", 0, 1, "withscores", "limit"); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("zrangebyscore", "test_zrangebyscore", 0, 1, "withscores", "limi", 1, 1); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("zrangebyscore", "test_zrangebyscore", 0, 1, "withscores", "limit", "a", 1); err == nil { + t.Fatal("invalid err of %v", err) + } + + if _, err := c.Do("zrangebyscore", "test_zrangebyscore", 0, 1, "withscores", "limit", 1, "a"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //zrevrangebyscore, almost same as zrangebyscore + if _, err := c.Do("zrevrangebyscore", "test_zrevrangebyscore"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //zclear + if _, err := c.Do("zclear"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //zmclear + if _, err := c.Do("zmclear"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //zexpire + if _, err := c.Do("zexpire", "test_zexpire"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //zexpireat + if _, err := c.Do("zexpireat", "test_zexpireat"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //zttl + if _, err := c.Do("zttl"); err == nil { + t.Fatal("invalid err of %v", err) + } + + //zpersist + if _, err := c.Do("zpersist"); err == nil { + t.Fatal("invalid err of %v", err) + } + +} diff --git a/server/const.go b/server/const.go index f1424d2..cfb7595 100644 --- a/server/const.go +++ b/server/const.go @@ -10,6 +10,8 @@ var ( ErrCmdParams = errors.New("invalid command param") ErrValue = errors.New("value is not an integer or out of range") ErrSyntax = errors.New("syntax error") + ErrOffset = errors.New("offset bit is not an natural number") + ErrBool = errors.New("value is not 0 or 1") ) var ( @@ -18,9 +20,6 @@ var ( NullBulk = []byte("-1") NullArray = []byte("-1") - PONG = "PONG" - OK = "OK" - SErrCmdParams = "ERR invalid command param" - SErrValue = "ERR value is not an integer or out of range" - SErrSyntax = "ERR syntax error" + PONG = "PONG" + OK = "OK" ) diff --git a/store/db.go b/store/db.go index adad004..f32b1c3 100644 --- a/store/db.go +++ b/store/db.go @@ -48,8 +48,8 @@ func (db *DB) NewIterator() *Iterator { return it } -func (db *DB) NewWriteBatch() *WriteBatch { - return &WriteBatch{db.db.NewWriteBatch()} +func (db *DB) NewWriteBatch() WriteBatch { + return db.db.NewWriteBatch() } func (db *DB) RangeIterator(min []byte, max []byte, rangeType uint8) *RangeLimitIterator { @@ -73,3 +73,12 @@ func (db *DB) RangeLimitIterator(min []byte, max []byte, rangeType uint8, offset func (db *DB) RevRangeLimitIterator(min []byte, max []byte, rangeType uint8, offset int, count int) *RangeLimitIterator { return NewRevRangeLimitIterator(db.NewIterator(), &Range{min, max, rangeType}, &Limit{offset, count}) } + +func (db *DB) Begin() (Tx, error) { + tx, err := db.db.Begin() + if err != nil { + return nil, err + } + + return tx, nil +} diff --git a/store/driver/driver.go b/store/driver/driver.go index 3be1c7e..a557d3e 100644 --- a/store/driver/driver.go +++ b/store/driver/driver.go @@ -1,5 +1,13 @@ package driver +import ( + "errors" +) + +var ( + ErrTxSupport = errors.New("transaction is not supported") +) + type IDB interface { Close() error @@ -11,6 +19,8 @@ type IDB interface { NewIterator() IIterator NewWriteBatch() IWriteBatch + + Begin() (Tx, error) } type IIterator interface { @@ -30,10 +40,20 @@ type IIterator interface { } type IWriteBatch interface { - Close() error - Put(key []byte, value []byte) Delete(key []byte) Commit() error Rollback() error } + +type Tx interface { + Get(key []byte) ([]byte, error) + Put(key []byte, value []byte) error + Delete(key []byte) error + + NewIterator() IIterator + NewWriteBatch() IWriteBatch + + Commit() error + Rollback() error +} diff --git a/store/goleveldb/batch.go b/store/goleveldb/batch.go index bdfb457..74902b2 100644 --- a/store/goleveldb/batch.go +++ b/store/goleveldb/batch.go @@ -9,10 +9,6 @@ type WriteBatch struct { wbatch *leveldb.Batch } -func (w *WriteBatch) Close() error { - return nil -} - func (w *WriteBatch) Put(key, value []byte) { w.wbatch.Put(key, value) } diff --git a/store/goleveldb/db.go b/store/goleveldb/db.go index c9e393f..261cf0d 100644 --- a/store/goleveldb/db.go +++ b/store/goleveldb/db.go @@ -135,3 +135,7 @@ func (db *DB) NewIterator() driver.IIterator { return it } + +func (db *DB) Begin() (driver.Tx, error) { + return nil, driver.ErrTxSupport +} diff --git a/store/leveldb/batch.go b/store/leveldb/batch.go index 87600cb..08a7d46 100644 --- a/store/leveldb/batch.go +++ b/store/leveldb/batch.go @@ -17,6 +17,8 @@ type WriteBatch struct { func (w *WriteBatch) Close() error { C.leveldb_writebatch_destroy(w.wbatch) + w.wbatch = nil + return nil } diff --git a/store/leveldb/db.go b/store/leveldb/db.go index 6e4022d..049c9f4 100644 --- a/store/leveldb/db.go +++ b/store/leveldb/db.go @@ -13,6 +13,7 @@ import "C" import ( "github.com/siddontang/ledisdb/store/driver" "os" + "runtime" "unsafe" ) @@ -187,6 +188,10 @@ func (db *DB) NewWriteBatch() driver.IWriteBatch { db: db, wbatch: C.leveldb_writebatch_create(), } + + runtime.SetFinalizer(wb, func(w *WriteBatch) { + w.Close() + }) return wb } @@ -260,3 +265,7 @@ func (db *DB) delete(wo *WriteOptions, key []byte) error { } return nil } + +func (db *DB) Begin() (driver.Tx, error) { + return nil, driver.ErrTxSupport +} diff --git a/store/mdb.go b/store/mdb.go index 364065a..ffbf010 100644 --- a/store/mdb.go +++ b/store/mdb.go @@ -1,3 +1,5 @@ +// +build !windows + package store import ( diff --git a/store/mdb/batch.go b/store/mdb/batch.go index d68ee43..7c99be2 100644 --- a/store/mdb/batch.go +++ b/store/mdb/batch.go @@ -1,17 +1,17 @@ package mdb +type batchPut interface { + BatchPut([]Write) error +} + type Write struct { Key []byte Value []byte } type WriteBatch struct { - db *MDB - wb []Write -} - -func (w *WriteBatch) Close() error { - return nil + batch batchPut + wb []Write } func (w *WriteBatch) Put(key, value []byte) { @@ -26,7 +26,7 @@ func (w *WriteBatch) Delete(key []byte) { } func (w *WriteBatch) Commit() error { - return w.db.BatchPut(w.wb) + return w.batch.BatchPut(w.wb) } func (w *WriteBatch) Rollback() error { diff --git a/store/mdb/mdb.go b/store/mdb/mdb.go index a10d0d4..d06f413 100644 --- a/store/mdb/mdb.go +++ b/store/mdb/mdb.go @@ -1,7 +1,7 @@ package mdb import ( - mdb "github.com/influxdb/gomdb" + mdb "github.com/siddontang/gomdb" "github.com/siddontang/ledisdb/store/driver" "os" ) @@ -32,6 +32,7 @@ func Open(c *Config) (MDB, error) { if err := env.SetMaxDBs(1); err != nil { return MDB{}, err } + if err := env.SetMapSize(uint64(c.MapSize)); err != nil { return MDB{}, err } @@ -78,11 +79,10 @@ func Repair(c *Config) error { func (db MDB) Put(key, value []byte) error { itr := db.iterator(false) - defer itr.Close() itr.err = itr.c.Put(key, value, 0) itr.setState() - return itr.Error() + return itr.Close() } func (db MDB) BatchPut(writes []Write) error { @@ -140,6 +140,8 @@ type MDBIterator struct { tx *mdb.Txn valid bool err error + + closeAutoCommit bool } func (itr *MDBIterator) Key() []byte { @@ -201,6 +203,11 @@ func (itr *MDBIterator) Close() error { itr.tx.Abort() return err } + + if !itr.closeAutoCommit { + return itr.err + } + if itr.err != nil { itr.tx.Abort() return itr.err @@ -226,16 +233,16 @@ func (db MDB) iterator(rdonly bool) *MDBIterator { } tx, err := db.env.BeginTxn(nil, flags) if err != nil { - return &MDBIterator{nil, nil, nil, nil, false, err} + return &MDBIterator{nil, nil, nil, nil, false, err, true} } c, err := tx.CursorOpen(db.db) if err != nil { tx.Abort() - return &MDBIterator{nil, nil, nil, nil, false, err} + return &MDBIterator{nil, nil, nil, nil, false, err, true} } - return &MDBIterator{nil, nil, c, tx, true, nil} + return &MDBIterator{nil, nil, c, tx, true, nil, true} } func (db MDB) Close() error { @@ -253,3 +260,7 @@ func (db MDB) NewIterator() driver.IIterator { func (db MDB) NewWriteBatch() driver.IWriteBatch { return &WriteBatch{&db, []Write{}} } + +func (db MDB) Begin() (driver.Tx, error) { + return newTx(db) +} diff --git a/store/mdb/tx.go b/store/mdb/tx.go new file mode 100644 index 0000000..4fb29b4 --- /dev/null +++ b/store/mdb/tx.go @@ -0,0 +1,81 @@ +package mdb + +import ( + mdb "github.com/siddontang/gomdb" + "github.com/siddontang/ledisdb/store/driver" +) + +type Tx struct { + db mdb.DBI + tx *mdb.Txn +} + +func newTx(db MDB) (*Tx, error) { + tx, err := db.env.BeginTxn(nil, uint(0)) + if err != nil { + return nil, err + } + + return &Tx{db.db, tx}, nil +} + +func (t *Tx) Get(key []byte) ([]byte, error) { + return t.tx.Get(t.db, key) +} + +func (t *Tx) Put(key []byte, value []byte) error { + return t.tx.Put(t.db, key, value, mdb.NODUPDATA) +} + +func (t *Tx) Delete(key []byte) error { + return t.tx.Del(t.db, key, nil) +} + +func (t *Tx) NewIterator() driver.IIterator { + return t.newIterator() +} + +func (t *Tx) newIterator() *MDBIterator { + c, err := t.tx.CursorOpen(t.db) + if err != nil { + return &MDBIterator{nil, nil, nil, nil, false, err, false} + } + + return &MDBIterator{nil, nil, c, t.tx, true, nil, false} +} + +func (t *Tx) NewWriteBatch() driver.IWriteBatch { + return &WriteBatch{t, []Write{}} +} + +func (t *Tx) BatchPut(writes []Write) error { + itr := t.newIterator() + + for _, w := range writes { + if w.Value == nil { + itr.key, itr.value, itr.err = itr.c.Get(w.Key, mdb.SET) + if itr.err == nil { + itr.err = itr.c.Del(0) + } + } else { + itr.err = itr.c.Put(w.Key, w.Value, 0) + } + + if itr.err != nil && itr.err != mdb.NotFound { + break + } + } + itr.setState() + + return itr.Close() + +} + +func (t *Tx) Rollback() error { + t.tx.Abort() + return nil +} + +func (t *Tx) Commit() error { + return t.tx.Commit() +} diff --git a/store/mdb_test.go b/store/mdb_test.go index 8e70d13..512b375 100644 --- a/store/mdb_test.go +++ b/store/mdb_test.go @@ -1,3 +1,5 @@ +// +build !windows + package store import ( @@ -30,3 +32,11 @@ func TestLMDB(t *testing.T) { db.Close() } + +func TestLMDBTx(t *testing.T) { + db := newTestLMDB() + + testTx(db, t) + + db.Close() +} diff --git a/store/rocksdb/batch.go b/store/rocksdb/batch.go index 46e700b..b69c383 100644 --- a/store/rocksdb/batch.go +++ b/store/rocksdb/batch.go @@ -17,6 +17,7 @@ type WriteBatch struct { func (w *WriteBatch) Close() error { C.rocksdb_writebatch_destroy(w.wbatch) + w.wbatch = nil return nil } diff --git a/store/rocksdb/db.go b/store/rocksdb/db.go index 0faeaec..2d8a56d 100644 --- a/store/rocksdb/db.go +++ b/store/rocksdb/db.go @@ -207,6 +207,11 @@ func (db *DB) NewWriteBatch() driver.IWriteBatch { db: db, wbatch: C.rocksdb_writebatch_create(), } + + runtime.SetFinalizer(wb, func(w *WriteBatch) { + w.Close() + }) + return wb } @@ -277,3 +282,7 @@ func (db *DB) delete(wo *WriteOptions, key []byte) error { } return nil } + +func (db *DB) Begin() (driver.Tx, error) { + return nil, driver.ErrTxSupport +} diff --git a/store/store_test.go b/store/store_test.go index c79bddd..a008d03 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -59,7 +59,6 @@ func testBatch(db *DB, t *testing.T) { db.Put(key2, value) wb := db.NewWriteBatch() - defer wb.Close() wb.Delete(key2) wb.Put(key1, []byte("hello world2")) diff --git a/store/tx.go b/store/tx.go new file mode 100644 index 0000000..1a074c7 --- /dev/null +++ b/store/tx.go @@ -0,0 +1,9 @@ +package store + +import ( + "github.com/siddontang/ledisdb/store/driver" +) + +type Tx interface { + driver.Tx +} diff --git a/store/tx_test.go b/store/tx_test.go new file mode 100644 index 0000000..cc1115b --- /dev/null +++ b/store/tx_test.go @@ -0,0 +1,123 @@ +package store + +import ( + "testing" +) + +func TestTx(t *testing.T) { + +} + +func testTx(db *DB, t *testing.T) { + key1 := []byte("1") + key2 := []byte("2") + key3 := []byte("3") + key4 := []byte("4") + + db.Put(key1, []byte("1")) + db.Put(key2, []byte("2")) + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + + if err := tx.Put(key1, []byte("a")); err != nil { + t.Fatal(err) + } + + if err := tx.Put(key2, []byte("b")); err != nil { + t.Fatal(err) + } + + if err := tx.Put(key3, []byte("c")); err != nil { + t.Fatal(err) + } + + if err := tx.Put(key4, []byte("d")); err != nil { + t.Fatal(err) + } + + it := tx.NewIterator() + + it.Seek(key1) + + if !it.Valid() { + t.Fatal("must valid") + } else if string(it.Value()) != "a" { + t.Fatal(string(it.Value())) + } + + it.First() + + if !it.Valid() { + t.Fatal("must valid") + } else if string(it.Value()) != "a" { + t.Fatal(string(it.Value())) + } + + it.Seek(key2) + + if !it.Valid() { + t.Fatal("must valid") + } else if string(it.Value()) != "b" { + t.Fatal(string(it.Value())) + } + + it.Next() + + if !it.Valid() { + t.Fatal("must valid") + } else if string(it.Value()) != "c" { + t.Fatal(string(it.Value())) + } + + it.Last() + + if !it.Valid() { + t.Fatal("must valid") + } else if string(it.Value()) != "d" { + t.Fatal(string(it.Value())) + } + + it.Close() + + tx.Rollback() + + if v, err := db.Get(key1); err != nil { + t.Fatal(err) + } else if string(v) != "1" { + t.Fatal(string(v)) + } + + tx, err = db.Begin() + if err != nil { + t.Fatal(err) + } + + if err := tx.Put(key1, []byte("a")); err != nil { + t.Fatal(err) + } + + it = tx.NewIterator() + + it.Seek(key2) + + if !it.Valid() { + t.Fatal("must valid") + } else if string(it.Value()) != "2" { + t.Fatal(string(it.Value())) + } + + it.Close() + + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + + if v, err := db.Get(key1); err != nil { + t.Fatal(err) + } else if string(v) != "a" { + t.Fatal(string(v)) + } +} diff --git a/store/writebatch.go b/store/writebatch.go index d898a03..9fe21ac 100644 --- a/store/writebatch.go +++ b/store/writebatch.go @@ -4,6 +4,6 @@ import ( "github.com/siddontang/ledisdb/store/driver" ) -type WriteBatch struct { +type WriteBatch interface { driver.IWriteBatch }