diff --git a/cluster.go b/cluster.go index e0f2b40..0c405c0 100644 --- a/cluster.go +++ b/cluster.go @@ -47,12 +47,16 @@ func (c *ClusterClient) Close() error { // ------------------------------------------------------------------------ // getClient returns a Client for a given address. -func (c *ClusterClient) getClient(addr string) *Client { +func (c *ClusterClient) getClient(addr string) (*Client, error) { + if addr == "" { + return c.randomClient() + } + c.clientsMx.RLock() client, ok := c.clients[addr] if ok { c.clientsMx.RUnlock() - return client + return client, nil } c.clientsMx.RUnlock() @@ -66,14 +70,24 @@ func (c *ClusterClient) getClient(addr string) *Client { } c.clientsMx.Unlock() - return client + return client, nil +} + +func (c *ClusterClient) slotAddrs(slot int) []string { + c.slotsMx.RLock() + addrs := c.slots[slot] + c.slotsMx.RUnlock() + return addrs } // randomClient returns a Client for the first live node. func (c *ClusterClient) randomClient() (client *Client, err error) { for i := 0; i < 10; i++ { n := rand.Intn(len(c.addrs)) - client = c.getClient(c.addrs[n]) + client, err = c.getClient(c.addrs[n]) + if err != nil { + continue + } err = client.Ping().Err() if err == nil { return client, nil @@ -82,27 +96,22 @@ func (c *ClusterClient) randomClient() (client *Client, err error) { return nil, err } -// Process a command func (c *ClusterClient) process(cmd Cmder) { - var client *Client var ask bool c.reloadIfDue() slot := hashSlot(cmd.clusterKey()) - c.slotsMx.RLock() - addrs := c.slots[slot] - c.slotsMx.RUnlock() - if len(addrs) > 0 { - client = c.getClient(addrs[0]) // First address is master. - } else { - var err error - client, err = c.randomClient() - if err != nil { - cmd.setErr(err) - return - } + var addr string + if addrs := c.slotAddrs(slot); len(addrs) > 0 { + addr = addrs[0] // First address is master. + } + + client, err := c.getClient(addr) + if err != nil { + cmd.setErr(err) + return } for attempt := 0; attempt <= c.opt.getMaxRedirects(); attempt++ { @@ -132,24 +141,22 @@ func (c *ClusterClient) process(cmd Cmder) { continue } - // Check the error message, return if unexpected - parts := strings.SplitN(err.Error(), " ", 3) - if len(parts) != 3 { - return + var moved bool + var addr string + moved, ask, addr = isMovedError(err) + if moved || ask { + if moved { + c.scheduleReload() + } + client, err = c.getClient(addr) + if err != nil { + return + } + cmd.reset() + continue } - // Handle MOVE and ASK redirections, return on any other error - switch parts[0] { - case "MOVED": - c.scheduleReload() - client = c.getClient(parts[2]) - case "ASK": - ask = true - client = c.getClient(parts[2]) - default: - return - } - cmd.reset() + break } } diff --git a/cluster_client_test.go b/cluster_client_test.go index ad5c101..6e44b10 100644 --- a/cluster_client_test.go +++ b/cluster_client_test.go @@ -5,6 +5,23 @@ import ( . "github.com/onsi/gomega" ) +// GetSlot returns the cached slot addresses +func (c *ClusterClient) GetSlot(pos int) []string { + c.slotsMx.RLock() + defer c.slotsMx.RUnlock() + + return c.slots[pos] +} + +// SwapSlot swaps a slot's master/slave address +// for testing MOVED redirects +func (c *ClusterClient) SwapSlot(pos int) []string { + c.slotsMx.Lock() + defer c.slotsMx.Unlock() + c.slots[pos][0], c.slots[pos][1] = c.slots[pos][1], c.slots[pos][0] + return c.slots[pos] +} + var _ = Describe("ClusterClient", func() { var subject *ClusterClient diff --git a/cluster_pipeline.go b/cluster_pipeline.go new file mode 100644 index 0000000..a9e61b2 --- /dev/null +++ b/cluster_pipeline.go @@ -0,0 +1,128 @@ +package redis + +// ClusterPipeline is not thread-safe. +type ClusterPipeline struct { + commandable + + cmds []Cmder + cluster *ClusterClient + closed bool +} + +// Pipeline creates a new pipeline which is able to execute commands +// against multiple shards. +func (c *ClusterClient) Pipeline() *ClusterPipeline { + pipe := &ClusterPipeline{ + cluster: c, + cmds: make([]Cmder, 0, 10), + } + pipe.commandable.process = pipe.process + return pipe +} + +func (c *ClusterPipeline) process(cmd Cmder) { + c.cmds = append(c.cmds, cmd) +} + +// Close marks the pipeline as closed +func (c *ClusterPipeline) Close() error { + c.closed = true + return nil +} + +// Discard resets the pipeline and discards queued commands +func (c *ClusterPipeline) Discard() error { + if c.closed { + return errClosed + } + c.cmds = c.cmds[:0] + return nil +} + +func (c *ClusterPipeline) Exec() (cmds []Cmder, retErr error) { + if c.closed { + return nil, errClosed + } + if len(c.cmds) == 0 { + return []Cmder{}, nil + } + + cmds = c.cmds + c.cmds = make([]Cmder, 0, 10) + + cmdsMap := make(map[string][]Cmder) + for _, cmd := range cmds { + slot := hashSlot(cmd.clusterKey()) + addrs := c.cluster.slotAddrs(slot) + + var addr string + if len(addrs) > 0 { + addr = addrs[0] // First address is master. + } + + cmdsMap[addr] = append(cmdsMap[addr], cmd) + } + + for attempt := 0; attempt <= c.cluster.opt.getMaxRedirects(); attempt++ { + failedCmds := make(map[string][]Cmder) + + for addr, cmds := range cmdsMap { + client, err := c.cluster.getClient(addr) + if err != nil { + setCmdsErr(cmds, err) + retErr = err + continue + } + + cn, err := client.conn() + if err != nil { + setCmdsErr(cmds, err) + retErr = err + continue + } + + failedCmds, err = c.execClusterCmds(cn, cmds, failedCmds) + if err != nil { + retErr = err + } + client.freeConn(cn, err) + } + + cmdsMap = failedCmds + } + + return cmds, retErr +} + +func (c *ClusterPipeline) execClusterCmds( + cn *conn, cmds []Cmder, failedCmds map[string][]Cmder, +) (map[string][]Cmder, error) { + if err := cn.writeCmds(cmds...); err != nil { + setCmdsErr(cmds, err) + return failedCmds, err + } + + var firstCmdErr error + for i, cmd := range cmds { + err := cmd.parseReply(cn.rd) + if err == nil { + continue + } + if isNetworkError(err) { + cmd.reset() + failedCmds[""] = append(failedCmds[""], cmds[i:]...) + break + } else if moved, ask, addr := isMovedError(err); moved { + c.cluster.scheduleReload() + cmd.reset() + failedCmds[addr] = append(failedCmds[addr], cmd) + } else if ask { + cmd.reset() + failedCmds[addr] = append(failedCmds[addr], NewCmd("ASKING"), cmd) + } else if firstCmdErr == nil { + firstCmdErr = err + } + } + + return failedCmds, firstCmdErr +} diff --git a/cluster_test.go b/cluster_test.go index 5d12cd3..70bdd35 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -2,9 +2,11 @@ package redis_test import ( "math/rand" + "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" + "gopkg.in/redis.v2" ) @@ -181,22 +183,50 @@ var _ = Describe("Cluster", func() { It("should follow redirects", func() { Expect(client.Set("A", "VALUE", 0).Err()).NotTo(HaveOccurred()) Expect(redis.HashSlot("A")).To(Equal(6373)) - - // Slot 6373 is stored on the second node - defer func() { - scenario.masters()[1].ClusterFailover() - }() - - slave := scenario.slaves()[1] - Expect(slave.ClusterFailover().Err()).NotTo(HaveOccurred()) - Eventually(func() string { - return slave.Info().Val() - }, "10s", "200ms").Should(ContainSubstring("role:master")) + Expect(client.SwapSlot(6373)).To(Equal([]string{"127.0.0.1:8224", "127.0.0.1:8221"})) val, err := client.Get("A").Result() Expect(err).NotTo(HaveOccurred()) Expect(val).To(Equal("VALUE")) + Expect(client.GetSlot(6373)).To(Equal([]string{"127.0.0.1:8224", "127.0.0.1:8221"})) + + val, err = client.Get("A").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("VALUE")) + Expect(client.GetSlot(6373)).To(Equal([]string{"127.0.0.1:8221", "127.0.0.1:8224"})) }) + + It("should perform multi-pipelines", func() { + // Dummy command to load slots info. + Expect(client.Ping().Err()).NotTo(HaveOccurred()) + + slot := redis.HashSlot("A") + Expect(client.SwapSlot(slot)).To(Equal([]string{"127.0.0.1:8224", "127.0.0.1:8221"})) + + pipe := client.Pipeline() + defer pipe.Close() + + keys := []string{"A", "B", "C", "D", "E", "F", "G"} + for i, key := range keys { + pipe.Set(key, key+"_value", 0) + pipe.Expire(key, time.Duration(i+1)*time.Hour) + } + for _, key := range keys { + pipe.Get(key) + pipe.TTL(key) + } + + cmds, err := pipe.Exec() + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(28)) + Expect(cmds[14].(*redis.StringCmd).Val()).To(Equal("A_value")) + Expect(cmds[15].(*redis.DurationCmd).Val()).To(BeNumerically("~", 1*time.Hour, time.Second)) + Expect(cmds[20].(*redis.StringCmd).Val()).To(Equal("D_value")) + Expect(cmds[21].(*redis.DurationCmd).Val()).To(BeNumerically("~", 4*time.Hour, time.Second)) + Expect(cmds[26].(*redis.StringCmd).Val()).To(Equal("G_value")) + Expect(cmds[27].(*redis.DurationCmd).Val()).To(BeNumerically("~", 7*time.Hour, time.Second)) + }) + }) }) diff --git a/command_test.go b/command_test.go index a02847e..633c8a5 100644 --- a/command_test.go +++ b/command_test.go @@ -117,51 +117,52 @@ var _ = Describe("Command", func() { }) Describe("races", func() { + var C, N = 10, 1000 + if testing.Short() { + N = 100 + } It("should echo", func() { - var n = 10000 - if testing.Short() { - n = 1000 - } - wg := &sync.WaitGroup{} - wg.Add(n) - for i := 0; i < n; i++ { + for i := 0; i < C; i++ { + wg.Add(1) + go func(i int) { + defer GinkgoRecover() defer wg.Done() - msg := "echo" + strconv.Itoa(i) - echo := client.Echo(msg) - Expect(echo.Err()).NotTo(HaveOccurred()) - Expect(echo.Val()).To(Equal(msg)) + for j := 0; j < N; j++ { + msg := "echo" + strconv.Itoa(i) + echo := client.Echo(msg) + Expect(echo.Err()).NotTo(HaveOccurred()) + Expect(echo.Val()).To(Equal(msg)) + } }(i) } wg.Wait() }) It("should incr", func() { - var n = 10000 - if testing.Short() { - n = 1000 - } - key := "TestIncrFromGoroutines" wg := &sync.WaitGroup{} - wg.Add(n) - for i := 0; i < n; i++ { + for i := 0; i < C; i++ { + wg.Add(1) + go func() { defer GinkgoRecover() defer wg.Done() - err := client.Incr(key).Err() - Expect(err).NotTo(HaveOccurred()) + for j := 0; j < N; j++ { + err := client.Incr(key).Err() + Expect(err).NotTo(HaveOccurred()) + } }() } wg.Wait() val, err := client.Get(key).Int64() Expect(err).NotTo(HaveOccurred()) - Expect(val).To(Equal(int64(n))) + Expect(val).To(Equal(int64(C * N))) }) }) diff --git a/commands_test.go b/commands_test.go index e87845a..b265bbe 100644 --- a/commands_test.go +++ b/commands_test.go @@ -2245,53 +2245,49 @@ var _ = Describe("Commands", func() { Describe("watch/unwatch", func() { - var safeIncr = func() ([]redis.Cmder, error) { - multi := client.Multi() - defer multi.Close() - - watch := multi.Watch("key") - Expect(watch.Err()).NotTo(HaveOccurred()) - Expect(watch.Val()).To(Equal("OK")) - - get := multi.Get("key") - Expect(get.Err()).NotTo(HaveOccurred()) - Expect(get.Val()).NotTo(Equal(redis.Nil)) - - v, err := strconv.ParseInt(get.Val(), 10, 64) - Expect(err).NotTo(HaveOccurred()) - - return multi.Exec(func() error { - multi.Set("key", strconv.FormatInt(v+1, 10), 0) - return nil - }) - } - It("should WatchUnwatch", func() { - var n = 10000 + var C, N = 10, 1000 if testing.Short() { - n = 1000 + N = 100 } err := client.Set("key", "0", 0).Err() Expect(err).NotTo(HaveOccurred()) wg := &sync.WaitGroup{} - for i := 0; i < n; i++ { + for i := 0; i < C; i++ { wg.Add(1) go func() { defer GinkgoRecover() defer wg.Done() - for { - cmds, err := safeIncr() + multi := client.Multi() + defer multi.Close() + + for j := 0; j < N; j++ { + val, err := multi.Watch("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("OK")) + + val, err = multi.Get("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).NotTo(Equal(redis.Nil)) + + num, err := strconv.ParseInt(val, 10, 64) + Expect(err).NotTo(HaveOccurred()) + + cmds, err := multi.Exec(func() error { + multi.Set("key", strconv.FormatInt(num+1, 10), 0) + return nil + }) if err == redis.TxFailedErr { + j-- continue } Expect(err).NotTo(HaveOccurred()) Expect(cmds).To(HaveLen(1)) Expect(cmds[0].Err()).NotTo(HaveOccurred()) - break } }() } @@ -2299,7 +2295,7 @@ var _ = Describe("Commands", func() { val, err := client.Get("key").Int64() Expect(err).NotTo(HaveOccurred()) - Expect(val).To(Equal(int64(n))) + Expect(val).To(Equal(int64(C * N))) }) }) diff --git a/error.go b/error.go index 33159d5..0e031f3 100644 --- a/error.go +++ b/error.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net" + "strings" ) // Redis nil reply. @@ -30,3 +31,25 @@ func isNetworkError(err error) bool { } return false } + +func isMovedError(err error) (moved bool, ask bool, addr string) { + if _, ok := err.(redisError); !ok { + return + } + + parts := strings.SplitN(err.Error(), " ", 3) + if len(parts) != 3 { + return + } + + switch parts[0] { + case "MOVED": + moved = true + addr = parts[2] + case "ASK": + ask = true + addr = parts[2] + } + + return +} diff --git a/pipeline.go b/pipeline.go index 33e51b4..9490696 100644 --- a/pipeline.go +++ b/pipeline.go @@ -54,14 +54,13 @@ func (c *Pipeline) Exec() ([]Cmder, error) { if c.closed { return nil, errClosed } + if len(c.cmds) == 0 { + return []Cmder{}, nil + } cmds := c.cmds c.cmds = make([]Cmder, 0, 0) - if len(cmds) == 0 { - return []Cmder{}, nil - } - cn, err := c.client.conn() if err != nil { setCmdsErr(cmds, err) @@ -84,11 +83,16 @@ func (c *Pipeline) execCmds(cn *conn, cmds []Cmder) error { } var firstCmdErr error - for _, cmd := range cmds { - if err := cmd.parseReply(cn.rd); err != nil { - if firstCmdErr == nil { - firstCmdErr = err - } + for i, cmd := range cmds { + err := cmd.parseReply(cn.rd) + if err == nil { + continue + } + if firstCmdErr == nil { + firstCmdErr = err + } + if isNetworkError(err) { + setCmdsErr(cmds[i:], err) } } diff --git a/redis_test.go b/redis_test.go index ed53b9a..b4b9195 100644 --- a/redis_test.go +++ b/redis_test.go @@ -127,7 +127,7 @@ func TestGinkgoSuite(t *testing.T) { func execCmd(name string, args ...string) (*os.Process, error) { cmd := exec.Command(name, args...) - if false { + if testing.Verbose() { cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr }