mod: flate.writer for context-takeover

This commit is contained in:
misu 2018-02-01 12:59:30 +09:00
parent fb7d67a34a
commit f8b4a0f71d
5 changed files with 60 additions and 49 deletions

View File

@ -6,6 +6,7 @@ package websocket
import (
"bytes"
"compress/flate"
"crypto/tls"
"errors"
"io"
@ -322,7 +323,11 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
switch {
case cmwb && smwb:
conn.contextTakeover = true
conn.newCompressionWriter = compressContextTakeover
var f contextTakeoverWriterFactory
f.fw, _ = flate.NewWriter(&f.tw, 2) // level is specified in Dialer, Upgrader
conn.newCompressionWriter = f.newCompressionWriter
conn.newDecompressionReader = decompressContextTakeover
default:
conn.newCompressionWriter = compressNoContextTakeover

View File

@ -6,7 +6,6 @@ package websocket
import (
"errors"
"fmt"
"io"
"strings"
"sync"
@ -69,18 +68,6 @@ func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
return &flateWriteWrapper{fw: fw, tw: tw, p: p}
}
func compressContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
// p := &flateWriterDictPools[level-minCompressionLevel]
// tw := &truncWriter{w: w}
// fw, _ := p.Get().(*flate.Writer)
// if fw == nil {
// fw, _ = flate.NewWriterDict(tw, level, nil)
// } else {
// fw.Reset(tw)
// }
return &flateWriteWrapper{}
}
// truncWriter is an io.Writer that writes all but the last four bytes of the
// stream to another io.Writer.
type truncWriter struct {
@ -91,9 +78,6 @@ type truncWriter struct {
func (w *truncWriter) Write(p []byte) (int, error) {
n := 0
fmt.Printf("\x1b[32m Start truncWriter.Write %#v \x1b[0m\n", p)
fmt.Printf("\x1b[32m truncWriter w.n -> len %#v \x1b[0m\n", w.n)
fmt.Printf("\x1b[32m truncWriter w.p %#v \x1b[0m\n", w.p)
// fill buffer first for simplicity.
if w.n < len(w.p) {
@ -110,18 +94,13 @@ func (w *truncWriter) Write(p []byte) (int, error) {
m = len(w.p)
}
fmt.Printf("\x1b[32m Write will truncWriter.Write %#v \x1b[0m\n", w.p[:m])
if nn, err := w.w.Write(w.p[:m]); err != nil {
fmt.Printf("\x1b[32m w.w.Write Error truncWriter.Write %#v \x1b[0m\n", err)
return n + nn, err
}
copy(w.p[:], w.p[m:])
copy(w.p[len(w.p)-m:], p[len(p)-m:])
nn, err := w.w.Write(p[:len(p)-m])
fmt.Printf("\x1b[32m End truncWriter.Write %#v \x1b[0m\n", p)
fmt.Printf("\x1b[32m End truncWriter w.p %#v \x1b[0m\n", w.p)
return n + nn, err
}
@ -138,8 +117,6 @@ func (w *flateWriteWrapper) Write(p []byte) (int, error) {
return 0, errWriteClosed
}
fmt.Printf("flateWriteWrapper will Write %#v \n", p)
return w.fw.Write(p)
}
@ -149,9 +126,6 @@ func (w *flateWriteWrapper) Close() error {
}
err1 := w.fw.Flush()
fmt.Printf("w.tw.n -> -> %#v \n", w.tw.n)
fmt.Printf("w.tw.p -> -> %#v \n", w.tw.p)
if !w.isDictWriter {
w.p.Put(w.fw)
w.fw = nil
@ -171,8 +145,6 @@ func (w *flateWriteWrapper) Close() error {
return err1
}
fmt.Printf("err2 %#v \n", err2)
return err2
}
@ -229,3 +201,44 @@ func (r *flateReadWrapper) addDict(b []byte) {
*r.dict = (*r.dict)[offset:]
}
}
type (
contextTakeoverWriterFactory struct {
fw *flate.Writer
tw truncWriter
}
flateTakeoverWriteWrapper struct {
f *contextTakeoverWriterFactory
}
)
func (f *contextTakeoverWriterFactory) newCompressionWriter(w io.WriteCloser, level int) io.WriteCloser {
f.tw.w = w
f.tw.n = 0
return &flateTakeoverWriteWrapper{f}
}
func (w *flateTakeoverWriteWrapper) Write(p []byte) (int, error) {
if w.f == nil {
return 0, errWriteClosed
}
return w.f.fw.Write(p)
}
func (w *flateTakeoverWriteWrapper) Close() error {
if w.f == nil {
return errWriteClosed
}
f := w.f
w.f = nil
err1 := f.fw.Flush()
if f.tw.p != [4]byte{0, 0, 0xff, 0xff} {
return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
}
err2 := f.tw.w.Close()
if err1 != nil {
return err1
}
return err2
}

View File

@ -2,6 +2,7 @@ package websocket
import (
"bytes"
"compress/flate"
"fmt"
"io"
"io/ioutil"
@ -71,7 +72,9 @@ func BenchmarkWriteWithCompressionOfContextTakeover(b *testing.B) {
messages := textMessages(100)
c.enableWriteCompression = true
c.contextTakeover = true
c.newCompressionWriter = compressContextTakeover
var f contextTakeoverWriterFactory
f.fw, _ = flate.NewWriter(&f.tw, 2) // level is specified in Dialer, Upgrader
c.newCompressionWriter = f.newCompressionWriter
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.WriteMessage(TextMessage, messages[i%len(messages)])

21
conn.go
View File

@ -6,7 +6,6 @@ package websocket
import (
"bufio"
"compress/flate"
"encoding/binary"
"errors"
"io"
@ -346,6 +345,7 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in
rxDict: &[]byte{},
}
c.SetCloseHandler(nil)
c.SetPingHandler(nil)
c.SetPongHandler(nil)
@ -517,23 +517,8 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
mw.compress = true
// For context-takeover
if c.contextTakeover {
if fww, ok := c.compressionWriters[messageType]; ok {
// tw := &truncWriter{w: c.writer}
//Todo reset trunkwriter inside flate.Writer.
// fw, _ := flate.NewWriterDict(tw, c.compressionLevel, []byte("Hello"))
// fww.fw.Reset(tw)
// fww.fw = fw
fww.tw.w = c.writer
return fww, nil
} else {
tw := &truncWriter{w: c.writer}
fw, _ := flate.NewWriterDict(tw, c.compressionLevel, nil)
fww := &flateWriteWrapper{fw: fw, tw: tw, isDictWriter: true}
c.compressionWriters[messageType] = fww
return fww, nil
}
c.writer = c.newCompressionWriter(c.writer, c.compressionLevel)
return c.writer, nil
}
c.writer = c.newCompressionWriter(c.writer, c.compressionLevel)

View File

@ -6,6 +6,7 @@ package websocket
import (
"bufio"
"compress/flate"
"errors"
"net"
"net/http"
@ -187,7 +188,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
switch {
case contextTakeover:
c.contextTakeover = contextTakeover
c.newCompressionWriter = compressContextTakeover
var f contextTakeoverWriterFactory
f.fw, _ = flate.NewWriter(&f.tw, 2) // level is specified in Dialer, Upgrader
c.newCompressionWriter = f.newCompressionWriter
c.newDecompressionReader = decompressContextTakeover
default:
c.newCompressionWriter = compressNoContextTakeover