diff --git a/ring.go b/ring.go index aa47543..1fd0a72 100644 --- a/ring.go +++ b/ring.go @@ -275,6 +275,7 @@ func (c *ringShards) Heartbeat(frequency time.Duration) { func (c *ringShards) rebalance() { hash := newConsistentHash(c.opt) var shardsNum int + c.mu.Lock() for name, shard := range c.shards { if shard.IsUp() { hash.Add(name) @@ -282,7 +283,6 @@ func (c *ringShards) rebalance() { } } - c.mu.Lock() c.hash = hash c.len = shardsNum c.mu.Unlock() @@ -653,6 +653,39 @@ func (c *Ring) Close() error { return c.shards.Close() } +func (c *Ring) Watch(fn func(*Tx) error, keys ...string) error { + if len(keys) == 0 { + return fmt.Errorf("redis: Watch requires at least one key") + } + + var shards []*ringShard + for _, key := range keys { + if key != "" { + shard, err := c.shards.GetByKey(hashtag.Key(key)) + if err != nil { + return err + } + + shards = append(shards, shard) + } + } + + if len(shards) == 0 { + return fmt.Errorf("redis: Watch requires at least one shard") + } + + if len(shards) > 1 { + for _, shard := range shards[1:] { + if shard.Client != shards[0].Client { + err := fmt.Errorf("redis: Watch requires all keys to be in the same shard") + return err + } + } + } + + return shards[0].Client.Watch(fn, keys...) +} + func newConsistentHash(opt *RingOptions) *consistenthash.Map { return consistenthash.New(opt.HashReplicas, consistenthash.Hash(opt.Hash)) } diff --git a/ring_test.go b/ring_test.go index 1f5bf0d..4ff0898 100644 --- a/ring_test.go +++ b/ring_test.go @@ -3,6 +3,9 @@ package redis_test import ( "crypto/rand" "fmt" + "net" + "strconv" + "sync" "time" "github.com/go-redis/redis" @@ -186,3 +189,263 @@ var _ = Describe("empty Redis Ring", func() { Expect(err).To(MatchError("redis: all ring shards are down")) }) }) + +var _ = Describe("Ring watch", func() { + const heartbeat = 100 * time.Millisecond + + var ring *redis.Ring + + BeforeEach(func() { + opt := redisRingOptions() + opt.HeartbeatFrequency = heartbeat + ring = redis.NewRing(opt) + + err := ring.ForEachShard(func(cl *redis.Client) error { + return cl.FlushDB().Err() + }) + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + Expect(ring.Close()).NotTo(HaveOccurred()) + }) + + It("should Watch", func() { + var incr func(string) error + + // Transactionally increments key using GET and SET commands. + incr = func(key string) error { + err := ring.Watch(func(tx *redis.Tx) error { + n, err := tx.Get(key).Int64() + if err != nil && err != redis.Nil { + return err + } + + _, err = tx.Pipelined(func(pipe redis.Pipeliner) error { + pipe.Set(key, strconv.FormatInt(n+1, 10), 0) + return nil + }) + return err + }, key) + if err == redis.TxFailedErr { + return incr(key) + } + return err + } + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer GinkgoRecover() + defer wg.Done() + + err := incr("key") + Expect(err).NotTo(HaveOccurred()) + }() + } + wg.Wait() + + n, err := ring.Get("key").Int64() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(100))) + }) + + It("should discard", func() { + err := ring.Watch(func(tx *redis.Tx) error { + cmds, err := tx.Pipelined(func(pipe redis.Pipeliner) error { + pipe.Set("key1", "hello1", 0) + pipe.Discard() + pipe.Set("key2", "hello2", 0) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(1)) + return err + }, "key1", "key2") + Expect(err).NotTo(HaveOccurred()) + + get := ring.Get("key1") + Expect(get.Err()).To(Equal(redis.Nil)) + Expect(get.Val()).To(Equal("")) + + get = ring.Get("key2") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("hello2")) + }) + + It("returns no error when there are no commands", func() { + err := ring.Watch(func(tx *redis.Tx) error { + _, err := tx.Pipelined(func(redis.Pipeliner) error { return nil }) + return err + }, "key") + Expect(err).NotTo(HaveOccurred()) + + v, err := ring.Ping().Result() + Expect(err).NotTo(HaveOccurred()) + Expect(v).To(Equal("PONG")) + }) + + It("should exec bulks", func() { + const N = 20000 + + err := ring.Watch(func(tx *redis.Tx) error { + cmds, err := tx.Pipelined(func(pipe redis.Pipeliner) error { + for i := 0; i < N; i++ { + pipe.Incr("key") + } + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(len(cmds)).To(Equal(N)) + for _, cmd := range cmds { + Expect(cmd.Err()).NotTo(HaveOccurred()) + } + return err + }, "key") + Expect(err).NotTo(HaveOccurred()) + + num, err := ring.Get("key").Int64() + Expect(err).NotTo(HaveOccurred()) + Expect(num).To(Equal(int64(N))) + }) + + It("should Watch/Unwatch", func() { + var C, N int + + err := ring.Set("key", "0", 0).Err() + Expect(err).NotTo(HaveOccurred()) + + perform(C, func(id int) { + for i := 0; i < N; i++ { + err := ring.Watch(func(tx *redis.Tx) error { + val, err := tx.Get("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).NotTo(Equal(redis.Nil)) + + num, err := strconv.ParseInt(val, 10, 64) + Expect(err).NotTo(HaveOccurred()) + + cmds, err := tx.Pipelined(func(pipe redis.Pipeliner) error { + pipe.Set("key", strconv.FormatInt(num+1, 10), 0) + return nil + }) + Expect(cmds).To(HaveLen(1)) + return err + }, "key") + if err == redis.TxFailedErr { + i-- + continue + } + Expect(err).NotTo(HaveOccurred()) + } + }) + + val, err := ring.Get("key").Int64() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal(int64(C * N))) + }) + + It("should close Tx without closing the client", func() { + err := ring.Watch(func(tx *redis.Tx) error { + _, err := tx.Pipelined(func(pipe redis.Pipeliner) error { + pipe.Ping() + return nil + }) + return err + }, "key") + Expect(err).NotTo(HaveOccurred()) + + Expect(ring.Ping().Err()).NotTo(HaveOccurred()) + }) + + It("respects max size on multi", func() { + perform(1000, func(id int) { + var ping *redis.StatusCmd + + err := ring.Watch(func(tx *redis.Tx) error { + cmds, err := tx.Pipelined(func(pipe redis.Pipeliner) error { + ping = pipe.Ping() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(1)) + return err + }, "key") + Expect(err).NotTo(HaveOccurred()) + + Expect(ping.Err()).NotTo(HaveOccurred()) + Expect(ping.Val()).To(Equal("PONG")) + }) + + ring.ForEachShard(func(cl *redis.Client) error { + pool := cl.Pool() + Expect(pool.Len()).To(BeNumerically("<=", 10)) + Expect(pool.IdleLen()).To(BeNumerically("<=", 10)) + Expect(pool.Len()).To(Equal(pool.IdleLen())) + + return nil + }) + }) +}) + +var _ = Describe("Ring Tx timeout", func() { + const heartbeat = 100 * time.Millisecond + + var ring *redis.Ring + + AfterEach(func() { + _ = ring.Close() + }) + + testTimeout := func() { + It("Tx timeouts", func() { + err := ring.Watch(func(tx *redis.Tx) error { + return tx.Ping().Err() + }, "foo") + Expect(err).To(HaveOccurred()) + Expect(err.(net.Error).Timeout()).To(BeTrue()) + }) + + It("Tx Pipeline timeouts", func() { + err := ring.Watch(func(tx *redis.Tx) error { + _, err := tx.Pipelined(func(pipe redis.Pipeliner) error { + pipe.Ping() + return nil + }) + return err + }, "foo") + Expect(err).To(HaveOccurred()) + Expect(err.(net.Error).Timeout()).To(BeTrue()) + }) + } + + const pause = 5 * time.Second + + Context("read/write timeout", func() { + BeforeEach(func() { + opt := redisRingOptions() + opt.ReadTimeout = 250 * time.Millisecond + opt.WriteTimeout = 250 * time.Millisecond + opt.HeartbeatFrequency = heartbeat + ring = redis.NewRing(opt) + + err := ring.ForEachShard(func(client *redis.Client) error { + return client.ClientPause(pause).Err() + }) + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + _ = ring.ForEachShard(func(client *redis.Client) error { + defer GinkgoRecover() + Eventually(func() error { + return client.Ping().Err() + }, 2*pause).ShouldNot(HaveOccurred()) + return nil + }) + }) + + testTimeout() + }) +})