Use node address instead of relying on loopback reported by redis

This commit is contained in:
Dimitrij Denissenko 2017-06-29 22:43:19 +01:00 committed by Vladimir Mihailenco
parent b52814fa17
commit 94ea195dc1
2 changed files with 26 additions and 5 deletions

View File

@ -3,6 +3,7 @@ package redis
import ( import (
"fmt" "fmt"
"math/rand" "math/rand"
"net"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -244,16 +245,22 @@ type clusterState struct {
slots [][]*clusterNode slots [][]*clusterNode
} }
func newClusterState(nodes *clusterNodes, slots []ClusterSlot) (*clusterState, error) { func newClusterState(nodes *clusterNodes, slots []ClusterSlot, origin string) (*clusterState, error) {
c := clusterState{ c := clusterState{
nodes: nodes, nodes: nodes,
slots: make([][]*clusterNode, hashtag.SlotNumber), slots: make([][]*clusterNode, hashtag.SlotNumber),
} }
isLoopbackOrigin := isLoopbackAddr(origin)
for _, slot := range slots { for _, slot := range slots {
var nodes []*clusterNode var nodes []*clusterNode
for _, slotNode := range slot.Nodes { for _, slotNode := range slot.Nodes {
node, err := c.nodes.Get(slotNode.Addr) addr := slotNode.Addr
if !isLoopbackOrigin && isLoopbackAddr(addr) {
addr = origin
}
node, err := c.nodes.Get(addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -661,7 +668,7 @@ func (c *ClusterClient) reloadSlots() (*clusterState, error) {
return nil, err return nil, err
} }
return newClusterState(c.nodes, slots) return newClusterState(c.nodes, slots, node.Client.opt.Addr)
} }
// reaper closes idle connections to the cluster. // reaper closes idle connections to the cluster.
@ -960,3 +967,17 @@ func (c *ClusterClient) txPipelineReadQueued(
return firstErr return firstErr
} }
func isLoopbackAddr(addr string) bool {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return false
}
ip := net.ParseIP(host)
if ip == nil {
return false
}
return ip.IsLoopback()
}

View File

@ -2888,12 +2888,12 @@ var _ = Describe("Commands", func() {
It("returns map of commands", func() { It("returns map of commands", func() {
cmds, err := client.Command().Result() cmds, err := client.Command().Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(len(cmds)).To(BeNumerically("~", 173, 5)) Expect(len(cmds)).To(BeNumerically("~", 180, 10))
cmd := cmds["mget"] cmd := cmds["mget"]
Expect(cmd.Name).To(Equal("mget")) Expect(cmd.Name).To(Equal("mget"))
Expect(cmd.Arity).To(Equal(int8(-2))) Expect(cmd.Arity).To(Equal(int8(-2)))
Expect(cmd.Flags).To(Equal([]string{"readonly"})) Expect(cmd.Flags).To(ContainElement("readonly"))
Expect(cmd.FirstKeyPos).To(Equal(int8(1))) Expect(cmd.FirstKeyPos).To(Equal(int8(1)))
Expect(cmd.LastKeyPos).To(Equal(int8(-1))) Expect(cmd.LastKeyPos).To(Equal(int8(-1)))
Expect(cmd.StepCount).To(Equal(int8(1))) Expect(cmd.StepCount).To(Equal(int8(1)))