forked from mirror/redis
Add support for connection initialisation.
This commit is contained in:
parent
4e6fa48b48
commit
c5c8ec6b0c
11
README.md
11
README.md
|
@ -11,7 +11,10 @@ Example 1:
|
||||||
import "github.com/vmihailenco/redis"
|
import "github.com/vmihailenco/redis"
|
||||||
|
|
||||||
|
|
||||||
redisClient := redis.NewTCPClient(":6379", "", 0)
|
address := ":6379"
|
||||||
|
password := "secret"
|
||||||
|
db := 0
|
||||||
|
redisClient := redis.NewTCPClient(address, password, db)
|
||||||
|
|
||||||
Example 2:
|
Example 2:
|
||||||
|
|
||||||
|
@ -29,7 +32,7 @@ Example 2:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
initConn := func(client *Client) error {
|
initConn := func(client *redis.Client) error {
|
||||||
_, err := client.Auth("foo").Reply()
|
_, err := client.Auth("foo").Reply()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -45,7 +48,7 @@ Example 2:
|
||||||
|
|
||||||
redisClient := redis.NewClient(openConn, closeConn, initConn)
|
redisClient := redis.NewClient(openConn, closeConn, initConn)
|
||||||
|
|
||||||
`closeConn` and `initConn` functions can be `nil`.
|
Both `closeConn` and `initConn` functions can be `nil`.
|
||||||
|
|
||||||
Running commands
|
Running commands
|
||||||
----------------
|
----------------
|
||||||
|
@ -199,4 +202,4 @@ Connection pool
|
||||||
|
|
||||||
Client uses connection pool with default capacity of 10 connections. To change pool capacity:
|
Client uses connection pool with default capacity of 10 connections. To change pool capacity:
|
||||||
|
|
||||||
redisClient.ConnPool.MaxCap = 1
|
redisClient.ConnPool.(*redis.MultiConnPool).MaxCap = 1
|
||||||
|
|
45
connpool.go
45
connpool.go
|
@ -21,7 +21,16 @@ func NewConn(rw io.ReadWriter) *Conn {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConnPool struct {
|
type ConnPool interface {
|
||||||
|
Get() (*Conn, bool, error)
|
||||||
|
Add(*Conn)
|
||||||
|
Remove(*Conn)
|
||||||
|
Len() int
|
||||||
|
}
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type MultiConnPool struct {
|
||||||
Logger *log.Logger
|
Logger *log.Logger
|
||||||
cond *sync.Cond
|
cond *sync.Cond
|
||||||
conns []*Conn
|
conns []*Conn
|
||||||
|
@ -30,13 +39,13 @@ type ConnPool struct {
|
||||||
cap, MaxCap int64
|
cap, MaxCap int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnPool(openConn OpenConnFunc, closeConn CloseConnFunc, maxCap int64) *ConnPool {
|
func NewMultiConnPool(openConn OpenConnFunc, closeConn CloseConnFunc, maxCap int64) *MultiConnPool {
|
||||||
logger := log.New(
|
logger := log.New(
|
||||||
os.Stdout,
|
os.Stdout,
|
||||||
"redis.connpool: ",
|
"redis.connpool: ",
|
||||||
log.Ldate|log.Ltime|log.Lshortfile,
|
log.Ldate|log.Ltime|log.Lshortfile,
|
||||||
)
|
)
|
||||||
return &ConnPool{
|
return &MultiConnPool{
|
||||||
cond: sync.NewCond(&sync.Mutex{}),
|
cond: sync.NewCond(&sync.Mutex{}),
|
||||||
Logger: logger,
|
Logger: logger,
|
||||||
conns: make([]*Conn, 0),
|
conns: make([]*Conn, 0),
|
||||||
|
@ -46,7 +55,7 @@ func NewConnPool(openConn OpenConnFunc, closeConn CloseConnFunc, maxCap int64) *
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConnPool) Get() (*Conn, bool, error) {
|
func (p *MultiConnPool) Get() (*Conn, bool, error) {
|
||||||
p.cond.L.Lock()
|
p.cond.L.Lock()
|
||||||
defer p.cond.L.Unlock()
|
defer p.cond.L.Unlock()
|
||||||
|
|
||||||
|
@ -72,14 +81,14 @@ func (p *ConnPool) Get() (*Conn, bool, error) {
|
||||||
return conn, false, nil
|
return conn, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConnPool) Add(conn *Conn) {
|
func (p *MultiConnPool) Add(conn *Conn) {
|
||||||
p.cond.L.Lock()
|
p.cond.L.Lock()
|
||||||
defer p.cond.L.Unlock()
|
defer p.cond.L.Unlock()
|
||||||
p.conns = append(p.conns, conn)
|
p.conns = append(p.conns, conn)
|
||||||
p.cond.Signal()
|
p.cond.Signal()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConnPool) Remove(conn *Conn) {
|
func (p *MultiConnPool) Remove(conn *Conn) {
|
||||||
p.cond.L.Lock()
|
p.cond.L.Lock()
|
||||||
p.cap--
|
p.cap--
|
||||||
p.cond.Signal()
|
p.cond.Signal()
|
||||||
|
@ -90,6 +99,28 @@ func (p *ConnPool) Remove(conn *Conn) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConnPool) Len() int {
|
func (p *MultiConnPool) Len() int {
|
||||||
return len(p.conns)
|
return len(p.conns)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type OneConnPool struct {
|
||||||
|
conn *Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOneConnPool(conn *Conn) *OneConnPool {
|
||||||
|
return &OneConnPool{conn: conn}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OneConnPool) Get() (*Conn, bool, error) {
|
||||||
|
return p.conn, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OneConnPool) Add(conn *Conn) {}
|
||||||
|
|
||||||
|
func (p *OneConnPool) Remove(conn *Conn) {}
|
||||||
|
|
||||||
|
func (p *OneConnPool) Len() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
39
redis.go
39
redis.go
|
@ -28,6 +28,10 @@ func TLSConnector(addr string, tlsConfig *tls.Config) OpenConnFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthSelectFunc(password string, db int64) InitConnFunc {
|
func AuthSelectFunc(password string, db int64) InitConnFunc {
|
||||||
|
if password == "" && db < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
return func(client *Client) error {
|
return func(client *Client) error {
|
||||||
if password != "" {
|
if password != "" {
|
||||||
_, err := client.Auth(password).Reply()
|
_, err := client.Auth(password).Reply()
|
||||||
|
@ -36,9 +40,11 @@ func AuthSelectFunc(password string, db int64) InitConnFunc {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := client.Select(db).Reply()
|
if db >= 0 {
|
||||||
if err != nil {
|
_, err := client.Select(db).Reply()
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -51,7 +57,7 @@ func createReader() (*bufreader.Reader, error) {
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
mtx sync.Mutex
|
mtx sync.Mutex
|
||||||
ConnPool *ConnPool
|
ConnPool ConnPool
|
||||||
InitConn InitConnFunc
|
InitConn InitConnFunc
|
||||||
|
|
||||||
reqs []Req
|
reqs []Req
|
||||||
|
@ -59,7 +65,7 @@ type Client struct {
|
||||||
|
|
||||||
func NewClient(openConn OpenConnFunc, closeConn CloseConnFunc, initConn InitConnFunc) *Client {
|
func NewClient(openConn OpenConnFunc, closeConn CloseConnFunc, initConn InitConnFunc) *Client {
|
||||||
return &Client{
|
return &Client{
|
||||||
ConnPool: NewConnPool(openConn, closeConn, 10),
|
ConnPool: NewMultiConnPool(openConn, closeConn, 10),
|
||||||
InitConn: initConn,
|
InitConn: initConn,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -76,6 +82,23 @@ func NewTLSClient(addr string, tlsConfig *tls.Config, password string, db int64)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) conn() (*Conn, error) {
|
||||||
|
conn, isNew, err := c.ConnPool.Get()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if isNew && c.InitConn != nil {
|
||||||
|
client := &Client{
|
||||||
|
ConnPool: NewOneConnPool(conn),
|
||||||
|
}
|
||||||
|
err = c.InitConn(client)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) WriteReq(buf []byte, conn *Conn) error {
|
func (c *Client) WriteReq(buf []byte, conn *Conn) error {
|
||||||
_, err := conn.RW.Write(buf)
|
_, err := conn.RW.Write(buf)
|
||||||
return err
|
return err
|
||||||
|
@ -120,7 +143,7 @@ func (c *Client) Queue(req Req) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Run(req Req) {
|
func (c *Client) Run(req Req) {
|
||||||
conn, _, err := c.ConnPool.Get()
|
conn, err := c.conn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
req.SetErr(err)
|
req.SetErr(err)
|
||||||
return
|
return
|
||||||
|
@ -154,7 +177,7 @@ func (c *Client) RunQueued() ([]Req, error) {
|
||||||
c.reqs = make([]Req, 0)
|
c.reqs = make([]Req, 0)
|
||||||
c.mtx.Unlock()
|
c.mtx.Unlock()
|
||||||
|
|
||||||
conn, _, err := c.ConnPool.Get()
|
conn, err := c.conn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -223,7 +246,7 @@ func (c *Client) Exec() ([]Req, error) {
|
||||||
c.reqs = make([]Req, 0)
|
c.reqs = make([]Req, 0)
|
||||||
c.mtx.Unlock()
|
c.mtx.Unlock()
|
||||||
|
|
||||||
conn, _, err := c.ConnPool.Get()
|
conn, err := c.conn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package redis_test
|
package redis_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -24,7 +26,7 @@ func Test(t *testing.T) { TestingT(t) }
|
||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
|
|
||||||
func (t *RedisTest) SetUpTest(c *C) {
|
func (t *RedisTest) SetUpTest(c *C) {
|
||||||
t.client = redis.NewTCPClient(":6379", "", 0)
|
t.client = redis.NewTCPClient(":6379", "", -1)
|
||||||
_, err := t.client.Flushdb().Reply()
|
_, err := t.client.Flushdb().Reply()
|
||||||
c.Check(err, IsNil)
|
c.Check(err, IsNil)
|
||||||
|
|
||||||
|
@ -38,6 +40,24 @@ func (t *RedisTest) TearDownTest(c *C) {
|
||||||
|
|
||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func (t *RedisTest) TestInitConn(c *C) {
|
||||||
|
openConn := func() (io.ReadWriter, error) {
|
||||||
|
return net.Dial("tcp", ":6379")
|
||||||
|
}
|
||||||
|
|
||||||
|
isInitConnCalled := false
|
||||||
|
initConn := func(client *redis.Client) error {
|
||||||
|
isInitConnCalled = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
client := redis.NewClient(openConn, nil, initConn)
|
||||||
|
pong, err := client.Ping().Reply()
|
||||||
|
c.Check(err, IsNil)
|
||||||
|
c.Check(pong, Equals, "PONG")
|
||||||
|
c.Check(isInitConnCalled, Equals, true)
|
||||||
|
}
|
||||||
|
|
||||||
func (t *RedisTest) TestRunWithMissingReplyPart(c *C) {
|
func (t *RedisTest) TestRunWithMissingReplyPart(c *C) {
|
||||||
req := t.client.Set("foo", "bar")
|
req := t.client.Set("foo", "bar")
|
||||||
|
|
||||||
|
@ -1667,6 +1687,25 @@ func (t *RedisTest) BenchmarkRedisGet(c *C) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *RedisTest) BenchmarkRedisMGet(c *C) {
|
||||||
|
c.StopTimer()
|
||||||
|
|
||||||
|
_, err := t.client.MSet("foo1", "bar1", "foo2", "bar2").Reply()
|
||||||
|
c.Check(err, IsNil)
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
values, err := t.client.MGet("foo1", "foo2").Reply()
|
||||||
|
c.Check(err, IsNil)
|
||||||
|
c.Check(values, DeepEquals, []interface{}{"bar1", "bar2"})
|
||||||
|
}
|
||||||
|
|
||||||
|
c.StartTimer()
|
||||||
|
|
||||||
|
for i := 0; i < c.N; i++ {
|
||||||
|
t.client.MGet("foo1", "foo2").Reply()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *RedisTest) BenchmarkRedisWriteRead(c *C) {
|
func (t *RedisTest) BenchmarkRedisWriteRead(c *C) {
|
||||||
c.StopTimer()
|
c.StopTimer()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue