diff --git a/README.md b/README.md index 9cb46af..cc52425 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,10 @@ Example 1: import "github.com/vmihailenco/redis" - redisClient := redis.NewTCPClient(":6379", "", 0) + address := ":6379" + password := "secret" + db := 0 + redisClient := redis.NewTCPClient(address, password, db) Example 2: @@ -29,7 +32,7 @@ Example 2: return nil } - initConn := func(client *Client) error { + initConn := func(client *redis.Client) error { _, err := client.Auth("foo").Reply() if err != nil { return err @@ -45,7 +48,7 @@ Example 2: redisClient := redis.NewClient(openConn, closeConn, initConn) -`closeConn` and `initConn` functions can be `nil`. +Both `closeConn` and `initConn` functions can be `nil`. Running commands ---------------- @@ -199,4 +202,4 @@ Connection pool Client uses connection pool with default capacity of 10 connections. To change pool capacity: - redisClient.ConnPool.MaxCap = 1 + redisClient.ConnPool.(*redis.MultiConnPool).MaxCap = 1 diff --git a/connpool.go b/connpool.go index 80e722b..c14ada9 100644 --- a/connpool.go +++ b/connpool.go @@ -21,7 +21,16 @@ func NewConn(rw io.ReadWriter) *Conn { } } -type ConnPool struct { +type ConnPool interface { + Get() (*Conn, bool, error) + Add(*Conn) + Remove(*Conn) + Len() int +} + +//------------------------------------------------------------------------------ + +type MultiConnPool struct { Logger *log.Logger cond *sync.Cond conns []*Conn @@ -30,13 +39,13 @@ type ConnPool struct { cap, MaxCap int64 } -func NewConnPool(openConn OpenConnFunc, closeConn CloseConnFunc, maxCap int64) *ConnPool { +func NewMultiConnPool(openConn OpenConnFunc, closeConn CloseConnFunc, maxCap int64) *MultiConnPool { logger := log.New( os.Stdout, "redis.connpool: ", log.Ldate|log.Ltime|log.Lshortfile, ) - return &ConnPool{ + return &MultiConnPool{ cond: sync.NewCond(&sync.Mutex{}), Logger: logger, conns: make([]*Conn, 0), @@ -46,7 +55,7 @@ func NewConnPool(openConn OpenConnFunc, closeConn CloseConnFunc, maxCap int64) * } } -func (p *ConnPool) Get() (*Conn, bool, error) { +func (p *MultiConnPool) Get() (*Conn, bool, error) { p.cond.L.Lock() defer p.cond.L.Unlock() @@ -72,14 +81,14 @@ func (p *ConnPool) Get() (*Conn, bool, error) { return conn, false, nil } -func (p *ConnPool) Add(conn *Conn) { +func (p *MultiConnPool) Add(conn *Conn) { p.cond.L.Lock() defer p.cond.L.Unlock() p.conns = append(p.conns, conn) p.cond.Signal() } -func (p *ConnPool) Remove(conn *Conn) { +func (p *MultiConnPool) Remove(conn *Conn) { p.cond.L.Lock() p.cap-- p.cond.Signal() @@ -90,6 +99,28 @@ func (p *ConnPool) Remove(conn *Conn) { } } -func (p *ConnPool) Len() int { +func (p *MultiConnPool) Len() int { return len(p.conns) } + +//------------------------------------------------------------------------------ + +type OneConnPool struct { + conn *Conn +} + +func NewOneConnPool(conn *Conn) *OneConnPool { + return &OneConnPool{conn: conn} +} + +func (p *OneConnPool) Get() (*Conn, bool, error) { + return p.conn, false, nil +} + +func (p *OneConnPool) Add(conn *Conn) {} + +func (p *OneConnPool) Remove(conn *Conn) {} + +func (p *OneConnPool) Len() int { + return 1 +} diff --git a/redis.go b/redis.go index 5bf7533..a586a7a 100644 --- a/redis.go +++ b/redis.go @@ -28,6 +28,10 @@ func TLSConnector(addr string, tlsConfig *tls.Config) OpenConnFunc { } func AuthSelectFunc(password string, db int64) InitConnFunc { + if password == "" && db < 0 { + return nil + } + return func(client *Client) error { if password != "" { _, err := client.Auth(password).Reply() @@ -36,9 +40,11 @@ func AuthSelectFunc(password string, db int64) InitConnFunc { } } - _, err := client.Select(db).Reply() - if err != nil { - return err + if db >= 0 { + _, err := client.Select(db).Reply() + if err != nil { + return err + } } return nil @@ -51,7 +57,7 @@ func createReader() (*bufreader.Reader, error) { type Client struct { mtx sync.Mutex - ConnPool *ConnPool + ConnPool ConnPool InitConn InitConnFunc reqs []Req @@ -59,7 +65,7 @@ type Client struct { func NewClient(openConn OpenConnFunc, closeConn CloseConnFunc, initConn InitConnFunc) *Client { return &Client{ - ConnPool: NewConnPool(openConn, closeConn, 10), + ConnPool: NewMultiConnPool(openConn, closeConn, 10), InitConn: initConn, } } @@ -76,6 +82,23 @@ func NewTLSClient(addr string, tlsConfig *tls.Config, password string, db int64) ) } +func (c *Client) conn() (*Conn, error) { + conn, isNew, err := c.ConnPool.Get() + if err != nil { + return nil, err + } + if isNew && c.InitConn != nil { + client := &Client{ + ConnPool: NewOneConnPool(conn), + } + err = c.InitConn(client) + if err != nil { + return nil, err + } + } + return conn, nil +} + func (c *Client) WriteReq(buf []byte, conn *Conn) error { _, err := conn.RW.Write(buf) return err @@ -120,7 +143,7 @@ func (c *Client) Queue(req Req) { } func (c *Client) Run(req Req) { - conn, _, err := c.ConnPool.Get() + conn, err := c.conn() if err != nil { req.SetErr(err) return @@ -154,7 +177,7 @@ func (c *Client) RunQueued() ([]Req, error) { c.reqs = make([]Req, 0) c.mtx.Unlock() - conn, _, err := c.ConnPool.Get() + conn, err := c.conn() if err != nil { return nil, err } @@ -223,7 +246,7 @@ func (c *Client) Exec() ([]Req, error) { c.reqs = make([]Req, 0) c.mtx.Unlock() - conn, _, err := c.ConnPool.Get() + conn, err := c.conn() if err != nil { return nil, err } diff --git a/redis_test.go b/redis_test.go index 9695272..eca596b 100644 --- a/redis_test.go +++ b/redis_test.go @@ -1,6 +1,8 @@ package redis_test import ( + "io" + "net" "strconv" "sync" "testing" @@ -24,7 +26,7 @@ func Test(t *testing.T) { TestingT(t) } //------------------------------------------------------------------------------ func (t *RedisTest) SetUpTest(c *C) { - t.client = redis.NewTCPClient(":6379", "", 0) + t.client = redis.NewTCPClient(":6379", "", -1) _, err := t.client.Flushdb().Reply() c.Check(err, IsNil) @@ -38,6 +40,24 @@ func (t *RedisTest) TearDownTest(c *C) { //------------------------------------------------------------------------------ +func (t *RedisTest) TestInitConn(c *C) { + openConn := func() (io.ReadWriter, error) { + return net.Dial("tcp", ":6379") + } + + isInitConnCalled := false + initConn := func(client *redis.Client) error { + isInitConnCalled = true + return nil + } + + client := redis.NewClient(openConn, nil, initConn) + pong, err := client.Ping().Reply() + c.Check(err, IsNil) + c.Check(pong, Equals, "PONG") + c.Check(isInitConnCalled, Equals, true) +} + func (t *RedisTest) TestRunWithMissingReplyPart(c *C) { req := t.client.Set("foo", "bar") @@ -1667,6 +1687,25 @@ func (t *RedisTest) BenchmarkRedisGet(c *C) { } } +func (t *RedisTest) BenchmarkRedisMGet(c *C) { + c.StopTimer() + + _, err := t.client.MSet("foo1", "bar1", "foo2", "bar2").Reply() + c.Check(err, IsNil) + + for i := 0; i < 10; i++ { + values, err := t.client.MGet("foo1", "foo2").Reply() + c.Check(err, IsNil) + c.Check(values, DeepEquals, []interface{}{"bar1", "bar2"}) + } + + c.StartTimer() + + for i := 0; i < c.N; i++ { + t.client.MGet("foo1", "foo2").Reply() + } +} + func (t *RedisTest) BenchmarkRedisWriteRead(c *C) { c.StopTimer()