ConnPool check fd for bad conns (#1824)

* conncheck for badconn (#1821)

* format imports

* fix ut: pool with badconn

* fix unstable ut: should facilitate failover

* Revert "fix unstable ut: should facilitate failover"

This reverts commit c7eeca2a5c.

* fix test error

Signed-off-by: monkey92t <golang@88.com>

Co-authored-by: hidu <duv123+github@gmail.com>
Co-authored-by: monkey92t <golang@88.com>
This commit is contained in:
do it 2021-07-20 17:23:48 +08:00 committed by GitHub
parent 62fc2c821b
commit 346bfafddd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 204 additions and 18 deletions

View File

@ -0,0 +1,45 @@
// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos
package pool
import (
"errors"
"io"
"net"
"syscall"
)
var errUnexpectedRead = errors.New("unexpected read from socket")
func connCheck(conn net.Conn) error {
sysConn, ok := conn.(syscall.Conn)
if !ok {
return nil
}
rawConn, err := sysConn.SyscallConn()
if err != nil {
return err
}
var sysErr error
err = rawConn.Read(func(fd uintptr) bool {
var buf [1]byte
n, err := syscall.Read(int(fd), buf[:])
switch {
case n == 0 && err == nil:
sysErr = io.EOF
case n > 0:
sysErr = errUnexpectedRead
case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK:
sysErr = nil
default:
sysErr = err
}
return true
})
if err != nil {
return err
}
return sysErr
}

View File

@ -0,0 +1,9 @@
// +build !linux,!darwin,!dragonfly,!freebsd,!netbsd,!openbsd,!solaris,!illumos
package pool
import "net"
func connCheck(conn net.Conn) error {
return nil
}

View File

@ -0,0 +1,46 @@
// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos
package pool
import (
"net"
"net/http/httptest"
"testing"
"time"
)
func Test_connCheck(t *testing.T) {
// tests with real conns
ts := httptest.NewServer(nil)
defer ts.Close()
t.Run("good conn", func(t *testing.T) {
conn, err := net.DialTimeout(ts.Listener.Addr().Network(), ts.Listener.Addr().String(), time.Second)
if err != nil {
t.Fatalf(err.Error())
}
defer conn.Close()
if err = connCheck(conn); err != nil {
t.Fatalf(err.Error())
}
conn.Close()
if err = connCheck(conn); err == nil {
t.Fatalf("expect has error")
}
})
t.Run("bad conn 2", func(t *testing.T) {
conn, err := net.DialTimeout(ts.Listener.Addr().Network(), ts.Listener.Addr().String(), time.Second)
if err != nil {
t.Fatalf(err.Error())
}
defer conn.Close()
ts.Close()
if err = connCheck(conn); err == nil {
t.Fatalf("expect has err")
}
})
}

View File

@ -2,9 +2,12 @@ package pool_test
import ( import (
"context" "context"
"fmt"
"net" "net"
"sync" "sync"
"syscall"
"testing" "testing"
"time"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -32,5 +35,87 @@ func perform(n int, cbs ...func(int)) {
} }
func dummyDialer(context.Context) (net.Conn, error) { func dummyDialer(context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil // return &net.TCPConn{}, nil
return newDummyConn(), nil
}
func newDummyConn() net.Conn {
return &dummyConn{
rawConn: &dummyRawConn{},
}
}
var _ net.Conn = (*dummyConn)(nil)
var _ syscall.Conn = (*dummyConn)(nil)
type dummyConn struct {
rawConn *dummyRawConn
}
func (d *dummyConn) SyscallConn() (syscall.RawConn, error) {
return d.rawConn, nil
}
var errDummy = fmt.Errorf("dummyConn err")
func (d *dummyConn) Read(b []byte) (n int, err error) {
return 0, errDummy
}
func (d *dummyConn) Write(b []byte) (n int, err error) {
return 0, errDummy
}
func (d *dummyConn) Close() error {
d.rawConn.Close()
return nil
}
func (d *dummyConn) LocalAddr() net.Addr {
return &net.TCPAddr{}
}
func (d *dummyConn) RemoteAddr() net.Addr {
return &net.TCPAddr{}
}
func (d *dummyConn) SetDeadline(t time.Time) error {
return nil
}
func (d *dummyConn) SetReadDeadline(t time.Time) error {
return nil
}
func (d *dummyConn) SetWriteDeadline(t time.Time) error {
return nil
}
var _ syscall.RawConn = (*dummyRawConn)(nil)
type dummyRawConn struct {
closed bool
mux sync.Mutex
}
func (d *dummyRawConn) Control(f func(fd uintptr)) error {
return nil
}
func (d *dummyRawConn) Read(f func(fd uintptr) (done bool)) error {
d.mux.Lock()
defer d.mux.Unlock()
if d.closed {
return fmt.Errorf("dummyRawConn closed")
}
return nil
}
func (d *dummyRawConn) Write(f func(fd uintptr) (done bool)) error {
return nil
}
func (d *dummyRawConn) Close() {
d.mux.Lock()
d.closed = true
d.mux.Unlock()
} }

View File

@ -520,7 +520,7 @@ func (p *ConnPool) reapStaleConn() *Conn {
func (p *ConnPool) isStaleConn(cn *Conn) bool { func (p *ConnPool) isStaleConn(cn *Conn) bool {
if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 { if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 {
return false return connCheck(cn.netConn) != nil
} }
now := time.Now() now := time.Now()
@ -531,5 +531,5 @@ func (p *ConnPool) isStaleConn(cn *Conn) bool {
return true return true
} }
return false return connCheck(cn.netConn) != nil
} }

View File

@ -6,10 +6,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/go-redis/redis/v8/internal/pool"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"github.com/go-redis/redis/v8/internal/pool"
) )
var _ = Describe("ConnPool", func() { var _ = Describe("ConnPool", func() {
@ -285,6 +285,8 @@ var _ = Describe("conns reaper", func() {
cn.SetUsedAt(time.Now().Add(-2 * idleTimeout)) cn.SetUsedAt(time.Now().Add(-2 * idleTimeout))
case "aged": case "aged":
cn.SetCreatedAt(time.Now().Add(-2 * maxAge)) cn.SetCreatedAt(time.Now().Add(-2 * maxAge))
case "conncheck":
cn.Close()
} }
conns = append(conns, cn) conns = append(conns, cn)
staleConns = append(staleConns, cn) staleConns = append(staleConns, cn)
@ -371,6 +373,7 @@ var _ = Describe("conns reaper", func() {
assert("idle") assert("idle")
assert("aged") assert("aged")
assert("conncheck")
}) })
var _ = Describe("race", func() { var _ = Describe("race", func() {

View File

@ -12,10 +12,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/go-redis/redis/v8"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"github.com/go-redis/redis/v8"
) )
const ( const (
@ -117,7 +117,7 @@ func TestGinkgoSuite(t *testing.T) {
RunSpecs(t, "go-redis") RunSpecs(t, "go-redis")
} }
//------------------------------------------------------------------------------ // ------------------------------------------------------------------------------
func redisOptions() *redis.Options { func redisOptions() *redis.Options {
return &redis.Options{ return &redis.Options{
@ -364,7 +364,7 @@ func startSentinel(port, masterName, masterPort string) (*redisProcess, error) {
return p, nil return p, nil
} }
//------------------------------------------------------------------------------ // ------------------------------------------------------------------------------
type badConnError string type badConnError string
@ -409,7 +409,7 @@ func (cn *badConn) Write([]byte) (int, error) {
return 0, badConnError("bad connection") return 0, badConnError("bad connection")
} }
//------------------------------------------------------------------------------ // ------------------------------------------------------------------------------
type hook struct { type hook struct {
beforeProcess func(ctx context.Context, cmd redis.Cmder) (context.Context, error) beforeProcess func(ctx context.Context, cmd redis.Cmder) (context.Context, error)

View File

@ -87,8 +87,9 @@ var _ = Describe("pool", func() {
cn.SetNetConn(&badConn{}) cn.SetNetConn(&badConn{})
client.Pool().Put(ctx, cn) client.Pool().Put(ctx, cn)
// connCheck will automatically remove damaged connections.
err = client.Ping(ctx).Err() err = client.Ping(ctx).Err()
Expect(err).To(MatchError("bad connection")) Expect(err).NotTo(HaveOccurred())
val, err := client.Ping(ctx).Result() val, err := client.Ping(ctx).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())

View File

@ -191,7 +191,7 @@ var _ = Describe("NewFailoverClusterClient", func() {
err = master.Shutdown(ctx).Err() err = master.Shutdown(ctx).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Eventually(func() error { Eventually(func() error {
return sentinelMaster.Ping(ctx).Err() return master.Ping(ctx).Err()
}, "15s", "100ms").Should(HaveOccurred()) }, "15s", "100ms").Should(HaveOccurred())
// Check that client picked up new master. // Check that client picked up new master.

View File

@ -123,7 +123,7 @@ var _ = Describe("Tx", func() {
Expect(num).To(Equal(int64(N))) Expect(num).To(Equal(int64(N)))
}) })
It("should recover from bad connection", func() { It("should remove from bad connection", func() {
// Put bad connection in the pool. // Put bad connection in the pool.
cn, err := client.Pool().Get(context.Background()) cn, err := client.Pool().Get(context.Background())
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -134,17 +134,14 @@ var _ = Describe("Tx", func() {
do := func() error { do := func() error {
err := client.Watch(ctx, func(tx *redis.Tx) error { err := client.Watch(ctx, func(tx *redis.Tx) error {
_, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { _, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Ping(ctx) return pipe.Ping(ctx).Err()
return nil
}) })
return err return err
}) })
return err return err
} }
err = do() // connCheck will automatically remove damaged connections.
Expect(err).To(MatchError("bad connection"))
err = do() err = do()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })