From cb5a61240d62797ac934b3d692f426cf27facb89 Mon Sep 17 00:00:00 2001 From: siddontang Date: Tue, 26 Aug 2014 23:21:45 +0800 Subject: [PATCH] server add scan support --- ledis/scan.go | 2 +- server/client_resp.go | 2 + server/cmd_bit.go | 22 +++++++++ server/cmd_hash.go | 22 +++++++++ server/cmd_kv.go | 78 ++++++++++++++++++++++++++++++ server/cmd_list.go | 22 +++++++++ server/cmd_replication_test.go | 2 + server/cmd_set.go | 22 +++++++++ server/cmd_zset.go | 22 +++++++++ server/scan_test.go | 88 ++++++++++++++++++++++++++++++++++ 10 files changed, 281 insertions(+), 1 deletion(-) create mode 100644 server/scan_test.go diff --git a/ledis/scan.go b/ledis/scan.go index 64a3f97..09e2b5c 100644 --- a/ledis/scan.go +++ b/ledis/scan.go @@ -20,7 +20,7 @@ func (db *DB) scan(dataType byte, key []byte, count int, inclusive bool, match s } } - if key != nil { + if len(key) > 0 { if err = checkKeySize(key); err != nil { return nil, err } diff --git a/server/client_resp.go b/server/client_resp.go index 22e6e69..e8fb1ff 100644 --- a/server/client_resp.go +++ b/server/client_resp.go @@ -203,6 +203,8 @@ func (w *respWriter) writeArray(lst []interface{}) { switch v := lst[i].(type) { case []interface{}: w.writeArray(v) + case [][]byte: + w.writeSliceArray(v) case []byte: w.writeBulk(v) case nil: diff --git a/server/cmd_bit.go b/server/cmd_bit.go index 5845f28..39762e9 100644 --- a/server/cmd_bit.go +++ b/server/cmd_bit.go @@ -272,6 +272,27 @@ func bpersistCommand(c *client) error { return nil } +func bscanCommand(c *client) error { + key, inclusive, match, count, err := parseScanArgs(c) + if err != nil { + return err + } + + if ay, err := c.db.BScan(key, count, inclusive, match); err != nil { + return err + } else { + data := make([]interface{}, 2) + if len(ay) < count { + data[0] = "" + } else { + data[0] = append([]byte{'('}, ay[len(ay)-1]...) + } + data[1] = ay + c.resp.writeArray(data) + } + return nil +} + func init() { register("bget", bgetCommand) register("bdelete", bdeleteCommand) @@ -284,4 +305,5 @@ func init() { register("bexpireat", bexpireAtCommand) register("bttl", bttlCommand) register("bpersist", bpersistCommand) + register("bscan", bscanCommand) } diff --git a/server/cmd_hash.go b/server/cmd_hash.go index 5c336a3..f292e4b 100644 --- a/server/cmd_hash.go +++ b/server/cmd_hash.go @@ -292,6 +292,27 @@ func hpersistCommand(c *client) error { return nil } +func hscanCommand(c *client) error { + key, inclusive, match, count, err := parseScanArgs(c) + if err != nil { + return err + } + + if ay, err := c.db.HScan(key, count, inclusive, match); err != nil { + return err + } else { + data := make([]interface{}, 2) + if len(ay) < count { + data[0] = "" + } else { + data[0] = append([]byte{'('}, ay[len(ay)-1]...) + } + data[1] = ay + c.resp.writeArray(data) + } + return nil +} + func init() { register("hdel", hdelCommand) register("hexists", hexistsCommand) @@ -313,4 +334,5 @@ func init() { register("hexpireat", hexpireAtCommand) register("httl", httlCommand) register("hpersist", hpersistCommand) + register("hscan", hscanCommand) } diff --git a/server/cmd_kv.go b/server/cmd_kv.go index cd0cac0..44fd28a 100644 --- a/server/cmd_kv.go +++ b/server/cmd_kv.go @@ -2,6 +2,8 @@ package server import ( "github.com/siddontang/ledisdb/ledis" + "strconv" + "strings" ) func getCommand(c *client) error { @@ -273,6 +275,81 @@ func persistCommand(c *client) error { return nil } +func parseScanArgs(c *client) (key []byte, inclusive bool, match string, count int, err error) { + args := c.args + count = 10 + inclusive = false + + switch len(args) { + case 0: + key = nil + return + case 1, 3, 5: + key = args[0] + break + case 2, 4, 6: + key = args[0] + if strings.ToLower(ledis.String(args[len(args)-1])) != "inclusive" { + err = ErrCmdParams + return + } + inclusive = true + args = args[0 : len(args)-1] + default: + err = ErrCmdParams + return + } + + if len(args) == 3 { + switch strings.ToLower(ledis.String(args[1])) { + case "match": + match = ledis.String(args[2]) + return + case "count": + count, err = strconv.Atoi(ledis.String(args[2])) + return + default: + err = ErrCmdParams + return + } + } else if len(args) == 5 { + if strings.ToLower(ledis.String(args[1])) != "match" { + err = ErrCmdParams + return + } else if strings.ToLower(ledis.String(args[3])) != "count" { + err = ErrCmdParams + return + } + + match = ledis.String(args[2]) + count, err = strconv.Atoi(ledis.String(args[4])) + return + } + + return +} + +func scanCommand(c *client) error { + key, inclusive, match, count, err := parseScanArgs(c) + if err != nil { + return err + } + + if ay, err := c.db.Scan(key, count, inclusive, match); err != nil { + return err + } else { + data := make([]interface{}, 2) + if len(ay) < count { + data[0] = []byte("") + } else { + data[0] = ay[len(ay)-1] + } + data[1] = ay + c.resp.writeArray(data) + } + return nil +} + func init() { register("decr", decrCommand) register("decrby", decrbyCommand) @@ -290,4 +367,5 @@ func init() { register("expireat", expireAtCommand) register("ttl", ttlCommand) register("persist", persistCommand) + register("scan", scanCommand) } diff --git a/server/cmd_list.go b/server/cmd_list.go index f893643..bd4cc6f 100644 --- a/server/cmd_list.go +++ b/server/cmd_list.go @@ -228,6 +228,27 @@ func lpersistCommand(c *client) error { return nil } +func lscanCommand(c *client) error { + key, inclusive, match, count, err := parseScanArgs(c) + if err != nil { + return err + } + + if ay, err := c.db.LScan(key, count, inclusive, match); err != nil { + return err + } else { + data := make([]interface{}, 2) + if len(ay) < count { + data[0] = "" + } else { + data[0] = append([]byte{'('}, ay[len(ay)-1]...) + } + data[1] = ay + c.resp.writeArray(data) + } + return nil +} + func init() { register("lindex", lindexCommand) register("llen", llenCommand) @@ -245,4 +266,5 @@ func init() { register("lexpireat", lexpireAtCommand) register("lttl", lttlCommand) register("lpersist", lpersistCommand) + register("lscan", lscanCommand) } diff --git a/server/cmd_replication_test.go b/server/cmd_replication_test.go index 5a7aa54..3e7e285 100644 --- a/server/cmd_replication_test.go +++ b/server/cmd_replication_test.go @@ -43,6 +43,7 @@ func TestReplication(t *testing.T) { if err != nil { t.Fatal(err) } + defer master.Close() slaveCfg := new(config.Config) slaveCfg.DataDir = fmt.Sprintf("%s/slave", data_dir) @@ -53,6 +54,7 @@ func TestReplication(t *testing.T) { if err != nil { t.Fatal(err) } + defer slave.Close() go master.Run() diff --git a/server/cmd_set.go b/server/cmd_set.go index 335d8cc..052389d 100644 --- a/server/cmd_set.go +++ b/server/cmd_set.go @@ -262,6 +262,27 @@ func spersistCommand(c *client) error { return nil } +func sscanCommand(c *client) error { + key, inclusive, match, count, err := parseScanArgs(c) + if err != nil { + return err + } + + if ay, err := c.db.SScan(key, count, inclusive, match); err != nil { + return err + } else { + data := make([]interface{}, 2) + if len(ay) < count { + data[0] = "" + } else { + data[0] = append([]byte{'('}, ay[len(ay)-1]...) + } + data[1] = ay + c.resp.writeArray(data) + } + return nil +} + func init() { register("sadd", saddCommand) register("scard", scardCommand) @@ -280,4 +301,5 @@ func init() { register("sexpireat", sexpireAtCommand) register("sttl", sttlCommand) register("spersist", spersistCommand) + register("sscan", sscanCommand) } diff --git a/server/cmd_zset.go b/server/cmd_zset.go index 2964af5..162c411 100644 --- a/server/cmd_zset.go +++ b/server/cmd_zset.go @@ -638,6 +638,27 @@ func zinterstoreCommand(c *client) error { return err } +func zscanCommand(c *client) error { + key, inclusive, match, count, err := parseScanArgs(c) + if err != nil { + return err + } + + if ay, err := c.db.ZScan(key, count, inclusive, match); err != nil { + return err + } else { + data := make([]interface{}, 2) + if len(ay) < count { + data[0] = "" + } else { + data[0] = append([]byte{'('}, ay[len(ay)-1]...) + } + data[1] = ay + c.resp.writeArray(data) + } + return nil +} + func init() { register("zadd", zaddCommand) register("zcard", zcardCommand) @@ -665,4 +686,5 @@ func init() { register("zexpireat", zexpireAtCommand) register("zttl", zttlCommand) register("zpersist", zpersistCommand) + register("zscan", zscanCommand) } diff --git a/server/scan_test.go b/server/scan_test.go new file mode 100644 index 0000000..e2c671f --- /dev/null +++ b/server/scan_test.go @@ -0,0 +1,88 @@ +package server + +import ( + "fmt" + "github.com/siddontang/ledisdb/client/go/ledis" + "github.com/siddontang/ledisdb/config" + "os" + "testing" +) + +func TestScan(t *testing.T) { + cfg := new(config.Config) + cfg.DataDir = "/tmp/test_scan" + cfg.Addr = "127.0.0.1:11185" + + os.RemoveAll(cfg.DataDir) + + s, err := NewApp(cfg) + if err != nil { + t.Fatal(err) + } + go s.Run() + defer s.Close() + + cc := new(ledis.Config) + cc.Addr = cfg.Addr + cc.MaxIdleConns = 1 + c := ledis.NewClient(cc) + defer c.Close() + + testKVScan(t, c) +} + +func checkScanValues(t *testing.T, ay interface{}, values ...int) { + a, err := ledis.Strings(ay, nil) + if err != nil { + t.Fatal(err) + } + + if len(a) != len(values) { + t.Fatal(fmt.Sprintf("len %d != %d", len(a), len(values))) + } + + for i, v := range a { + if string(v) != fmt.Sprintf("%d", values[i]) { + t.Fatal(fmt.Sprintf("%d %s != %d", string(v), values[i])) + } + } +} + +func testKVScan(t *testing.T, c *ledis.Client) { + for i := 0; i < 10; i++ { + if _, err := c.Do("set", fmt.Sprintf("%d", i), []byte("value")); err != nil { + t.Fatal(err) + } + } + + if ay, err := ledis.Values(c.Do("scan", "", "count", 5)); err != nil { + t.Fatal(err) + } else if len(ay) != 2 { + t.Fatal(len(ay)) + } else if n := ay[0].([]byte); string(n) != "4" { + t.Fatal(string(n)) + } else { + checkScanValues(t, ay[1], 0, 1, 2, 3, 4) + } + + if ay, err := ledis.Values(c.Do("scan", "4", "count", 6)); err != nil { + t.Fatal(err) + } else if len(ay) != 2 { + t.Fatal(len(ay)) + } else if n := ay[0].([]byte); string(n) != "" { + t.Fatal(string(n)) + } else { + checkScanValues(t, ay[1], 5, 6, 7, 8, 9) + } + + if ay, err := ledis.Values(c.Do("scan", "4", "count", 6, "inclusive")); err != nil { + t.Fatal(err) + } else if len(ay) != 2 { + t.Fatal(len(ay)) + } else if n := ay[0].([]byte); string(n) != "9" { + t.Fatal(string(n)) + } else { + checkScanValues(t, ay[1], 4, 5, 6, 7, 8, 9) + } + +}