mirror of https://github.com/gorilla/websocket.git
upgrade: compressContextTakeover reader
This commit is contained in:
parent
aca4275801
commit
e7575e215d
|
@ -333,10 +333,13 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
|||
conn.contextTakeover = true
|
||||
|
||||
var f contextTakeoverWriterFactory
|
||||
f.fw, _ = flate.NewWriter(&f.tw, d.CompressionLevel) // level is specified in Dialer, Upgrader
|
||||
f.fw, _ = flate.NewWriter(&f.tw, d.CompressionLevel)
|
||||
conn.newCompressionWriter = f.newCompressionWriter
|
||||
|
||||
conn.newDecompressionReader = decompressContextTakeover
|
||||
var frf contextTakeoverReaderFactory
|
||||
fr := flate.NewReader(nil)
|
||||
frf.fr = fr
|
||||
conn.newDecompressionReader = frf.newDeCompressionReader
|
||||
default:
|
||||
conn.newCompressionWriter = compressNoContextTakeover
|
||||
conn.newDecompressionReader = decompressNoContextTakeover
|
||||
|
|
117
compression.go
117
compression.go
|
@ -17,41 +17,27 @@ const (
|
|||
minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6
|
||||
maxCompressionLevel = flate.BestCompression
|
||||
defaultCompressionLevel = 1
|
||||
|
||||
tail =
|
||||
// Add four bytes as specified in RFC
|
||||
"\x00\x00\xff\xff" +
|
||||
// Add final block to squelch unexpected EOF error from flate reader.
|
||||
"\x01\x00\x00\xff\xff"
|
||||
)
|
||||
|
||||
var (
|
||||
flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool
|
||||
flateWriterDictPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool
|
||||
flateReaderPool = sync.Pool{New: func() interface{} {
|
||||
flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool
|
||||
flateReaderPool = sync.Pool{New: func() interface{} {
|
||||
return flate.NewReader(nil)
|
||||
}}
|
||||
)
|
||||
|
||||
func decompressNoContextTakeover(r io.Reader, dict *[]byte) io.ReadCloser {
|
||||
const tail =
|
||||
// Add four bytes as specified in RFC
|
||||
"\x00\x00\xff\xff" +
|
||||
// Add final block to squelch unexpected EOF error from flate reader.
|
||||
"\x01\x00\x00\xff\xff"
|
||||
|
||||
func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
|
||||
fr, _ := flateReaderPool.Get().(io.ReadCloser)
|
||||
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
|
||||
return &flateReadWrapper{fr: fr}
|
||||
}
|
||||
|
||||
func decompressContextTakeover(r io.Reader, dict *[]byte) io.ReadCloser {
|
||||
const tail =
|
||||
// Add four bytes as specified in RFC
|
||||
"\x00\x00\xff\xff" +
|
||||
// Add final block to squelch unexpected EOF error from flate reader.
|
||||
"\x01\x00\x00\xff\xff"
|
||||
|
||||
fr, _ := flateReaderPool.Get().(io.ReadCloser)
|
||||
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), *dict)
|
||||
|
||||
return &flateReadWrapper{fr: fr, hasDict: true, dict: dict}
|
||||
}
|
||||
|
||||
func isValidCompressionLevel(level int) bool {
|
||||
return minCompressionLevel <= level && level <= maxCompressionLevel
|
||||
}
|
||||
|
@ -108,8 +94,6 @@ type flateWriteWrapper struct {
|
|||
fw *flate.Writer
|
||||
tw *truncWriter
|
||||
p *sync.Pool
|
||||
|
||||
isDictWriter bool
|
||||
}
|
||||
|
||||
func (w *flateWriteWrapper) Write(p []byte) (int, error) {
|
||||
|
@ -126,19 +110,15 @@ func (w *flateWriteWrapper) Close() error {
|
|||
}
|
||||
err1 := w.fw.Flush()
|
||||
|
||||
if !w.isDictWriter {
|
||||
w.p.Put(w.fw)
|
||||
w.fw = nil
|
||||
}
|
||||
w.p.Put(w.fw)
|
||||
w.fw = nil
|
||||
|
||||
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
|
||||
return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
|
||||
}
|
||||
|
||||
if !w.isDictWriter {
|
||||
w.tw.p = [4]byte{}
|
||||
w.tw.n = 0
|
||||
}
|
||||
w.tw.p = [4]byte{}
|
||||
w.tw.n = 0
|
||||
|
||||
err2 := w.tw.w.Close()
|
||||
if err1 != nil {
|
||||
|
@ -150,9 +130,6 @@ func (w *flateWriteWrapper) Close() error {
|
|||
|
||||
type flateReadWrapper struct {
|
||||
fr io.ReadCloser // flate.NewReader
|
||||
|
||||
hasDict bool
|
||||
dict *[]byte
|
||||
}
|
||||
|
||||
func (r *flateReadWrapper) Read(p []byte) (int, error) {
|
||||
|
@ -169,12 +146,6 @@ func (r *flateReadWrapper) Read(p []byte) (int, error) {
|
|||
r.Close()
|
||||
}
|
||||
|
||||
if r.hasDict {
|
||||
if n > 0 {
|
||||
r.addDict(p[:n])
|
||||
}
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
|
@ -184,24 +155,12 @@ func (r *flateReadWrapper) Close() error {
|
|||
}
|
||||
err := r.fr.Close()
|
||||
|
||||
if !r.hasDict {
|
||||
flateReaderPool.Put(r.fr)
|
||||
}
|
||||
flateReaderPool.Put(r.fr)
|
||||
|
||||
r.fr = nil
|
||||
return err
|
||||
}
|
||||
|
||||
// addDict adds payload to dict.
|
||||
func (r *flateReadWrapper) addDict(b []byte) {
|
||||
*r.dict = append(*r.dict, b...)
|
||||
|
||||
if len(*r.dict) > maxWindowBits {
|
||||
offset := len(*r.dict) - maxWindowBits
|
||||
*r.dict = (*r.dict)[offset:]
|
||||
}
|
||||
}
|
||||
|
||||
type (
|
||||
contextTakeoverWriterFactory struct {
|
||||
fw *flate.Writer
|
||||
|
@ -242,3 +201,51 @@ func (w *flateTakeoverWriteWrapper) Close() error {
|
|||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
type (
|
||||
contextTakeoverReaderFactory struct {
|
||||
fr io.ReadCloser
|
||||
window []byte
|
||||
}
|
||||
|
||||
flateTakeoverReadWrapper struct {
|
||||
f *contextTakeoverReaderFactory
|
||||
}
|
||||
)
|
||||
|
||||
func (f *contextTakeoverReaderFactory) newDeCompressionReader(r io.Reader) io.ReadCloser {
|
||||
f.fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), f.window)
|
||||
return &flateTakeoverReadWrapper{f}
|
||||
}
|
||||
|
||||
func (r *flateTakeoverReadWrapper) Read(p []byte) (int, error) {
|
||||
if r.f.fr == nil {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
|
||||
n, err := r.f.fr.Read(p)
|
||||
|
||||
// add dictionary
|
||||
r.f.window = append(r.f.window, p[:n]...)
|
||||
if len(r.f.window) > maxWindowBits {
|
||||
offset := len(r.f.window) - maxWindowBits
|
||||
r.f.window = r.f.window[offset:]
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
// Preemptively place the reader back in the pool. This helps with
|
||||
// scenarios where the application does not call NextReader() soon after
|
||||
// this final read.
|
||||
r.Close()
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *flateTakeoverReadWrapper) Close() error {
|
||||
if r.f.fr == nil {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
err := r.f.fr.Close()
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -91,7 +91,7 @@ func BenchmarkReadWithCompression(b *testing.B) {
|
|||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
r := bytes.NewReader(messages[i%len(messages)])
|
||||
reader := c.newDecompressionReader(r, nil)
|
||||
reader := c.newDecompressionReader(r)
|
||||
ioutil.ReadAll(reader)
|
||||
}
|
||||
b.ReportAllocs()
|
||||
|
@ -102,12 +102,15 @@ func BenchmarkReadWithCompressionOfContextTakeover(b *testing.B) {
|
|||
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
|
||||
c.enableWriteCompression = true
|
||||
c.contextTakeover = true
|
||||
c.newDecompressionReader = decompressContextTakeover
|
||||
var frf contextTakeoverReaderFactory
|
||||
fr := flate.NewReader(nil)
|
||||
frf.fr = fr
|
||||
c.newDecompressionReader = frf.newDeCompressionReader
|
||||
messages := textMessages(100)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
r := bytes.NewReader(messages[i%len(messages)])
|
||||
reader := c.newDecompressionReader(r, c.rxDict)
|
||||
reader := c.newDecompressionReader(r)
|
||||
ioutil.ReadAll(reader)
|
||||
}
|
||||
b.ReportAllocs()
|
||||
|
|
17
conn.go
17
conn.go
|
@ -263,8 +263,8 @@ type Conn struct {
|
|||
readErrCount int
|
||||
messageReader *messageReader // the current low-level reader
|
||||
|
||||
readDecompress bool // whether last read frame had RSV1 set
|
||||
newDecompressionReader func(io.Reader, *[]byte) io.ReadCloser // arges may flateReadWrapper struct
|
||||
readDecompress bool // whether last read frame had RSV1 set
|
||||
newDecompressionReader func(io.Reader) io.ReadCloser // arges may flateReadWrapper struct
|
||||
|
||||
contextTakeover bool
|
||||
rxDict *[]byte
|
||||
|
@ -955,13 +955,14 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
|||
if frameType == TextMessage || frameType == BinaryMessage {
|
||||
c.messageReader = &messageReader{c}
|
||||
c.reader = c.messageReader
|
||||
c.reader = c.newDecompressionReader(c.reader)
|
||||
|
||||
switch {
|
||||
case c.readDecompress && c.contextTakeover:
|
||||
c.reader = c.newDecompressionReader(c.reader, c.rxDict)
|
||||
case c.readDecompress:
|
||||
c.reader = c.newDecompressionReader(c.reader, nil)
|
||||
}
|
||||
// switch {
|
||||
// case c.readDecompress && c.contextTakeover:
|
||||
// c.reader = c.newDecompressionReader(c.reader, c.rxDict)
|
||||
// case c.readDecompress:
|
||||
// c.reader = c.newDecompressionReader(c.reader, nil)
|
||||
// }
|
||||
|
||||
return frameType, c.reader, nil
|
||||
}
|
||||
|
|
11
server.go
11
server.go
|
@ -199,11 +199,14 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
|||
case contextTakeover && u.EnableContextTakeover:
|
||||
c.contextTakeover = contextTakeover
|
||||
|
||||
var f contextTakeoverWriterFactory
|
||||
f.fw, _ = flate.NewWriter(&f.tw, u.CompressionLevel) // level is specified in Dialer, Upgrader
|
||||
c.newCompressionWriter = f.newCompressionWriter
|
||||
var fwf contextTakeoverWriterFactory
|
||||
fwf.fw, _ = flate.NewWriter(&fwf.tw, u.CompressionLevel)
|
||||
c.newCompressionWriter = fwf.newCompressionWriter
|
||||
|
||||
c.newDecompressionReader = decompressContextTakeover
|
||||
var frf contextTakeoverReaderFactory
|
||||
fr := flate.NewReader(nil)
|
||||
frf.fr = fr
|
||||
c.newDecompressionReader = frf.newDeCompressionReader
|
||||
default:
|
||||
c.newCompressionWriter = compressNoContextTakeover
|
||||
c.newDecompressionReader = decompressNoContextTakeover
|
||||
|
|
Loading…
Reference in New Issue