diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 1dc6e97..e9a2585 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -58,23 +58,31 @@ func (cn *Conn) RemoteAddr() net.Addr { } func (cn *Conn) WithReader(ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error) error { - tm := cn.deadline(ctx, timeout) - _ = cn.netConn.SetReadDeadline(tm) + err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)) + if err != nil { + return err + } return fn(cn.rd) } func (cn *Conn) WithWriter( ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error, ) error { - tm := cn.deadline(ctx, timeout) - _ = cn.netConn.SetWriteDeadline(tm) - - firstErr := fn(cn.wr) - err := cn.wr.Flush() - if err != nil && firstErr == nil { - firstErr = err + err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)) + if err != nil { + return err } - return firstErr + + if cn.wr.Buffered() > 0 { + cn.wr.Reset(cn.netConn) + } + + err = fn(cn.wr) + if err != nil { + return err + } + + return cn.wr.Flush() } func (cn *Conn) Close() error { diff --git a/internal/proto/writer.go b/internal/proto/writer.go index 7e77a3b..cd83d65 100644 --- a/internal/proto/writer.go +++ b/internal/proto/writer.go @@ -152,6 +152,10 @@ func (w *Writer) crlf() error { return w.wr.WriteByte('\n') } +func (w *Writer) Buffered() int { + return w.wr.Buffered() +} + func (w *Writer) Reset(wr io.Writer) { w.wr.Reset(wr) } diff --git a/main_test.go b/main_test.go index 33c970c..0ffa75f 100644 --- a/main_test.go +++ b/main_test.go @@ -343,6 +343,14 @@ type badConn struct { var _ net.Conn = &badConn{} +func (cn *badConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (cn *badConn) SetWriteDeadline(t time.Time) error { + return nil +} + func (cn *badConn) Read([]byte) (int, error) { if cn.readDelay != 0 { time.Sleep(cn.readDelay)