forked from mirror/redis
Merge pull request #1572 from go-redis/fix/close-conn-on-context-timeout
Close the conn on context timeout
This commit is contained in:
commit
2e398ada86
5
error.go
5
error.go
|
@ -65,8 +65,11 @@ func isRedisError(err error) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func isBadConn(err error, allowTimeout bool) bool {
|
func isBadConn(err error, allowTimeout bool) bool {
|
||||||
if err == nil {
|
switch err {
|
||||||
|
case nil:
|
||||||
return false
|
return false
|
||||||
|
case context.Canceled, context.DeadlineExceeded:
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if isRedisError(err) {
|
if isRedisError(err) {
|
||||||
|
|
|
@ -175,8 +175,8 @@ func redisRingOptions() *redis.RingOptions {
|
||||||
func performAsync(n int, cbs ...func(int)) *sync.WaitGroup {
|
func performAsync(n int, cbs ...func(int)) *sync.WaitGroup {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for _, cb := range cbs {
|
for _, cb := range cbs {
|
||||||
|
wg.Add(n)
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
wg.Add(1)
|
|
||||||
go func(cb func(int), i int) {
|
go func(cb func(int), i int) {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
20
race_test.go
20
race_test.go
|
@ -2,6 +2,7 @@ package redis_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -295,6 +296,25 @@ var _ = Describe("races", func() {
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("should abort on context timeout", func() {
|
||||||
|
opt := redisClusterOptions()
|
||||||
|
client := cluster.newClusterClient(ctx, opt)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
wg := performAsync(C, func(_ int) {
|
||||||
|
_, err := client.XRead(ctx, &redis.XReadArgs{
|
||||||
|
Streams: []string{"test", "$"},
|
||||||
|
Block: 1 * time.Second,
|
||||||
|
}).Result()
|
||||||
|
Expect(err).To(Equal(context.Canceled))
|
||||||
|
})
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
wg.Wait()
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
var _ = Describe("cluster races", func() {
|
var _ = Describe("cluster races", func() {
|
||||||
|
|
38
redis.go
38
redis.go
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-redis/redis/v8/internal"
|
"github.com/go-redis/redis/v8/internal"
|
||||||
|
@ -130,20 +131,7 @@ func (hs hooks) processTxPipeline(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hs hooks) withContext(ctx context.Context, fn func() error) error {
|
func (hs hooks) withContext(ctx context.Context, fn func() error) error {
|
||||||
done := ctx.Done()
|
|
||||||
if done == nil {
|
|
||||||
return fn()
|
return fn()
|
||||||
}
|
|
||||||
|
|
||||||
errc := make(chan error, 1)
|
|
||||||
go func() { errc <- fn() }()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
return ctx.Err()
|
|
||||||
case err := <-errc:
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
|
@ -316,8 +304,24 @@ func (c *baseClient) withConn(
|
||||||
c.releaseConn(ctx, cn, err)
|
c.releaseConn(ctx, cn, err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
done := ctx.Done()
|
||||||
|
if done == nil {
|
||||||
err = fn(ctx, cn)
|
err = fn(ctx, cn)
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
errc := make(chan error, 1)
|
||||||
|
go func() { errc <- fn(ctx, cn) }()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
_ = cn.Close()
|
||||||
|
|
||||||
|
err = ctx.Err()
|
||||||
|
return err
|
||||||
|
case err = <-errc:
|
||||||
|
return err
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -334,7 +338,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
retryTimeout := true
|
retryTimeout := uint32(1)
|
||||||
err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
|
err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
|
||||||
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
|
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
|
||||||
return writeCmd(wr, cmd)
|
return writeCmd(wr, cmd)
|
||||||
|
@ -345,7 +349,9 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
|
||||||
|
|
||||||
err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply)
|
err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
retryTimeout = cmd.readTimeout() == nil
|
if cmd.readTimeout() == nil {
|
||||||
|
atomic.StoreUint32(&retryTimeout, 1)
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -354,7 +360,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
retry = shouldRetry(err, retryTimeout)
|
retry = shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
if err == nil || !retry {
|
if err == nil || !retry {
|
||||||
|
|
Loading…
Reference in New Issue