From aee0cc6caed486c2053122a6f23774424effd65e Mon Sep 17 00:00:00 2001 From: Stephanie Hingtgen Date: Thu, 21 Oct 2021 14:54:46 -0600 Subject: [PATCH] feat: add helper func, cleanup code --- cluster.go | 12 +------ cluster_test.go | 84 ++++++++++++++----------------------------------- go.mod | 1 + go.sum | 3 ++ options.go | 44 ++++++++++++++++---------- 5 files changed, 55 insertions(+), 89 deletions(-) diff --git a/cluster.go b/cluster.go index 98c16ff..3632ce9 100644 --- a/cluster.go +++ b/cluster.go @@ -173,17 +173,7 @@ func ParseClusterURL(redisURL string) (*ClusterOptions, error) { // add base URL to the array of addresses // more addresses may be added through the URL params - h, p, err := net.SplitHostPort(u.Host) - if err != nil { - h = u.Host - } - if h == "" { - h = "localhost" - } - if p == "" { - p = "6379" - } - + h, p := getHostPortWithDefaults(u) o.Addrs = append(o.Addrs, net.JoinHostPort(h, p)) // setup username, password, and other configurations diff --git a/cluster_test.go b/cluster_test.go index f923112..92512ab 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "net" - "reflect" "strconv" "strings" "sync" @@ -15,6 +14,7 @@ import ( . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" + "github.com/stretchr/testify/assert" "github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8/internal/hashtag" @@ -1300,7 +1300,7 @@ func TestParseClusterURL(t *testing.T) { }, { test: "ParseRedissURL", url: "rediss://localhost:123", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, TLSConfig: &tls.Config{ /* no deep comparison */ }}, + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, TLSConfig: &tls.Config{ServerName: "localhost"}}, }, { test: "MissingRedisPort", url: "redis://localhost", @@ -1308,15 +1308,15 @@ func TestParseClusterURL(t *testing.T) { }, { test: "MissingRedissPort", url: "rediss://localhost", - o: &redis.ClusterOptions{Addrs: []string{"localhost:6379"}, TLSConfig: &tls.Config{ /* no deep comparison */ }}, + o: &redis.ClusterOptions{Addrs: []string{"localhost:6379"}, TLSConfig: &tls.Config{ServerName: "localhost"}}, }, { test: "MultipleRedisURLs", url: "redis://localhost:123?addr=localhost:1234&addr=localhost:12345", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:12345", "localhost:1234"}}, + o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234", "localhost:12345"}}, }, { test: "MultipleRedissURLs", url: "rediss://localhost:123?addr=localhost:1234&addr=localhost:12345", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:12345", "localhost:1234"}, TLSConfig: &tls.Config{ /* no deep comparison */ }}, + o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234", "localhost:12345"}, TLSConfig: &tls.Config{ServerName: "localhost"}}, }, { test: "OnlyPassword", url: "redis://:bar@localhost:123", @@ -1332,7 +1332,7 @@ func TestParseClusterURL(t *testing.T) { }, { test: "RedissUsernamePassword", url: "rediss://foo:bar@localhost:123?addr=localhost:1234", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234"}, Username: "foo", Password: "bar", TLSConfig: &tls.Config{ /* no deep comparison */ }}, + o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234"}, Username: "foo", Password: "bar", TLSConfig: &tls.Config{ServerName: "localhost"}}, }, { test: "QueryParameters", url: "redis://localhost:123?read_timeout=2&pool_fifo=true&addr=localhost:1234", @@ -1403,59 +1403,21 @@ func TestParseClusterURL(t *testing.T) { func comprareOptions(t *testing.T, actual, expected *redis.ClusterOptions) { t.Helper() - - if !reflect.DeepEqual(actual.Addrs, expected.Addrs) { - t.Errorf("got %q, want %q", actual.Addrs, expected.Addrs) - } - if actual.TLSConfig == nil && expected.TLSConfig != nil { - t.Errorf("got nil TLSConfig, expected a TLSConfig") - } - if actual.TLSConfig != nil && expected.TLSConfig == nil { - t.Errorf("got TLSConfig, expected no TLSConfig") - } - if actual.Username != expected.Username { - t.Errorf("Username: got %q, expected %q", actual.Username, expected.Username) - } - if actual.Password != expected.Password { - t.Errorf("Password: got %q, expected %q", actual.Password, expected.Password) - } - if actual.MaxRetries != expected.MaxRetries { - t.Errorf("MaxRetries: got %v, expected %v", actual.MaxRetries, expected.MaxRetries) - } - if actual.MinRetryBackoff != expected.MinRetryBackoff { - t.Errorf("MinRetryBackoff: got %v, expected %v", actual.MinRetryBackoff, expected.MinRetryBackoff) - } - if actual.MaxRetryBackoff != expected.MaxRetryBackoff { - t.Errorf("MaxRetryBackoff: got %v, expected %v", actual.MaxRetryBackoff, expected.MaxRetryBackoff) - } - if actual.DialTimeout != expected.DialTimeout { - t.Errorf("DialTimeout: got %v, expected %v", actual.DialTimeout, expected.DialTimeout) - } - if actual.ReadTimeout != expected.ReadTimeout { - t.Errorf("ReadTimeout: got %v, expected %v", actual.ReadTimeout, expected.ReadTimeout) - } - if actual.WriteTimeout != expected.WriteTimeout { - t.Errorf("WriteTimeout: got %v, expected %v", actual.WriteTimeout, expected.WriteTimeout) - } - if actual.PoolFIFO != expected.PoolFIFO { - t.Errorf("PoolFIFO: got %v, expected %v", actual.PoolFIFO, expected.PoolFIFO) - } - if actual.PoolSize != expected.PoolSize { - t.Errorf("PoolSize: got %v, expected %v", actual.PoolSize, expected.PoolSize) - } - if actual.MinIdleConns != expected.MinIdleConns { - t.Errorf("MinIdleConns: got %v, expected %v", actual.MinIdleConns, expected.MinIdleConns) - } - if actual.MaxConnAge != expected.MaxConnAge { - t.Errorf("MaxConnAge: got %v, expected %v", actual.MaxConnAge, expected.MaxConnAge) - } - if actual.PoolTimeout != expected.PoolTimeout { - t.Errorf("PoolTimeout: got %v, expected %v", actual.PoolTimeout, expected.PoolTimeout) - } - if actual.IdleTimeout != expected.IdleTimeout { - t.Errorf("IdleTimeout: got %v, expected %v", actual.IdleTimeout, expected.IdleTimeout) - } - if actual.IdleCheckFrequency != expected.IdleCheckFrequency { - t.Errorf("IdleCheckFrequency: got %v, expected %v", actual.IdleCheckFrequency, expected.IdleCheckFrequency) - } + assert.Equal(t, expected.Addrs, actual.Addrs) + assert.Equal(t, expected.TLSConfig, actual.TLSConfig) + assert.Equal(t, expected.Username, actual.Username) + assert.Equal(t, expected.Password, actual.Password) + assert.Equal(t, expected.MaxRetries, actual.MaxRetries) + assert.Equal(t, expected.MinRetryBackoff, actual.MinRetryBackoff) + assert.Equal(t, expected.MaxRetryBackoff, actual.MaxRetryBackoff) + assert.Equal(t, expected.DialTimeout, actual.DialTimeout) + assert.Equal(t, expected.ReadTimeout, actual.ReadTimeout) + assert.Equal(t, expected.WriteTimeout, actual.WriteTimeout) + assert.Equal(t, expected.PoolFIFO, actual.PoolFIFO) + assert.Equal(t, expected.PoolSize, actual.PoolSize) + assert.Equal(t, expected.MinIdleConns, actual.MinIdleConns) + assert.Equal(t, expected.MaxConnAge, actual.MaxConnAge) + assert.Equal(t, expected.PoolTimeout, actual.PoolTimeout) + assert.Equal(t, expected.IdleTimeout, actual.IdleTimeout) + assert.Equal(t, expected.IdleCheckFrequency, actual.IdleCheckFrequency) } diff --git a/go.mod b/go.mod index 6852913..d34c75d 100644 --- a/go.mod +++ b/go.mod @@ -8,4 +8,5 @@ require ( github.com/google/go-cmp v0.5.6 // indirect github.com/onsi/ginkgo v1.16.4 github.com/onsi/gomega v1.16.0 + github.com/stretchr/testify v1.5.1 ) diff --git a/go.sum b/go.sum index d9aec34..db546ad 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,7 @@ github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= @@ -36,8 +37,10 @@ github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7J github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/onsi/gomega v1.16.0 h1:6gjqkI8iiRHMvdccRJM8rVKjCWk6ZIm6FTm3ddIe4/c= github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/options.go b/options.go index ed20c51..c68b0ad 100644 --- a/options.go +++ b/options.go @@ -240,16 +240,7 @@ func setupTCPConn(u *url.URL) (*Options, error) { o.Username, o.Password = getUserPassword(u) - h, p, err := net.SplitHostPort(u.Host) - if err != nil { - h = u.Host - } - if h == "" { - h = "localhost" - } - if p == "" { - p = "6379" - } + h, p := getHostPortWithDefaults(u) o.Addr = net.JoinHostPort(h, p) f := strings.FieldsFunc(u.Path, func(r rune) bool { @@ -259,6 +250,7 @@ func setupTCPConn(u *url.URL) (*Options, error) { case 0: o.DB = 0 case 1: + var err error if o.DB, err = strconv.Atoi(f[0]); err != nil { return nil, fmt.Errorf("redis: invalid database number: %q", f[0]) } @@ -273,6 +265,23 @@ func setupTCPConn(u *url.URL) (*Options, error) { return setupConnParams(u, o) } +// getHostPortWithDefaults is a helper function that splits the url into +// a host and a port. If the host is missing, it defaults to localhost +// and if the port is missing, it defaults to 6379 +func getHostPortWithDefaults(u *url.URL) (string, string) { + host, port, err := net.SplitHostPort(u.Host) + if err != nil { + host = u.Host + } + if host == "" { + host = "localhost" + } + if port == "" { + port = "6379" + } + return host, port +} + func setupUnixConn(u *url.URL) (*Options, error) { o := &Options{ Network: "unix", @@ -292,19 +301,20 @@ type queryOptions struct { } func (o *queryOptions) string(name string) string { - vs := o.q[name] - if len(vs) == 0 { + if len(o.q[name]) == 0 { return "" } + // get the first item from the array to return + // and remove it so it isn't processed again + param := o.q[name][0] + o.q[name] = o.q[name][1:] - // enable detection of unknown parameters - if len(vs) > 1 { - o.q[name] = o.q[name][:len(vs)-1] - } else { + // remove the key to enable detection of unknown params + if len(o.q[name]) == 0 { delete(o.q, name) } - return vs[len(vs)-1] + return param } func (o *queryOptions) int(name string) int {