Add support for connection initialisation.

This commit is contained in:
Vladimir Mihailenco 2012-08-06 11:33:49 +03:00
parent 4e6fa48b48
commit c5c8ec6b0c
4 changed files with 116 additions and 20 deletions

View File

@ -11,7 +11,10 @@ Example 1:
import "github.com/vmihailenco/redis" import "github.com/vmihailenco/redis"
redisClient := redis.NewTCPClient(":6379", "", 0) address := ":6379"
password := "secret"
db := 0
redisClient := redis.NewTCPClient(address, password, db)
Example 2: Example 2:
@ -29,7 +32,7 @@ Example 2:
return nil return nil
} }
initConn := func(client *Client) error { initConn := func(client *redis.Client) error {
_, err := client.Auth("foo").Reply() _, err := client.Auth("foo").Reply()
if err != nil { if err != nil {
return err return err
@ -45,7 +48,7 @@ Example 2:
redisClient := redis.NewClient(openConn, closeConn, initConn) redisClient := redis.NewClient(openConn, closeConn, initConn)
`closeConn` and `initConn` functions can be `nil`. Both `closeConn` and `initConn` functions can be `nil`.
Running commands Running commands
---------------- ----------------
@ -199,4 +202,4 @@ Connection pool
Client uses connection pool with default capacity of 10 connections. To change pool capacity: Client uses connection pool with default capacity of 10 connections. To change pool capacity:
redisClient.ConnPool.MaxCap = 1 redisClient.ConnPool.(*redis.MultiConnPool).MaxCap = 1

View File

@ -21,7 +21,16 @@ func NewConn(rw io.ReadWriter) *Conn {
} }
} }
type ConnPool struct { type ConnPool interface {
Get() (*Conn, bool, error)
Add(*Conn)
Remove(*Conn)
Len() int
}
//------------------------------------------------------------------------------
type MultiConnPool struct {
Logger *log.Logger Logger *log.Logger
cond *sync.Cond cond *sync.Cond
conns []*Conn conns []*Conn
@ -30,13 +39,13 @@ type ConnPool struct {
cap, MaxCap int64 cap, MaxCap int64
} }
func NewConnPool(openConn OpenConnFunc, closeConn CloseConnFunc, maxCap int64) *ConnPool { func NewMultiConnPool(openConn OpenConnFunc, closeConn CloseConnFunc, maxCap int64) *MultiConnPool {
logger := log.New( logger := log.New(
os.Stdout, os.Stdout,
"redis.connpool: ", "redis.connpool: ",
log.Ldate|log.Ltime|log.Lshortfile, log.Ldate|log.Ltime|log.Lshortfile,
) )
return &ConnPool{ return &MultiConnPool{
cond: sync.NewCond(&sync.Mutex{}), cond: sync.NewCond(&sync.Mutex{}),
Logger: logger, Logger: logger,
conns: make([]*Conn, 0), conns: make([]*Conn, 0),
@ -46,7 +55,7 @@ func NewConnPool(openConn OpenConnFunc, closeConn CloseConnFunc, maxCap int64) *
} }
} }
func (p *ConnPool) Get() (*Conn, bool, error) { func (p *MultiConnPool) Get() (*Conn, bool, error) {
p.cond.L.Lock() p.cond.L.Lock()
defer p.cond.L.Unlock() defer p.cond.L.Unlock()
@ -72,14 +81,14 @@ func (p *ConnPool) Get() (*Conn, bool, error) {
return conn, false, nil return conn, false, nil
} }
func (p *ConnPool) Add(conn *Conn) { func (p *MultiConnPool) Add(conn *Conn) {
p.cond.L.Lock() p.cond.L.Lock()
defer p.cond.L.Unlock() defer p.cond.L.Unlock()
p.conns = append(p.conns, conn) p.conns = append(p.conns, conn)
p.cond.Signal() p.cond.Signal()
} }
func (p *ConnPool) Remove(conn *Conn) { func (p *MultiConnPool) Remove(conn *Conn) {
p.cond.L.Lock() p.cond.L.Lock()
p.cap-- p.cap--
p.cond.Signal() p.cond.Signal()
@ -90,6 +99,28 @@ func (p *ConnPool) Remove(conn *Conn) {
} }
} }
func (p *ConnPool) Len() int { func (p *MultiConnPool) Len() int {
return len(p.conns) return len(p.conns)
} }
//------------------------------------------------------------------------------
type OneConnPool struct {
conn *Conn
}
func NewOneConnPool(conn *Conn) *OneConnPool {
return &OneConnPool{conn: conn}
}
func (p *OneConnPool) Get() (*Conn, bool, error) {
return p.conn, false, nil
}
func (p *OneConnPool) Add(conn *Conn) {}
func (p *OneConnPool) Remove(conn *Conn) {}
func (p *OneConnPool) Len() int {
return 1
}

View File

@ -28,6 +28,10 @@ func TLSConnector(addr string, tlsConfig *tls.Config) OpenConnFunc {
} }
func AuthSelectFunc(password string, db int64) InitConnFunc { func AuthSelectFunc(password string, db int64) InitConnFunc {
if password == "" && db < 0 {
return nil
}
return func(client *Client) error { return func(client *Client) error {
if password != "" { if password != "" {
_, err := client.Auth(password).Reply() _, err := client.Auth(password).Reply()
@ -36,10 +40,12 @@ func AuthSelectFunc(password string, db int64) InitConnFunc {
} }
} }
if db >= 0 {
_, err := client.Select(db).Reply() _, err := client.Select(db).Reply()
if err != nil { if err != nil {
return err return err
} }
}
return nil return nil
} }
@ -51,7 +57,7 @@ func createReader() (*bufreader.Reader, error) {
type Client struct { type Client struct {
mtx sync.Mutex mtx sync.Mutex
ConnPool *ConnPool ConnPool ConnPool
InitConn InitConnFunc InitConn InitConnFunc
reqs []Req reqs []Req
@ -59,7 +65,7 @@ type Client struct {
func NewClient(openConn OpenConnFunc, closeConn CloseConnFunc, initConn InitConnFunc) *Client { func NewClient(openConn OpenConnFunc, closeConn CloseConnFunc, initConn InitConnFunc) *Client {
return &Client{ return &Client{
ConnPool: NewConnPool(openConn, closeConn, 10), ConnPool: NewMultiConnPool(openConn, closeConn, 10),
InitConn: initConn, InitConn: initConn,
} }
} }
@ -76,6 +82,23 @@ func NewTLSClient(addr string, tlsConfig *tls.Config, password string, db int64)
) )
} }
func (c *Client) conn() (*Conn, error) {
conn, isNew, err := c.ConnPool.Get()
if err != nil {
return nil, err
}
if isNew && c.InitConn != nil {
client := &Client{
ConnPool: NewOneConnPool(conn),
}
err = c.InitConn(client)
if err != nil {
return nil, err
}
}
return conn, nil
}
func (c *Client) WriteReq(buf []byte, conn *Conn) error { func (c *Client) WriteReq(buf []byte, conn *Conn) error {
_, err := conn.RW.Write(buf) _, err := conn.RW.Write(buf)
return err return err
@ -120,7 +143,7 @@ func (c *Client) Queue(req Req) {
} }
func (c *Client) Run(req Req) { func (c *Client) Run(req Req) {
conn, _, err := c.ConnPool.Get() conn, err := c.conn()
if err != nil { if err != nil {
req.SetErr(err) req.SetErr(err)
return return
@ -154,7 +177,7 @@ func (c *Client) RunQueued() ([]Req, error) {
c.reqs = make([]Req, 0) c.reqs = make([]Req, 0)
c.mtx.Unlock() c.mtx.Unlock()
conn, _, err := c.ConnPool.Get() conn, err := c.conn()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -223,7 +246,7 @@ func (c *Client) Exec() ([]Req, error) {
c.reqs = make([]Req, 0) c.reqs = make([]Req, 0)
c.mtx.Unlock() c.mtx.Unlock()
conn, _, err := c.ConnPool.Get() conn, err := c.conn()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,6 +1,8 @@
package redis_test package redis_test
import ( import (
"io"
"net"
"strconv" "strconv"
"sync" "sync"
"testing" "testing"
@ -24,7 +26,7 @@ func Test(t *testing.T) { TestingT(t) }
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
func (t *RedisTest) SetUpTest(c *C) { func (t *RedisTest) SetUpTest(c *C) {
t.client = redis.NewTCPClient(":6379", "", 0) t.client = redis.NewTCPClient(":6379", "", -1)
_, err := t.client.Flushdb().Reply() _, err := t.client.Flushdb().Reply()
c.Check(err, IsNil) c.Check(err, IsNil)
@ -38,6 +40,24 @@ func (t *RedisTest) TearDownTest(c *C) {
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
func (t *RedisTest) TestInitConn(c *C) {
openConn := func() (io.ReadWriter, error) {
return net.Dial("tcp", ":6379")
}
isInitConnCalled := false
initConn := func(client *redis.Client) error {
isInitConnCalled = true
return nil
}
client := redis.NewClient(openConn, nil, initConn)
pong, err := client.Ping().Reply()
c.Check(err, IsNil)
c.Check(pong, Equals, "PONG")
c.Check(isInitConnCalled, Equals, true)
}
func (t *RedisTest) TestRunWithMissingReplyPart(c *C) { func (t *RedisTest) TestRunWithMissingReplyPart(c *C) {
req := t.client.Set("foo", "bar") req := t.client.Set("foo", "bar")
@ -1667,6 +1687,25 @@ func (t *RedisTest) BenchmarkRedisGet(c *C) {
} }
} }
func (t *RedisTest) BenchmarkRedisMGet(c *C) {
c.StopTimer()
_, err := t.client.MSet("foo1", "bar1", "foo2", "bar2").Reply()
c.Check(err, IsNil)
for i := 0; i < 10; i++ {
values, err := t.client.MGet("foo1", "foo2").Reply()
c.Check(err, IsNil)
c.Check(values, DeepEquals, []interface{}{"bar1", "bar2"})
}
c.StartTimer()
for i := 0; i < c.N; i++ {
t.client.MGet("foo1", "foo2").Reply()
}
}
func (t *RedisTest) BenchmarkRedisWriteRead(c *C) { func (t *RedisTest) BenchmarkRedisWriteRead(c *C) {
c.StopTimer() c.StopTimer()