diff --git a/ledis/scan.go b/ledis/scan.go index 5b32133..c465168 100644 --- a/ledis/scan.go +++ b/ledis/scan.go @@ -48,11 +48,9 @@ func getDataStoreType(dataType DataType) (byte, error) { return storeDataType, nil } -func (db *DB) scanGeneric(storeDataType byte, key []byte, count int, - inclusive bool, match string, reverse bool) ([][]byte, error) { - var minKey, maxKey []byte +func buildMatchRegexp(match string) (*regexp.Regexp, error) { var err error - var r *regexp.Regexp + var r *regexp.Regexp = nil if len(match) > 0 { if r, err = regexp.Compile(match); err != nil { @@ -60,6 +58,17 @@ func (db *DB) scanGeneric(storeDataType byte, key []byte, count int, } } + return r, nil +} + +func (db *DB) scanGeneric(storeDataType byte, key []byte, count int, + inclusive bool, match string, reverse bool) ([][]byte, error) { + var minKey, maxKey []byte + r, err := buildMatchRegexp(match) + if err != nil { + return nil, err + } + tp := store.RangeOpen if !reverse { @@ -151,10 +160,10 @@ func (db *DB) encodeScanKey(storeDataType byte, key []byte) ([]byte, error) { return db.hEncodeSizeKey(key), nil case ZSizeType: return db.zEncodeSizeKey(key), nil - case BitMetaType: - return db.bEncodeMetaKey(key), nil case SSizeType: return db.sEncodeSizeKey(key), nil + case BitMetaType: + return db.bEncodeMetaKey(key), nil default: return nil, errDataType } @@ -165,3 +174,123 @@ func (db *DB) decodeScanKey(storeDataType byte, ek []byte) ([]byte, error) { } return ek[2:], nil } + +// for specail data scan + +func (db *DB) buildDataScanIterator(start []byte, stop []byte, inclusive bool) *store.RangeLimitIterator { + tp := store.RangeROpen + + if !inclusive { + tp = store.RangeOpen + } + it := db.bucket.RangeIterator(start, stop, tp) + return it + +} + +func (db *DB) HScan(key []byte, cursor []byte, count int, inclusive bool, match string) ([]FVPair, error) { + if err := checkKeySize(key); err != nil { + return nil, err + } + + start := db.hEncodeHashKey(key, cursor) + stop := db.hEncodeStopKey(key) + + v := make([]FVPair, 0, 16) + + r, err := buildMatchRegexp(match) + if err != nil { + return nil, err + } + + it := db.buildDataScanIterator(start, stop, inclusive) + defer it.Close() + + for i := 0; it.Valid() && i < count; it.Next() { + _, f, err := db.hDecodeHashKey(it.Key()) + if err != nil { + return nil, err + } else if r != nil && !r.Match(f) { + continue + } + + v = append(v, FVPair{Field: f, Value: it.Value()}) + + i++ + } + + return v, nil +} + +func (db *DB) SScan(key []byte, cursor []byte, count int, inclusive bool, match string) ([][]byte, error) { + if err := checkKeySize(key); err != nil { + return nil, err + } + + start := db.sEncodeSetKey(key, cursor) + stop := db.sEncodeStopKey(key) + + v := make([][]byte, 0, 16) + + r, err := buildMatchRegexp(match) + if err != nil { + return nil, err + } + + it := db.buildDataScanIterator(start, stop, inclusive) + defer it.Close() + + for i := 0; it.Valid() && i < count; it.Next() { + _, m, err := db.sDecodeSetKey(it.Key()) + if err != nil { + return nil, err + } else if r != nil && !r.Match(m) { + continue + } + + v = append(v, m) + + i++ + } + + return v, nil +} + +func (db *DB) ZScan(key []byte, cursor []byte, count int, inclusive bool, match string) ([]ScorePair, error) { + if err := checkKeySize(key); err != nil { + return nil, err + } + + start := db.zEncodeSetKey(key, cursor) + stop := db.zEncodeStopSetKey(key) + + v := make([]ScorePair, 0, 16) + + r, err := buildMatchRegexp(match) + if err != nil { + return nil, err + } + + it := db.buildDataScanIterator(start, stop, inclusive) + defer it.Close() + + for i := 0; it.Valid() && i < count; it.Next() { + _, m, err := db.zDecodeSetKey(it.Key()) + if err != nil { + return nil, err + } else if r != nil && !r.Match(m) { + continue + } + + score, err := Int64(it.Value(), nil) + if err != nil { + return nil, err + } + + v = append(v, ScorePair{Score: score, Member: m}) + + i++ + } + + return v, nil +} diff --git a/ledis/scan_test.go b/ledis/scan_test.go index ded982c..9505964 100644 --- a/ledis/scan_test.go +++ b/ledis/scan_test.go @@ -111,7 +111,7 @@ func TestDBScan(t *testing.T) { } -func TestDBHScan(t *testing.T) { +func TestDBHKeyScan(t *testing.T) { db := getTestDB() db.hFlush() @@ -155,7 +155,7 @@ func TestDBHScan(t *testing.T) { } -func TestDBZScan(t *testing.T) { +func TestDBZKeyScan(t *testing.T) { db := getTestDB() db.zFlush() @@ -199,7 +199,7 @@ func TestDBZScan(t *testing.T) { } -func TestDBLScan(t *testing.T) { +func TestDBLKeyScan(t *testing.T) { db := getTestDB() db.lFlush() @@ -249,10 +249,10 @@ func TestDBLScan(t *testing.T) { } -func TestDBSScan(t *testing.T) { +func TestDBSKeyScan(t *testing.T) { db := getTestDB() - db.bFlush() + db.sFlush() k1 := []byte("k1") if _, err := db.SAdd(k1, []byte("1")); err != nil { @@ -296,5 +296,77 @@ func TestDBSScan(t *testing.T) { } else if string(v[1]) != "k3" { t.Fatal("invalid value ", string(v[1])) } - +} + +func TestDBHScan(t *testing.T) { + db := getTestDB() + + key := []byte("scan_h_key") + value := []byte("hello world") + db.HSet(key, []byte("1"), value) + db.HSet(key, []byte("222"), value) + db.HSet(key, []byte("19"), value) + db.HSet(key, []byte("1234"), value) + + v, err := db.HScan(key, nil, 100, true, "") + if err != nil { + t.Fatal(err) + } else if len(v) != 4 { + t.Fatal("invalid count", len(v)) + } + + v, err = db.HScan(key, []byte("19"), 1, false, "") + if err != nil { + t.Fatal(err) + } else if len(v) != 1 { + t.Fatal("invalid count", len(v)) + } else if string(v[0].Field) != "222" { + t.Fatal(string(v[0].Field)) + } +} + +func TestDBSScan(t *testing.T) { + db := getTestDB() + key := []byte("scan_s_key") + + db.SAdd(key, []byte("1"), []byte("222"), []byte("19"), []byte("1234")) + + v, err := db.SScan(key, nil, 100, true, "") + if err != nil { + t.Fatal(err) + } else if len(v) != 4 { + t.Fatal("invalid count", len(v)) + } + + v, err = db.SScan(key, []byte("19"), 1, false, "") + if err != nil { + t.Fatal(err) + } else if len(v) != 1 { + t.Fatal("invalid count", len(v)) + } else if string(v[0]) != "222" { + t.Fatal(string(v[0])) + } +} + +func TestDBZScan(t *testing.T) { + db := getTestDB() + key := []byte("scan_z_key") + + db.ZAdd(key, ScorePair{1, []byte("1")}, ScorePair{2, []byte("222")}, ScorePair{3, []byte("19")}, ScorePair{4, []byte("1234")}) + + v, err := db.ZScan(key, nil, 100, true, "") + if err != nil { + t.Fatal(err) + } else if len(v) != 4 { + t.Fatal("invalid count", len(v)) + } + + v, err = db.ZScan(key, []byte("19"), 1, false, "") + if err != nil { + t.Fatal(err) + } else if len(v) != 1 { + t.Fatal("invalid count", len(v)) + } else if string(v[0].Member) != "222" { + t.Fatal(string(v[0].Member)) + } } diff --git a/ledis/t_hash.go b/ledis/t_hash.go index 3a0133a..c7b92dd 100644 --- a/ledis/t_hash.go +++ b/ledis/t_hash.go @@ -354,6 +354,8 @@ func (db *DB) HGetAll(key []byte) ([]FVPair, error) { v := make([]FVPair, 0, 16) it := db.bucket.RangeLimitIterator(start, stop, store.RangeROpen, 0, -1) + defer it.Close() + for ; it.Valid(); it.Next() { _, f, err := db.hDecodeHashKey(it.Key()) if err != nil { @@ -363,8 +365,6 @@ func (db *DB) HGetAll(key []byte) ([]FVPair, error) { v = append(v, FVPair{Field: f, Value: it.Value()}) } - it.Close() - return v, nil } @@ -379,6 +379,8 @@ func (db *DB) HKeys(key []byte) ([][]byte, error) { v := make([][]byte, 0, 16) it := db.bucket.RangeLimitIterator(start, stop, store.RangeROpen, 0, -1) + defer it.Close() + for ; it.Valid(); it.Next() { _, f, err := db.hDecodeHashKey(it.Key()) if err != nil { @@ -387,8 +389,6 @@ func (db *DB) HKeys(key []byte) ([][]byte, error) { v = append(v, f) } - it.Close() - return v, nil } @@ -403,6 +403,8 @@ func (db *DB) HValues(key []byte) ([][]byte, error) { v := make([][]byte, 0, 16) it := db.bucket.RangeLimitIterator(start, stop, store.RangeROpen, 0, -1) + defer it.Close() + for ; it.Valid(); it.Next() { _, _, err := db.hDecodeHashKey(it.Key()) if err != nil { @@ -412,8 +414,6 @@ func (db *DB) HValues(key []byte) ([][]byte, error) { v = append(v, it.Value()) } - it.Close() - return v, nil } diff --git a/server/cmd_scan.go b/server/cmd_scan.go index e4f6f7f..1e62d8f 100644 --- a/server/cmd_scan.go +++ b/server/cmd_scan.go @@ -3,12 +3,51 @@ package server import ( "fmt" "github.com/siddontang/go/hack" + "github.com/siddontang/go/num" "github.com/siddontang/ledisdb/ledis" "strconv" "strings" ) -// XSCAN type cursor [MATCH match] [COUNT count] [ASC|DESC] +func parseScanArgs(args [][]byte) (cursor []byte, match string, count int, err error) { + cursor = args[0] + + args = args[1:] + + count = 10 + + for i := 0; i < len(args); { + switch strings.ToUpper(hack.String(args[i])) { + case "MATCH": + if i+1 >= len(args) { + err = ErrCmdParams + return + } + + match = hack.String(args[i+1]) + i = i + 2 + case "COUNT": + if i+1 >= len(args) { + err = ErrCmdParams + return + } + + count, err = strconv.Atoi(hack.String(args[i+1])) + if err != nil { + return + } + + i = i + 2 + default: + err = fmt.Errorf("invalid argument %s", args[i]) + return + } + } + + return +} + +// XSCAN type cursor [MATCH match] [COUNT count] func xscanCommand(c *client) error { args := c.args @@ -32,55 +71,13 @@ func xscanCommand(c *client) error { return fmt.Errorf("invalid key type %s", args[0]) } - cursor := args[1] + cursor, match, count, err := parseScanArgs(args[1:]) - args = args[2:] - - match := "" - count := 10 - - desc := false - - var err error - - for i := 0; i < len(args); { - switch strings.ToUpper(hack.String(args[i])) { - case "MATCH": - if i+1 >= len(args) { - return ErrCmdParams - } - - match = hack.String(args[i+1]) - i = i + 2 - case "COUNT": - if i+1 >= len(args) { - return ErrCmdParams - } - - count, err = strconv.Atoi(hack.String(args[i+1])) - if err != nil { - return err - } - - i = i + 2 - case "ASC": - desc = false - i++ - case "DESC": - desc = true - i++ - default: - return fmt.Errorf("invalid argument %s", args[i]) - } - } - - var ay [][]byte - if !desc { - ay, err = c.db.Scan(dataType, cursor, count, false, match) - } else { - ay, err = c.db.RevScan(dataType, cursor, count, false, match) + if err != nil { + return err } + ay, err := c.db.Scan(dataType, cursor, count, false, match) if err != nil { return err } @@ -96,6 +93,123 @@ func xscanCommand(c *client) error { return nil } +// XHSCAN key cursor [MATCH match] [COUNT count] +func xhscanCommand(c *client) error { + args := c.args + + if len(args) < 2 { + return ErrCmdParams + } + + key := args[0] + + cursor, match, count, err := parseScanArgs(args[1:]) + + if err != nil { + return err + } + + ay, err := c.db.HScan(key, cursor, count, false, match) + if err != nil { + return err + } + + data := make([]interface{}, 2) + if len(ay) < count { + data[0] = []byte("") + } else { + data[0] = ay[len(ay)-1].Field + } + + vv := make([][]byte, 0, len(ay)*2) + + for _, v := range ay { + vv = append(vv, v.Field, v.Value) + } + + data[1] = vv + + c.resp.writeArray(data) + return nil +} + +// XSSCAN key cursor [MATCH match] [COUNT count] +func xsscanCommand(c *client) error { + args := c.args + + if len(args) < 2 { + return ErrCmdParams + } + + key := args[0] + + cursor, match, count, err := parseScanArgs(args[1:]) + + if err != nil { + return err + } + + ay, err := c.db.SScan(key, cursor, count, false, match) + if err != nil { + return err + } + + data := make([]interface{}, 2) + if len(ay) < count { + data[0] = []byte("") + } else { + data[0] = ay[len(ay)-1] + } + + data[1] = ay + + c.resp.writeArray(data) + return nil +} + +// XZSCAN key cursor [MATCH match] [COUNT count] +func xzscanCommand(c *client) error { + args := c.args + + if len(args) < 2 { + return ErrCmdParams + } + + key := args[0] + + cursor, match, count, err := parseScanArgs(args[1:]) + + if err != nil { + return err + } + + ay, err := c.db.ZScan(key, cursor, count, false, match) + if err != nil { + return err + } + + data := make([]interface{}, 2) + if len(ay) < count { + data[0] = []byte("") + } else { + data[0] = ay[len(ay)-1].Member + } + + vv := make([][]byte, 0, len(ay)*2) + + for _, v := range ay { + vv = append(vv, v.Member, num.FormatInt64ToSlice(v.Score)) + } + + data[1] = vv + + c.resp.writeArray(data) + return nil +} + func init() { register("xscan", xscanCommand) + register("xhscan", xhscanCommand) + register("xsscan", xsscanCommand) + register("xzscan", xzscanCommand) } diff --git a/server/cmd_scan_test.go b/server/cmd_scan_test.go index 66d34d3..7ebb5ba 100644 --- a/server/cmd_scan_test.go +++ b/server/cmd_scan_test.go @@ -29,14 +29,14 @@ func TestScan(t *testing.T) { defer c.Close() testKVScan(t, c) - testHashScan(t, c) - testListScan(t, c) - testZSetScan(t, c) - testSetScan(t, c) + testHashKeyScan(t, c) + testListKeyScan(t, c) + testZSetKeyScan(t, c) + testSetKeyScan(t, c) } -func checkScanValues(t *testing.T, ay interface{}, values ...int) { +func checkScanValues(t *testing.T, ay interface{}, values ...interface{}) { a, err := ledis.Strings(ay, nil) if err != nil { t.Fatal(err) @@ -47,8 +47,8 @@ func checkScanValues(t *testing.T, ay interface{}, values ...int) { } for i, v := range a { - if string(v) != fmt.Sprintf("%d", values[i]) { - t.Fatal(fmt.Sprintf("%d %s != %d", string(v), values[i])) + if string(v) != fmt.Sprintf("%v", values[i]) { + t.Fatal(fmt.Sprintf("%d %s != %v", string(v), values[i])) } } } @@ -76,29 +76,6 @@ func checkScan(t *testing.T, c *ledis.Client, tp string) { } -func checkRevScan(t *testing.T, c *ledis.Client, tp string) { - if ay, err := ledis.Values(c.Do("XSCAN", tp, "", "count", 5, "DESC")); err != nil { - t.Fatal(err) - } else if len(ay) != 2 { - t.Fatal(len(ay)) - } else if n := ay[0].([]byte); string(n) != "5" { - t.Fatal(string(n)) - } else { - checkScanValues(t, ay[1], 9, 8, 7, 6, 5) - } - - if ay, err := ledis.Values(c.Do("XSCAN", tp, "5", "count", 6, "DESC")); err != nil { - t.Fatal(err) - } else if len(ay) != 2 { - t.Fatal(len(ay)) - } else if n := ay[0].([]byte); string(n) != "" { - t.Fatal(string(n)) - } else { - checkScanValues(t, ay[1], 4, 3, 2, 1, 0) - } - -} - func testKVScan(t *testing.T, c *ledis.Client) { for i := 0; i < 10; i++ { if _, err := c.Do("set", fmt.Sprintf("%d", i), []byte("value")); err != nil { @@ -107,10 +84,9 @@ func testKVScan(t *testing.T, c *ledis.Client) { } checkScan(t, c, "KV") - checkRevScan(t, c, "KV") } -func testHashScan(t *testing.T, c *ledis.Client) { +func testHashKeyScan(t *testing.T, c *ledis.Client) { for i := 0; i < 10; i++ { if _, err := c.Do("hset", fmt.Sprintf("%d", i), fmt.Sprintf("%d", i), []byte("value")); err != nil { t.Fatal(err) @@ -118,10 +94,9 @@ func testHashScan(t *testing.T, c *ledis.Client) { } checkScan(t, c, "HASH") - checkRevScan(t, c, "HASH") } -func testListScan(t *testing.T, c *ledis.Client) { +func testListKeyScan(t *testing.T, c *ledis.Client) { for i := 0; i < 10; i++ { if _, err := c.Do("lpush", fmt.Sprintf("%d", i), fmt.Sprintf("%d", i)); err != nil { t.Fatal(err) @@ -129,10 +104,9 @@ func testListScan(t *testing.T, c *ledis.Client) { } checkScan(t, c, "LIST") - checkRevScan(t, c, "LIST") } -func testZSetScan(t *testing.T, c *ledis.Client) { +func testZSetKeyScan(t *testing.T, c *ledis.Client) { for i := 0; i < 10; i++ { if _, err := c.Do("zadd", fmt.Sprintf("%d", i), i, []byte("value")); err != nil { t.Fatal(err) @@ -140,10 +114,9 @@ func testZSetScan(t *testing.T, c *ledis.Client) { } checkScan(t, c, "ZSET") - checkRevScan(t, c, "ZSET") } -func testSetScan(t *testing.T, c *ledis.Client) { +func testSetKeyScan(t *testing.T, c *ledis.Client) { for i := 0; i < 10; i++ { if _, err := c.Do("sadd", fmt.Sprintf("%d", i), fmt.Sprintf("%d", i)); err != nil { t.Fatal(err) @@ -151,5 +124,54 @@ func testSetScan(t *testing.T, c *ledis.Client) { } checkScan(t, c, "SET") - checkRevScan(t, c, "SET") +} + +func TestHashScan(t *testing.T) { + c := getTestConn() + defer c.Close() + + key := "scan_hash" + c.Do("HMSET", key, "a", 1, "b", 2) + + if ay, err := ledis.Values(c.Do("XHSCAN", key, "")); err != nil { + t.Fatal(err) + } else if len(ay) != 2 { + t.Fatal(len(ay)) + } else { + checkScanValues(t, ay[1], "a", 1, "b", 2) + } +} + +func TestSetScan(t *testing.T) { + c := getTestConn() + defer c.Close() + + key := "scan_set" + c.Do("SADD", key, "a", "b") + + if ay, err := ledis.Values(c.Do("XSSCAN", key, "")); err != nil { + t.Fatal(err) + } else if len(ay) != 2 { + t.Fatal(len(ay)) + } else { + checkScanValues(t, ay[1], "a", "b") + } + +} + +func TestZSetScan(t *testing.T) { + c := getTestConn() + defer c.Close() + + key := "scan_zset" + c.Do("ZADD", key, 1, "a", 2, "b") + + if ay, err := ledis.Values(c.Do("XZSCAN", key, "")); err != nil { + t.Fatal(err) + } else if len(ay) != 2 { + t.Fatal(len(ay)) + } else { + checkScanValues(t, ay[1], "a", 1, "b", 2) + } + }