diff --git a/redcon.go b/redcon.go index 8dd29b4..c359bc8 100644 --- a/redcon.go +++ b/redcon.go @@ -113,6 +113,8 @@ type Conn interface { PeekPipeline() []Command // NetConn returns the base net.Conn connection NetConn() net.Conn + // WriteBulkFrom write bulk from io.Reader, size n + WriteBulkFrom(n int64, rb io.Reader) } // NewServer returns a new Redcon server configured on "tcp" network net. @@ -494,6 +496,9 @@ func (c *conn) PeekPipeline() []Command { func (c *conn) NetConn() net.Conn { return c.conn } +func (c *conn) WriteBulkFrom(n int64, rb io.Reader) { + c.wr.WriteBulkFrom(n, rb) +} // BaseWriter returns the underlying connection writer, if any func BaseWriter(c Conn) *Writer { @@ -589,15 +594,29 @@ type Writer struct { w io.Writer b []byte err error + + // buff use io buffer write to w(io.Writer) + // for io.Copy r(io.Reader) to w(io.Writer) + buff *bufio.Writer } // NewWriter creates a new RESP writer. func NewWriter(wr io.Writer) *Writer { return &Writer{ - w: wr, + w: wr, + buff: bufio.NewWriter(wr), } } +func (w *Writer) WriteBulkFrom(n int64, r io.Reader) { + if w != nil && w.err != nil { + return + } + w.buff.Write(appendPrefix(w.b, '$', n)) + io.Copy(w.buff, r) + w.buff.Write([]byte{'\r', '\n'}) +} + // WriteNull writes a null to the client func (w *Writer) WriteNull() { if w.err != nil { @@ -656,6 +675,10 @@ func (w *Writer) SetBuffer(raw []byte) { // Flush writes all unflushed Write* calls to the underlying writer. func (w *Writer) Flush() error { + if w.buff != nil { + w.buff.Flush() + } + if w.err != nil { return w.err } diff --git a/redcon_test.go b/redcon_test.go index 757320e..987f105 100644 --- a/redcon_test.go +++ b/redcon_test.go @@ -365,6 +365,35 @@ func testServerNetwork(t *testing.T, network, laddr string) { <-done } +func TestConnImpl(t *testing.T) { + var i interface{} = &conn{} + if _, ok := i.(Conn); !ok { + t.Fatalf("conn does not implement Conn interface") + } +} + +func TestWriteBulkFrom(t *testing.T) { + wbuf := &bytes.Buffer{} + wr := NewWriter(wbuf) + rbuf := &bytes.Buffer{} + testStr := "hello world" + rbuf.WriteString(testStr) + wr.WriteBulkFrom(int64(len(testStr)), rbuf) + wr.Flush() + if wbuf.String() != fmt.Sprintf("$%d\r\n%s\r\n", len(testStr), testStr) { + t.Fatal("failed") + } + wbuf.Reset() + testStr1 := "hi world" + rbuf.WriteString(testStr1) + wr.WriteBulkFrom(int64(len(testStr1)), rbuf) + wr.Flush() + if wbuf.String() != fmt.Sprintf("$%d\r\n%s\r\n", len(testStr1), testStr1) { + t.Fatal("failed") + } + wbuf.Reset() +} + func TestWriter(t *testing.T) { buf := &bytes.Buffer{} wr := NewWriter(buf)