diff --git a/ledis/scan.go b/ledis/scan.go index 466f1ec..cdf18dc 100644 --- a/ledis/scan.go +++ b/ledis/scan.go @@ -61,44 +61,19 @@ func buildMatchRegexp(match string) (*regexp.Regexp, error) { return r, nil } -func (db *DB) scanGeneric(storeDataType byte, key []byte, count int, - inclusive bool, match string, reverse bool) ([][]byte, error) { - var minKey, maxKey []byte - r, err := buildMatchRegexp(match) - if err != nil { - return nil, err - } - +func (db *DB) buildScanIterator(minKey []byte, maxKey []byte, inclusive bool, reverse bool) *store.RangeLimitIterator { tp := store.RangeOpen if !reverse { - if minKey, err = db.encodeScanMinKey(storeDataType, key); err != nil { - return nil, err - } - if maxKey, err = db.encodeScanMaxKey(storeDataType, nil); err != nil { - return nil, err - } - if inclusive { tp = store.RangeROpen } } else { - if minKey, err = db.encodeScanMinKey(storeDataType, nil); err != nil { - return nil, err - } - if maxKey, err = db.encodeScanMaxKey(storeDataType, key); err != nil { - return nil, err - } - if inclusive { tp = store.RangeLOpen } } - if count <= 0 { - count = defaultScanCount - } - var it *store.RangeLimitIterator if !reverse { it = db.bucket.RangeIterator(minKey, maxKey, tp) @@ -106,6 +81,53 @@ func (db *DB) scanGeneric(storeDataType byte, key []byte, count int, it = db.bucket.RevRangeIterator(minKey, maxKey, tp) } + return it +} + +func (db *DB) buildScanKeyRange(storeDataType byte, key []byte, reverse bool) (minKey []byte, maxKey []byte, err error) { + if !reverse { + if minKey, err = db.encodeScanMinKey(storeDataType, key); err != nil { + return + } + if maxKey, err = db.encodeScanMaxKey(storeDataType, nil); err != nil { + return + } + } else { + if minKey, err = db.encodeScanMinKey(storeDataType, nil); err != nil { + return + } + if maxKey, err = db.encodeScanMaxKey(storeDataType, key); err != nil { + return + } + } + return +} + +func checkScanCount(count int) int { + if count <= 0 { + count = defaultScanCount + } + + return count +} + +func (db *DB) scanGeneric(storeDataType byte, key []byte, count int, + inclusive bool, match string, reverse bool) ([][]byte, error) { + + r, err := buildMatchRegexp(match) + if err != nil { + return nil, err + } + + minKey, maxKey, err := db.buildScanKeyRange(storeDataType, key, reverse) + if err != nil { + return nil, err + } + + count = checkScanCount(count) + + it := db.buildScanIterator(minKey, maxKey, inclusive, reverse) + v := make([][]byte, 0, count) for i := 0; it.Valid() && i < count; it.Next() { @@ -123,22 +145,11 @@ func (db *DB) scanGeneric(storeDataType byte, key []byte, count int, } func (db *DB) encodeScanMinKey(storeDataType byte, key []byte) ([]byte, error) { - if len(key) == 0 { - return db.encodeScanKey(storeDataType, nil) - } else { - if err := checkKeySize(key); err != nil { - return nil, err - } - return db.encodeScanKey(storeDataType, key) - } + return db.encodeScanKey(storeDataType, key) } func (db *DB) encodeScanMaxKey(storeDataType byte, key []byte) ([]byte, error) { if len(key) > 0 { - if err := checkKeySize(key); err != nil { - return nil, err - } - return db.encodeScanKey(storeDataType, key) } @@ -162,12 +173,11 @@ func (db *DB) encodeScanKey(storeDataType byte, key []byte) ([]byte, error) { return db.zEncodeSizeKey(key), nil case SSizeType: return db.sEncodeSizeKey(key), nil - // case BitMetaType: - // return db.bEncodeMetaKey(key), nil default: return nil, errDataType } } + func (db *DB) decodeScanKey(storeDataType byte, ek []byte) ([]byte, error) { if len(ek) < 2 || ek[0] != db.index || ek[1] != storeDataType { return nil, errMetaKey @@ -177,33 +187,89 @@ func (db *DB) decodeScanKey(storeDataType byte, ek []byte) ([]byte, error) { // for specail data scan -func (db *DB) buildDataScanIterator(start []byte, stop []byte, inclusive bool) *store.RangeLimitIterator { - tp := store.RangeROpen - - if !inclusive { - tp = store.RangeOpen +func (db *DB) buildDataScanKeyRange(storeDataType byte, key []byte, cursor []byte, reverse bool) (minKey []byte, maxKey []byte, err error) { + if !reverse { + if minKey, err = db.encodeDataScanMinKey(storeDataType, key, cursor); err != nil { + return + } + if maxKey, err = db.encodeDataScanMaxKey(storeDataType, key, nil); err != nil { + return + } + } else { + if minKey, err = db.encodeDataScanMinKey(storeDataType, key, nil); err != nil { + return + } + if maxKey, err = db.encodeDataScanMaxKey(storeDataType, key, cursor); err != nil { + return + } } - it := db.bucket.RangeIterator(start, stop, tp) - return it - + return } -func (db *DB) HScan(key []byte, cursor []byte, count int, inclusive bool, match string) ([]FVPair, error) { +func (db *DB) encodeDataScanMinKey(storeDataType byte, key []byte, cursor []byte) ([]byte, error) { + return db.encodeDataScanKey(storeDataType, key, cursor) +} + +func (db *DB) encodeDataScanMaxKey(storeDataType byte, key []byte, cursor []byte) ([]byte, error) { + if len(cursor) > 0 { + return db.encodeDataScanKey(storeDataType, key, cursor) + } + + k, err := db.encodeDataScanKey(storeDataType, key, nil) + if err != nil { + return nil, err + } + + // here, the last byte is the start seperator, set it to stop seperator + k[len(k)-1] = k[len(k)-1] + 1 + return k, nil +} + +func (db *DB) encodeDataScanKey(storeDataType byte, key []byte, cursor []byte) ([]byte, error) { + switch storeDataType { + case HashType: + return db.hEncodeHashKey(key, cursor), nil + case ZSetType: + return db.zEncodeSetKey(key, cursor), nil + case SetType: + return db.sEncodeSetKey(key, cursor), nil + default: + return nil, errDataType + } +} + +func (db *DB) buildDataScanIterator(storeDataType byte, key []byte, cursor []byte, count int, + inclusive bool, reverse bool) (*store.RangeLimitIterator, error) { + if err := checkKeySize(key); err != nil { return nil, err } - start := db.hEncodeHashKey(key, cursor) - stop := db.hEncodeStopKey(key) + minKey, maxKey, err := db.buildDataScanKeyRange(storeDataType, key, cursor, reverse) + if err != nil { + return nil, err + } - v := make([]FVPair, 0, 16) + it := db.buildScanIterator(minKey, maxKey, inclusive, reverse) + + return it, nil +} + +func (db *DB) hScanGeneric(key []byte, cursor []byte, count int, inclusive bool, match string, reverse bool) ([]FVPair, error) { + count = checkScanCount(count) r, err := buildMatchRegexp(match) if err != nil { return nil, err } - it := db.buildDataScanIterator(start, stop, inclusive) + v := make([]FVPair, 0, count) + + it, err := db.buildDataScanIterator(HashType, key, cursor, count, inclusive, reverse) + if err != nil { + return nil, err + } + defer it.Close() for i := 0; it.Valid() && i < count; it.Next() { @@ -222,22 +288,29 @@ func (db *DB) HScan(key []byte, cursor []byte, count int, inclusive bool, match return v, nil } -func (db *DB) SScan(key []byte, cursor []byte, count int, inclusive bool, match string) ([][]byte, error) { - if err := checkKeySize(key); err != nil { - return nil, err - } +func (db *DB) HScan(key []byte, cursor []byte, count int, inclusive bool, match string) ([]FVPair, error) { + return db.hScanGeneric(key, cursor, count, inclusive, match, false) +} - start := db.sEncodeSetKey(key, cursor) - stop := db.sEncodeStopKey(key) +func (db *DB) HRevScan(key []byte, cursor []byte, count int, inclusive bool, match string) ([]FVPair, error) { + return db.hScanGeneric(key, cursor, count, inclusive, match, true) +} - v := make([][]byte, 0, 16) +func (db *DB) sScanGeneric(key []byte, cursor []byte, count int, inclusive bool, match string, reverse bool) ([][]byte, error) { + count = checkScanCount(count) r, err := buildMatchRegexp(match) if err != nil { return nil, err } - it := db.buildDataScanIterator(start, stop, inclusive) + v := make([][]byte, 0, count) + + it, err := db.buildDataScanIterator(SetType, key, cursor, count, inclusive, reverse) + if err != nil { + return nil, err + } + defer it.Close() for i := 0; it.Valid() && i < count; it.Next() { @@ -256,22 +329,29 @@ func (db *DB) SScan(key []byte, cursor []byte, count int, inclusive bool, match return v, nil } -func (db *DB) ZScan(key []byte, cursor []byte, count int, inclusive bool, match string) ([]ScorePair, error) { - if err := checkKeySize(key); err != nil { - return nil, err - } +func (db *DB) SScan(key []byte, cursor []byte, count int, inclusive bool, match string) ([][]byte, error) { + return db.sScanGeneric(key, cursor, count, inclusive, match, false) +} - start := db.zEncodeSetKey(key, cursor) - stop := db.zEncodeStopSetKey(key) +func (db *DB) SRevScan(key []byte, cursor []byte, count int, inclusive bool, match string) ([][]byte, error) { + return db.sScanGeneric(key, cursor, count, inclusive, match, true) +} - v := make([]ScorePair, 0, 16) +func (db *DB) zScanGeneric(key []byte, cursor []byte, count int, inclusive bool, match string, reverse bool) ([]ScorePair, error) { + count = checkScanCount(count) r, err := buildMatchRegexp(match) if err != nil { return nil, err } - it := db.buildDataScanIterator(start, stop, inclusive) + v := make([]ScorePair, 0, count) + + it, err := db.buildDataScanIterator(ZSetType, key, cursor, count, inclusive, reverse) + if err != nil { + return nil, err + } + defer it.Close() for i := 0; it.Valid() && i < count; it.Next() { @@ -294,3 +374,11 @@ func (db *DB) ZScan(key []byte, cursor []byte, count int, inclusive bool, match return v, nil } + +func (db *DB) ZScan(key []byte, cursor []byte, count int, inclusive bool, match string) ([]ScorePair, error) { + return db.zScanGeneric(key, cursor, count, inclusive, match, false) +} + +func (db *DB) ZRevScan(key []byte, cursor []byte, count int, inclusive bool, match string) ([]ScorePair, error) { + return db.zScanGeneric(key, cursor, count, inclusive, match, true) +} diff --git a/ledis/scan_test.go b/ledis/scan_test.go index 9505964..e7d3171 100644 --- a/ledis/scan_test.go +++ b/ledis/scan_test.go @@ -323,6 +323,16 @@ func TestDBHScan(t *testing.T) { } else if string(v[0].Field) != "222" { t.Fatal(string(v[0].Field)) } + + v, err = db.HRevScan(key, []byte("19"), 1, false, "") + if err != nil { + t.Fatal(err) + } else if len(v) != 1 { + t.Fatal("invalid count", len(v)) + } else if string(v[0].Field) != "1234" { + t.Fatal(string(v[0].Field)) + } + } func TestDBSScan(t *testing.T) { @@ -346,6 +356,16 @@ func TestDBSScan(t *testing.T) { } else if string(v[0]) != "222" { t.Fatal(string(v[0])) } + + v, err = db.SRevScan(key, []byte("19"), 1, false, "") + if err != nil { + t.Fatal(err) + } else if len(v) != 1 { + t.Fatal("invalid count", len(v)) + } else if string(v[0]) != "1234" { + t.Fatal(string(v[0])) + } + } func TestDBZScan(t *testing.T) { @@ -369,4 +389,14 @@ func TestDBZScan(t *testing.T) { } else if string(v[0].Member) != "222" { t.Fatal(string(v[0].Member)) } + + v, err = db.ZRevScan(key, []byte("19"), 1, false, "") + if err != nil { + t.Fatal(err) + } else if len(v) != 1 { + t.Fatal("invalid count", len(v)) + } else if string(v[0].Member) != "1234" { + t.Fatal(string(v[0].Member)) + } + } diff --git a/server/cmd_scan.go b/server/cmd_scan.go index 1e62d8f..dc579c4 100644 --- a/server/cmd_scan.go +++ b/server/cmd_scan.go @@ -9,13 +9,15 @@ import ( "strings" ) -func parseScanArgs(args [][]byte) (cursor []byte, match string, count int, err error) { +func parseScanArgs(args [][]byte) (cursor []byte, match string, count int, desc bool, err error) { cursor = args[0] args = args[1:] count = 10 + desc = false + for i := 0; i < len(args); { switch strings.ToUpper(hack.String(args[i])) { case "MATCH": @@ -25,7 +27,7 @@ func parseScanArgs(args [][]byte) (cursor []byte, match string, count int, err e } match = hack.String(args[i+1]) - i = i + 2 + i++ case "COUNT": if i+1 >= len(args) { err = ErrCmdParams @@ -37,17 +39,23 @@ func parseScanArgs(args [][]byte) (cursor []byte, match string, count int, err e return } - i = i + 2 + i++ + case "ASC": + desc = false + case "DESC": + desc = true default: err = fmt.Errorf("invalid argument %s", args[i]) return } + + i++ } return } -// XSCAN type cursor [MATCH match] [COUNT count] +// XSCAN type cursor [MATCH match] [COUNT count] [ASC|DESC] func xscanCommand(c *client) error { args := c.args @@ -71,13 +79,20 @@ func xscanCommand(c *client) error { return fmt.Errorf("invalid key type %s", args[0]) } - cursor, match, count, err := parseScanArgs(args[1:]) + cursor, match, count, desc, err := parseScanArgs(args[1:]) if err != nil { return err } - ay, err := c.db.Scan(dataType, cursor, count, false, match) + var ay [][]byte + + if !desc { + ay, err = c.db.Scan(dataType, cursor, count, false, match) + } else { + ay, err = c.db.RevScan(dataType, cursor, count, false, match) + } + if err != nil { return err } @@ -93,7 +108,7 @@ func xscanCommand(c *client) error { return nil } -// XHSCAN key cursor [MATCH match] [COUNT count] +// XHSCAN key cursor [MATCH match] [COUNT count] [ASC|DESC] func xhscanCommand(c *client) error { args := c.args @@ -103,13 +118,20 @@ func xhscanCommand(c *client) error { key := args[0] - cursor, match, count, err := parseScanArgs(args[1:]) + cursor, match, count, desc, err := parseScanArgs(args[1:]) if err != nil { return err } - ay, err := c.db.HScan(key, cursor, count, false, match) + var ay []ledis.FVPair + + if !desc { + ay, err = c.db.HScan(key, cursor, count, false, match) + } else { + ay, err = c.db.HRevScan(key, cursor, count, false, match) + } + if err != nil { return err } @@ -133,7 +155,7 @@ func xhscanCommand(c *client) error { return nil } -// XSSCAN key cursor [MATCH match] [COUNT count] +// XSSCAN key cursor [MATCH match] [COUNT count] [ASC|DESC] func xsscanCommand(c *client) error { args := c.args @@ -143,13 +165,20 @@ func xsscanCommand(c *client) error { key := args[0] - cursor, match, count, err := parseScanArgs(args[1:]) + cursor, match, count, desc, err := parseScanArgs(args[1:]) if err != nil { return err } - ay, err := c.db.SScan(key, cursor, count, false, match) + var ay [][]byte + + if !desc { + ay, err = c.db.SScan(key, cursor, count, false, match) + } else { + ay, err = c.db.SRevScan(key, cursor, count, false, match) + } + if err != nil { return err } @@ -167,7 +196,7 @@ func xsscanCommand(c *client) error { return nil } -// XZSCAN key cursor [MATCH match] [COUNT count] +// XZSCAN key cursor [MATCH match] [COUNT count] [ASC|DESC] func xzscanCommand(c *client) error { args := c.args @@ -177,13 +206,20 @@ func xzscanCommand(c *client) error { key := args[0] - cursor, match, count, err := parseScanArgs(args[1:]) + cursor, match, count, desc, err := parseScanArgs(args[1:]) if err != nil { return err } - ay, err := c.db.ZScan(key, cursor, count, false, match) + var ay []ledis.ScorePair + + if !desc { + ay, err = c.db.ZScan(key, cursor, count, false, match) + } else { + ay, err = c.db.ZRevScan(key, cursor, count, false, match) + } + if err != nil { return err }