Merge pull request #1572 from go-redis/fix/close-conn-on-context-timeout

Close the conn on context timeout
This commit is contained in:
Vladimir Mihailenco 2020-12-06 11:34:07 +02:00 committed by GitHub
commit 2e398ada86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 50 additions and 21 deletions

View File

@ -65,8 +65,11 @@ func isRedisError(err error) bool {
} }
func isBadConn(err error, allowTimeout bool) bool { func isBadConn(err error, allowTimeout bool) bool {
if err == nil { switch err {
case nil:
return false return false
case context.Canceled, context.DeadlineExceeded:
return true
} }
if isRedisError(err) { if isRedisError(err) {

View File

@ -175,8 +175,8 @@ func redisRingOptions() *redis.RingOptions {
func performAsync(n int, cbs ...func(int)) *sync.WaitGroup { func performAsync(n int, cbs ...func(int)) *sync.WaitGroup {
var wg sync.WaitGroup var wg sync.WaitGroup
for _, cb := range cbs { for _, cb := range cbs {
wg.Add(n)
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
wg.Add(1)
go func(cb func(int), i int) { go func(cb func(int), i int) {
defer GinkgoRecover() defer GinkgoRecover()
defer wg.Done() defer wg.Done()

View File

@ -2,6 +2,7 @@ package redis_test
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
@ -295,6 +296,25 @@ var _ = Describe("races", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
}) })
It("should abort on context timeout", func() {
opt := redisClusterOptions()
client := cluster.newClusterClient(ctx, opt)
ctx, cancel := context.WithCancel(context.Background())
wg := performAsync(C, func(_ int) {
_, err := client.XRead(ctx, &redis.XReadArgs{
Streams: []string{"test", "$"},
Block: 1 * time.Second,
}).Result()
Expect(err).To(Equal(context.Canceled))
})
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
})
}) })
var _ = Describe("cluster races", func() { var _ = Describe("cluster races", func() {

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"sync/atomic"
"time" "time"
"github.com/go-redis/redis/v8/internal" "github.com/go-redis/redis/v8/internal"
@ -130,20 +131,7 @@ func (hs hooks) processTxPipeline(
} }
func (hs hooks) withContext(ctx context.Context, fn func() error) error { func (hs hooks) withContext(ctx context.Context, fn func() error) error {
done := ctx.Done()
if done == nil {
return fn() return fn()
}
errc := make(chan error, 1)
go func() { errc <- fn() }()
select {
case <-done:
return ctx.Err()
case err := <-errc:
return err
}
} }
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
@ -316,8 +304,24 @@ func (c *baseClient) withConn(
c.releaseConn(ctx, cn, err) c.releaseConn(ctx, cn, err)
}() }()
done := ctx.Done()
if done == nil {
err = fn(ctx, cn) err = fn(ctx, cn)
return err return err
}
errc := make(chan error, 1)
go func() { errc <- fn(ctx, cn) }()
select {
case <-done:
_ = cn.Close()
err = ctx.Err()
return err
case err = <-errc:
return err
}
}) })
} }
@ -334,7 +338,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
} }
} }
retryTimeout := true retryTimeout := uint32(1)
err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmd) return writeCmd(wr, cmd)
@ -345,7 +349,9 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply) err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply)
if err != nil { if err != nil {
retryTimeout = cmd.readTimeout() == nil if cmd.readTimeout() == nil {
atomic.StoreUint32(&retryTimeout, 1)
}
return err return err
} }
@ -354,7 +360,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
if err == nil { if err == nil {
return nil return nil
} }
retry = shouldRetry(err, retryTimeout) retry = shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1)
return err return err
}) })
if err == nil || !retry { if err == nil || !retry {