diff --git a/options.go b/options.go index 373037c..27e2359 100644 --- a/options.go +++ b/options.go @@ -29,6 +29,7 @@ type Limiter interface { ReportResult(result error) } +// Options keeps the settings to setup redis connection. type Options struct { // The network type, either tcp or unix. // Default is tcp. @@ -187,23 +188,32 @@ func (opt *Options) clone() *Options { } // ParseURL parses an URL into Options that can be used to connect to Redis. +// Scheme is required. +// There are two connection types: by tcp socket and by unix socket. +// Tcp connection: +// redis://:@:/ +// Unix connection: +// unix://:@?db= func ParseURL(redisURL string) (*Options, error) { - o := &Options{Network: "tcp"} u, err := url.Parse(redisURL) if err != nil { return nil, err } - if u.Scheme != "redis" && u.Scheme != "rediss" { - return nil, errors.New("invalid redis URL scheme: " + u.Scheme) + switch u.Scheme { + case "redis", "rediss": + return setupTCPConn(u) + case "unix": + return setupUnixConn(u) + default: + return nil, fmt.Errorf("invalid redis URL scheme: %s", u.Scheme) } +} - if u.User != nil { - o.Username = u.User.Username() - if p, ok := u.User.Password(); ok { - o.Password = p - } - } +func setupTCPConn(u *url.URL) (*Options, error) { + o := &Options{Network: "tcp"} + + o.Username, o.Password = getUserPassword(u) if len(u.Query()) > 0 { return nil, errors.New("no options supported") @@ -232,15 +242,52 @@ func ParseURL(redisURL string) (*Options, error) { return nil, fmt.Errorf("invalid redis database number: %q", f[0]) } default: - return nil, errors.New("invalid redis URL path: " + u.Path) + return nil, fmt.Errorf("invalid redis URL path: %s", u.Path) } if u.Scheme == "rediss" { o.TLSConfig = &tls.Config{ServerName: h} } + return o, nil } +func setupUnixConn(u *url.URL) (*Options, error) { + o := &Options{ + Network: "unix", + } + + if strings.TrimSpace(u.Path) == "" { // path is required with unix connection + return nil, errors.New("empty redis unix socket path") + } + o.Addr = u.Path + + o.Username, o.Password = getUserPassword(u) + + dbStr := u.Query().Get("db") + if dbStr == "" { + return o, nil // if database is not set, connect to 0 db. + } + db, err := strconv.Atoi(dbStr) + if err != nil { + return nil, fmt.Errorf("invalid reids database number: %s", err) + } + o.DB = db + + return o, nil +} + +func getUserPassword(u *url.URL) (string, string) { + var user, password string + if u.User != nil { + user = u.User.Username() + if p, ok := u.User.Password(); ok { + password = p + } + } + return user, password +} + func newConnPool(opt *Options) *pool.ConnPool { return pool.NewConnPool(&pool.Options{ Dialer: func(ctx context.Context) (net.Conn, error) { diff --git a/options_test.go b/options_test.go index f3468ff..4ee3870 100644 --- a/options_test.go +++ b/options_test.go @@ -66,6 +66,30 @@ func TestParseURL(t *testing.T) { 0, false, nil, "foo", "bar", }, + { + "unix:///tmp/redis.sock", + "/tmp/redis.sock", + 0, false, nil, + "", "", + }, + { + "unix://foo:bar@/tmp/redis.sock", + "/tmp/redis.sock", + 0, false, nil, + "foo", "bar", + }, + { + "unix://foo:bar@/tmp/redis.sock?db=3", + "/tmp/redis.sock", + 3, false, nil, + "foo", "bar", + }, + { + "unix://foo:bar@/tmp/redis.sock?db=test", + "/tmp/redis.sock", + 0, false, errors.New("invalid reids database number: strconv.Atoi: parsing \"test\": invalid syntax"), + "", "", + }, { "redis://localhost/?abc=123", "",