From 7456a0e473166133ad86b50e8296efdcc45a4e81 Mon Sep 17 00:00:00 2001 From: Dimitrij Denissenko Date: Wed, 13 Apr 2016 09:52:47 +0100 Subject: [PATCH] Add scan iterator. --- command.go | 19 ++++++----- commands.go | 28 ++++++++++----- example_test.go | 20 +++++++++++ iterator.go | 75 ++++++++++++++++++++++++++++++++++++++++ iterator_test.go | 89 ++++++++++++++++++++++++++++++++++++++++++++++++ redis.go | 3 +- 6 files changed, 216 insertions(+), 18 deletions(-) create mode 100644 iterator.go create mode 100644 iterator_test.go diff --git a/command.go b/command.go index e10d2fb0..8596c716 100644 --- a/command.go +++ b/command.go @@ -689,41 +689,42 @@ type ScanCmd struct { baseCmd cursor int64 - keys []string + page []string } func NewScanCmd(args ...interface{}) *ScanCmd { - return &ScanCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} + return &ScanCmd{ + baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}, + } } func (cmd *ScanCmd) reset() { cmd.cursor = 0 - cmd.keys = nil + cmd.page = nil cmd.err = nil } -// TODO: cursor should be string to match redis type // TODO: swap return values func (cmd *ScanCmd) Val() (int64, []string) { - return cmd.cursor, cmd.keys + return cmd.cursor, cmd.page } func (cmd *ScanCmd) Result() (int64, []string, error) { - return cmd.cursor, cmd.keys, cmd.err + return cmd.cursor, cmd.page, cmd.err } func (cmd *ScanCmd) String() string { - return cmdString(cmd, cmd.keys) + return cmdString(cmd, cmd.page) } func (cmd *ScanCmd) readReply(cn *pool.Conn) error { - keys, cursor, err := readScanReply(cn) + page, cursor, err := readScanReply(cn) if err != nil { cmd.err = err return cmd.err } - cmd.keys = keys + cmd.page = page cmd.cursor = cursor return nil } diff --git a/commands.go b/commands.go index 96bd1376..8ff6dc58 100644 --- a/commands.go +++ b/commands.go @@ -318,7 +318,7 @@ func (c *commandable) Type(key string) *StatusCmd { return cmd } -func (c *commandable) Scan(cursor int64, match string, count int64) *ScanCmd { +func (c *commandable) Scan(cursor int64, match string, count int64) Scanner { args := []interface{}{"SCAN", cursor} if match != "" { args = append(args, "MATCH", match) @@ -328,10 +328,13 @@ func (c *commandable) Scan(cursor int64, match string, count int64) *ScanCmd { } cmd := NewScanCmd(args...) c.Process(cmd) - return cmd + return Scanner{ + client: c, + ScanCmd: cmd, + } } -func (c *commandable) SScan(key string, cursor int64, match string, count int64) *ScanCmd { +func (c *commandable) SScan(key string, cursor int64, match string, count int64) Scanner { args := []interface{}{"SSCAN", key, cursor} if match != "" { args = append(args, "MATCH", match) @@ -341,10 +344,13 @@ func (c *commandable) SScan(key string, cursor int64, match string, count int64) } cmd := NewScanCmd(args...) c.Process(cmd) - return cmd + return Scanner{ + client: c, + ScanCmd: cmd, + } } -func (c *commandable) HScan(key string, cursor int64, match string, count int64) *ScanCmd { +func (c *commandable) HScan(key string, cursor int64, match string, count int64) Scanner { args := []interface{}{"HSCAN", key, cursor} if match != "" { args = append(args, "MATCH", match) @@ -354,10 +360,13 @@ func (c *commandable) HScan(key string, cursor int64, match string, count int64) } cmd := NewScanCmd(args...) c.Process(cmd) - return cmd + return Scanner{ + client: c, + ScanCmd: cmd, + } } -func (c *commandable) ZScan(key string, cursor int64, match string, count int64) *ScanCmd { +func (c *commandable) ZScan(key string, cursor int64, match string, count int64) Scanner { args := []interface{}{"ZSCAN", key, cursor} if match != "" { args = append(args, "MATCH", match) @@ -367,7 +376,10 @@ func (c *commandable) ZScan(key string, cursor int64, match string, count int64) } cmd := NewScanCmd(args...) c.Process(cmd) - return cmd + return Scanner{ + client: c, + ScanCmd: cmd, + } } //------------------------------------------------------------------------------ diff --git a/example_test.go b/example_test.go index 2e1ea595..abac6af6 100644 --- a/example_test.go +++ b/example_test.go @@ -314,3 +314,23 @@ func Example_customCommand() { fmt.Printf("%q %s", v, err) // Output: "" redis: nil } + +func ExampleScanIterator() { + iter := client.Scan(0, "", 0).Iterator() + for iter.Next() { + fmt.Println(iter.Val()) + } + if err := iter.Err(); err != nil { + panic(err) + } +} + +func ExampleScanCmd_Iterator() { + iter := client.Scan(0, "", 0).Iterator() + for iter.Next() { + fmt.Println(iter.Val()) + } + if err := iter.Err(); err != nil { + panic(err) + } +} diff --git a/iterator.go b/iterator.go new file mode 100644 index 00000000..c57c5afd --- /dev/null +++ b/iterator.go @@ -0,0 +1,75 @@ +package redis + +import "sync" + +type Scanner struct { + client *commandable + *ScanCmd +} + +// Iterator creates a new ScanIterator. +func (s Scanner) Iterator() *ScanIterator { + return &ScanIterator{ + Scanner: s, + } +} + +// ScanIterator is used to incrementally iterate over a collection of elements. +// It's safe for concurrent use by multiple goroutines. +type ScanIterator struct { + mu sync.Mutex // protects Scanner and pos + Scanner + pos int +} + +// Err returns the last iterator error, if any. +func (it *ScanIterator) Err() error { + it.mu.Lock() + err := it.ScanCmd.Err() + it.mu.Unlock() + return err +} + +// Next advances the cursor and returns true if more values can be read. +func (it *ScanIterator) Next() bool { + it.mu.Lock() + defer it.mu.Unlock() + + // Instantly return on errors. + if it.ScanCmd.Err() != nil { + return false + } + + // Advance cursor, check if we are still within range. + if it.pos < len(it.ScanCmd.page) { + it.pos++ + return true + } + + // Return if there is more data to fetch. + if it.ScanCmd.cursor == 0 { + return false + } + + // Fetch next page. + it.ScanCmd._args[1] = it.ScanCmd.cursor + it.ScanCmd.reset() + it.client.Process(it.ScanCmd) + if it.ScanCmd.Err() != nil { + return false + } + + it.pos = 1 + return len(it.ScanCmd.page) > 0 +} + +// Val returns the key/field at the current cursor position. +func (it *ScanIterator) Val() string { + var v string + it.mu.Lock() + if it.ScanCmd.Err() == nil && it.pos > 0 && it.pos <= len(it.ScanCmd.page) { + v = it.ScanCmd.page[it.pos-1] + } + it.mu.Unlock() + return v +} diff --git a/iterator_test.go b/iterator_test.go new file mode 100644 index 00000000..327feeb5 --- /dev/null +++ b/iterator_test.go @@ -0,0 +1,89 @@ +package redis_test + +import ( + "fmt" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "gopkg.in/redis.v4" +) + +var _ = Describe("ScanIterator", func() { + var client *redis.Client + + var seed = func(n int) error { + pipe := client.Pipeline() + for i := 1; i <= n; i++ { + pipe.Set(fmt.Sprintf("K%02d", i), "x", 0).Err() + } + _, err := pipe.Exec() + return err + } + + BeforeEach(func() { + client = redis.NewClient(redisOptions()) + Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + It("should scan across empty DBs", func() { + iter := client.Scan(0, "", 10).Iterator() + Expect(iter.Next()).To(BeFalse()) + Expect(iter.Err()).NotTo(HaveOccurred()) + }) + + It("should scan across one page", func() { + Expect(seed(7)).NotTo(HaveOccurred()) + + var vals []string + iter := client.Scan(0, "", 0).Iterator() + for iter.Next() { + vals = append(vals, iter.Val()) + } + Expect(iter.Err()).NotTo(HaveOccurred()) + Expect(vals).To(ConsistOf([]string{"K01", "K02", "K03", "K04", "K05", "K06", "K07"})) + }) + + It("should scan across multiple pages", func() { + Expect(seed(71)).NotTo(HaveOccurred()) + + var vals []string + iter := client.Scan(0, "", 10).Iterator() + for iter.Next() { + vals = append(vals, iter.Val()) + } + Expect(iter.Err()).NotTo(HaveOccurred()) + Expect(vals).To(HaveLen(71)) + Expect(vals).To(ContainElement("K01")) + Expect(vals).To(ContainElement("K71")) + }) + + It("should scan to page borders", func() { + Expect(seed(20)).NotTo(HaveOccurred()) + + var vals []string + iter := client.Scan(0, "", 10).Iterator() + for iter.Next() { + vals = append(vals, iter.Val()) + } + Expect(iter.Err()).NotTo(HaveOccurred()) + Expect(vals).To(HaveLen(20)) + }) + + It("should scan with match", func() { + Expect(seed(33)).NotTo(HaveOccurred()) + + var vals []string + iter := client.Scan(0, "K*2*", 10).Iterator() + for iter.Next() { + vals = append(vals, iter.Val()) + } + Expect(iter.Err()).NotTo(HaveOccurred()) + Expect(vals).To(HaveLen(13)) + }) + +}) diff --git a/redis.go b/redis.go index be3aa027..9222db91 100644 --- a/redis.go +++ b/redis.go @@ -146,12 +146,13 @@ type Client struct { func newClient(opt *Options, pool pool.Pooler) *Client { base := baseClient{opt: opt, connPool: pool} - return &Client{ + client := &Client{ baseClient: base, commandable: commandable{ process: base.process, }, } + return client } // NewClient returns a client to the Redis Server specified by Options.