diff --git a/multi.go b/multi.go index 9ecd3dd..c05d5e1 100644 --- a/multi.go +++ b/multi.go @@ -19,6 +19,11 @@ func (c *Client) MultiClient() (*MultiClient, error) { }, nil } +func (c *MultiClient) Close() error { + c.Unwatch() + return c.Client.Close() +} + func (c *MultiClient) Watch(keys ...string) *StatusReq { args := append([]string{"WATCH"}, keys...) req := NewStatusReq(args...) diff --git a/pubsub.go b/pubsub.go index f49f89e..0b9c6b5 100644 --- a/pubsub.go +++ b/pubsub.go @@ -41,13 +41,18 @@ func (c *PubSubClient) consumeMessages(conn *Conn) { for { msg := &Message{} - replyI, err := req.ParseReply(conn.Rd) + replyIface, err := req.ParseReply(conn.Rd) if err != nil { msg.Err = err c.ch <- msg break } - reply := replyI.([]interface{}) + reply, ok := replyIface.([]interface{}) + if !ok { + msg.Err = fmt.Errorf("redis: unexpected reply type %T", replyIface) + c.ch <- msg + return + } msgName := reply[0].(string) switch msgName { diff --git a/redis_test.go b/redis_test.go index 711381f..f277fb0 100644 --- a/redis_test.go +++ b/redis_test.go @@ -39,33 +39,50 @@ func sortStrings(slice []string) []string { //------------------------------------------------------------------------------ func (t *RedisTest) SetUpTest(c *C) { + if t.client == nil { + openConn := func() (io.ReadWriteCloser, error) { + t.openedConnCount++ + return net.Dial("tcp", redisAddr) + } + initConn := func(c *redis.Client) error { + t.initedConnCount++ + return nil + } + closeConn := func(conn io.ReadWriteCloser) error { + t.closedConnCount++ + return nil + } + + t.client = redis.NewClient(openConn, closeConn, initConn) + t.client.ConnPool.(*redis.MultiConnPool).MaxCap = 10 + } + t.openedConnCount = 0 - openConn := func() (io.ReadWriteCloser, error) { - t.openedConnCount++ - return net.Dial("tcp", redisAddr) - } t.closedConnCount = 0 - closeConn := func(conn io.ReadWriteCloser) error { - t.closedConnCount++ - return nil - } t.initedConnCount = 0 - initConn := func(c *redis.Client) error { - t.initedConnCount++ - return nil - } - t.client = redis.NewClient(openConn, closeConn, initConn) - t.client.ConnPool.(*redis.MultiConnPool).MaxCap = 10 - c.Assert(t.client.FlushDb().Err(), IsNil) + t.resetRedis(c) } func (t *RedisTest) TearDownTest(c *C) { - c.Assert(t.client.FlushDb().Err(), IsNil) - c.Assert(t.client.Close(), IsNil) - c.Assert(t.openedConnCount, Equals, t.closedConnCount) + t.resetRedis(c) c.Assert(t.openedConnCount, Equals, t.initedConnCount) } +func (t *RedisTest) resetRedis(c *C) { + c.Assert(t.client.Select(1).Err(), IsNil) + c.Assert(t.client.FlushDb().Err(), IsNil) + + c.Assert(t.client.Select(0).Err(), IsNil) + c.Assert(t.client.FlushDb().Err(), IsNil) +} + +func (t *RedisTest) resetClient(c *C) { + c.Assert(t.client.Close(), IsNil) + t.openedConnCount = 0 + t.initedConnCount = 0 + t.closedConnCount = 0 +} + //------------------------------------------------------------------------------ func (t *RedisTest) TestRunWithouthCheckingErrVal(c *C) { @@ -106,6 +123,8 @@ func (t *RedisTest) TestGetBigVal(c *C) { //------------------------------------------------------------------------------ func (t *RedisTest) TestConnPoolMaxCap(c *C) { + t.resetClient(c) + wg := &sync.WaitGroup{} for i := 0; i < 1000; i++ { wg.Add(1) @@ -124,6 +143,8 @@ func (t *RedisTest) TestConnPoolMaxCap(c *C) { } func (t *RedisTest) TestConnPoolMaxCapOnPipelineClient(c *C) { + t.resetClient(c) + wg := &sync.WaitGroup{} for i := 0; i < 1000; i++ { wg.Add(1) @@ -151,6 +172,8 @@ func (t *RedisTest) TestConnPoolMaxCapOnPipelineClient(c *C) { } func (t *RedisTest) TestConnPoolMaxCapOnMultiClient(c *C) { + t.resetClient(c) + wg := &sync.WaitGroup{} for i := 0; i < 1000; i++ { wg.Add(1) @@ -180,6 +203,8 @@ func (t *RedisTest) TestConnPoolMaxCapOnMultiClient(c *C) { } func (t *RedisTest) TestConnPoolMaxCapOnPubSubClient(c *C) { + t.resetClient(c) + wg := &sync.WaitGroup{} for i := 0; i < 1000; i++ { wg.Add(1) @@ -2477,9 +2502,12 @@ func (t *RedisTest) TestCmdBgRewriteAOF(c *C) { } func (t *RedisTest) TestCmdBgSave(c *C) { + // workaround for "ERR Can't BGSAVE while AOF log rewriting is in progress" + time.Sleep(time.Second) + r := t.client.BgSave() - c.Assert(r.Err(), ErrorMatches, "ERR Can't BGSAVE while AOF log rewriting is in progress") - c.Assert(r.Val(), Equals, "") + c.Assert(r.Err(), IsNil) + c.Assert(r.Val(), Equals, "Background saving started") } func (t *RedisTest) TestCmdClientKill(c *C) { @@ -2662,7 +2690,7 @@ func (t *RedisTest) BenchmarkRedisMGet(c *C) { for i := 0; i < 10; i++ { mGet := t.client.MGet("key1", "key2") c.Assert(mGet.Err(), IsNil) - c.Assert(mGet.Val(), DeepEquals, []string{"hello1", "hello2"}) + c.Assert(mGet.Val(), DeepEquals, []interface{}{"hello1", "hello2"}) } c.StartTimer()