diff --git a/internal/pool/conncheck.go b/internal/pool/conncheck.go new file mode 100644 index 0000000..5dd60df --- /dev/null +++ b/internal/pool/conncheck.go @@ -0,0 +1,45 @@ +// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos + +package pool + +import ( + "errors" + "io" + "net" + "syscall" +) + +var errUnexpectedRead = errors.New("unexpected read from socket") + +func connCheck(conn net.Conn) error { + sysConn, ok := conn.(syscall.Conn) + if !ok { + return nil + } + rawConn, err := sysConn.SyscallConn() + if err != nil { + return err + } + + var sysErr error + err = rawConn.Read(func(fd uintptr) bool { + var buf [1]byte + n, err := syscall.Read(int(fd), buf[:]) + switch { + case n == 0 && err == nil: + sysErr = io.EOF + case n > 0: + sysErr = errUnexpectedRead + case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK: + sysErr = nil + default: + sysErr = err + } + return true + }) + if err != nil { + return err + } + + return sysErr +} diff --git a/internal/pool/conncheck_dummy.go b/internal/pool/conncheck_dummy.go new file mode 100644 index 0000000..1daf986 --- /dev/null +++ b/internal/pool/conncheck_dummy.go @@ -0,0 +1,9 @@ +// +build !linux,!darwin,!dragonfly,!freebsd,!netbsd,!openbsd,!solaris,!illumos + +package pool + +import "net" + +func connCheck(conn net.Conn) error { + return nil +} diff --git a/internal/pool/conncheck_test.go b/internal/pool/conncheck_test.go new file mode 100644 index 0000000..0332174 --- /dev/null +++ b/internal/pool/conncheck_test.go @@ -0,0 +1,46 @@ +// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos + +package pool + +import ( + "net" + "net/http/httptest" + "testing" + "time" +) + +func Test_connCheck(t *testing.T) { + // tests with real conns + ts := httptest.NewServer(nil) + defer ts.Close() + + t.Run("good conn", func(t *testing.T) { + conn, err := net.DialTimeout(ts.Listener.Addr().Network(), ts.Listener.Addr().String(), time.Second) + if err != nil { + t.Fatalf(err.Error()) + } + defer conn.Close() + if err = connCheck(conn); err != nil { + t.Fatalf(err.Error()) + } + conn.Close() + + if err = connCheck(conn); err == nil { + t.Fatalf("expect has error") + } + }) + + t.Run("bad conn 2", func(t *testing.T) { + conn, err := net.DialTimeout(ts.Listener.Addr().Network(), ts.Listener.Addr().String(), time.Second) + if err != nil { + t.Fatalf(err.Error()) + } + defer conn.Close() + + ts.Close() + + if err = connCheck(conn); err == nil { + t.Fatalf("expect has err") + } + }) +} diff --git a/internal/pool/main_test.go b/internal/pool/main_test.go index 2365dbc..c54a38d 100644 --- a/internal/pool/main_test.go +++ b/internal/pool/main_test.go @@ -2,9 +2,12 @@ package pool_test import ( "context" + "fmt" "net" "sync" + "syscall" "testing" + "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -32,5 +35,87 @@ func perform(n int, cbs ...func(int)) { } func dummyDialer(context.Context) (net.Conn, error) { - return &net.TCPConn{}, nil + // return &net.TCPConn{}, nil + return newDummyConn(), nil +} + +func newDummyConn() net.Conn { + return &dummyConn{ + rawConn: &dummyRawConn{}, + } +} + +var _ net.Conn = (*dummyConn)(nil) +var _ syscall.Conn = (*dummyConn)(nil) + +type dummyConn struct { + rawConn *dummyRawConn +} + +func (d *dummyConn) SyscallConn() (syscall.RawConn, error) { + return d.rawConn, nil +} + +var errDummy = fmt.Errorf("dummyConn err") + +func (d *dummyConn) Read(b []byte) (n int, err error) { + return 0, errDummy +} + +func (d *dummyConn) Write(b []byte) (n int, err error) { + return 0, errDummy +} + +func (d *dummyConn) Close() error { + d.rawConn.Close() + return nil +} + +func (d *dummyConn) LocalAddr() net.Addr { + return &net.TCPAddr{} +} + +func (d *dummyConn) RemoteAddr() net.Addr { + return &net.TCPAddr{} +} + +func (d *dummyConn) SetDeadline(t time.Time) error { + return nil +} + +func (d *dummyConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (d *dummyConn) SetWriteDeadline(t time.Time) error { + return nil +} + +var _ syscall.RawConn = (*dummyRawConn)(nil) + +type dummyRawConn struct { + closed bool + mux sync.Mutex +} + +func (d *dummyRawConn) Control(f func(fd uintptr)) error { + return nil +} + +func (d *dummyRawConn) Read(f func(fd uintptr) (done bool)) error { + d.mux.Lock() + defer d.mux.Unlock() + if d.closed { + return fmt.Errorf("dummyRawConn closed") + } + return nil +} + +func (d *dummyRawConn) Write(f func(fd uintptr) (done bool)) error { + return nil +} +func (d *dummyRawConn) Close() { + d.mux.Lock() + d.closed = true + d.mux.Unlock() } diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 91b55e4..577923a 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -520,7 +520,7 @@ func (p *ConnPool) reapStaleConn() *Conn { func (p *ConnPool) isStaleConn(cn *Conn) bool { if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 { - return false + return connCheck(cn.netConn) != nil } now := time.Now() @@ -531,5 +531,5 @@ func (p *ConnPool) isStaleConn(cn *Conn) bool { return true } - return false + return connCheck(cn.netConn) != nil } diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index 795aef3..6c94fc2 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -6,10 +6,10 @@ import ( "testing" "time" - "github.com/go-redis/redis/v8/internal/pool" - . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" + + "github.com/go-redis/redis/v8/internal/pool" ) var _ = Describe("ConnPool", func() { @@ -285,6 +285,8 @@ var _ = Describe("conns reaper", func() { cn.SetUsedAt(time.Now().Add(-2 * idleTimeout)) case "aged": cn.SetCreatedAt(time.Now().Add(-2 * maxAge)) + case "conncheck": + cn.Close() } conns = append(conns, cn) staleConns = append(staleConns, cn) @@ -371,6 +373,7 @@ var _ = Describe("conns reaper", func() { assert("idle") assert("aged") + assert("conncheck") }) var _ = Describe("race", func() { diff --git a/main_test.go b/main_test.go index 0cb2b1d..dd9d874 100644 --- a/main_test.go +++ b/main_test.go @@ -12,10 +12,10 @@ import ( "testing" "time" - "github.com/go-redis/redis/v8" - . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" + + "github.com/go-redis/redis/v8" ) const ( @@ -117,7 +117,7 @@ func TestGinkgoSuite(t *testing.T) { RunSpecs(t, "go-redis") } -//------------------------------------------------------------------------------ +// ------------------------------------------------------------------------------ func redisOptions() *redis.Options { return &redis.Options{ @@ -364,7 +364,7 @@ func startSentinel(port, masterName, masterPort string) (*redisProcess, error) { return p, nil } -//------------------------------------------------------------------------------ +// ------------------------------------------------------------------------------ type badConnError string @@ -409,7 +409,7 @@ func (cn *badConn) Write([]byte) (int, error) { return 0, badConnError("bad connection") } -//------------------------------------------------------------------------------ +// ------------------------------------------------------------------------------ type hook struct { beforeProcess func(ctx context.Context, cmd redis.Cmder) (context.Context, error) diff --git a/pool_test.go b/pool_test.go index 08acc6d..8131819 100644 --- a/pool_test.go +++ b/pool_test.go @@ -87,8 +87,9 @@ var _ = Describe("pool", func() { cn.SetNetConn(&badConn{}) client.Pool().Put(ctx, cn) + // connCheck will automatically remove damaged connections. err = client.Ping(ctx).Err() - Expect(err).To(MatchError("bad connection")) + Expect(err).NotTo(HaveOccurred()) val, err := client.Ping(ctx).Result() Expect(err).NotTo(HaveOccurred()) diff --git a/sentinel_test.go b/sentinel_test.go index 7b4aabd..5faf9cf 100644 --- a/sentinel_test.go +++ b/sentinel_test.go @@ -191,7 +191,7 @@ var _ = Describe("NewFailoverClusterClient", func() { err = master.Shutdown(ctx).Err() Expect(err).NotTo(HaveOccurred()) Eventually(func() error { - return sentinelMaster.Ping(ctx).Err() + return master.Ping(ctx).Err() }, "15s", "100ms").Should(HaveOccurred()) // Check that client picked up new master. diff --git a/tx_test.go b/tx_test.go index 4681122..11e5b0d 100644 --- a/tx_test.go +++ b/tx_test.go @@ -123,7 +123,7 @@ var _ = Describe("Tx", func() { Expect(num).To(Equal(int64(N))) }) - It("should recover from bad connection", func() { + It("should remove from bad connection", func() { // Put bad connection in the pool. cn, err := client.Pool().Get(context.Background()) Expect(err).NotTo(HaveOccurred()) @@ -134,17 +134,14 @@ var _ = Describe("Tx", func() { do := func() error { err := client.Watch(ctx, func(tx *redis.Tx) error { _, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { - pipe.Ping(ctx) - return nil + return pipe.Ping(ctx).Err() }) return err }) return err } - err = do() - Expect(err).To(MatchError("bad connection")) - + // connCheck will automatically remove damaged connections. err = do() Expect(err).NotTo(HaveOccurred()) })