diff --git a/compression.go b/compression.go index 813ffb1..64cbb44 100644 --- a/compression.go +++ b/compression.go @@ -25,7 +25,7 @@ var ( }} ) -func decompressNoContextTakeover(r io.Reader) io.ReadCloser { +func decompressNoContextTakeover(r io.Reader, b []byte) io.ReadCloser { const tail = // Add four bytes as specified in RFC "\x00\x00\xff\xff" + @@ -37,6 +37,18 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser { return &flateReadWrapper{fr} } +func decompressContextTakeover(r io.Reader, dict []byte) io.ReadCloser { + const tail = + // Add four bytes as specified in RFC + "\x00\x00\xff\xff" + + // Add final block to squelch unexpected EOF error from flate reader. + "\x01\x00\x00\xff\xff" + + fr, _ := flateReaderPool.Get().(io.ReadCloser) + fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), dict) + return &flateReadWrapper{fr} +} + func isValidCompressionLevel(level int) bool { return minCompressionLevel <= level && level <= maxCompressionLevel } @@ -53,6 +65,18 @@ func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { return &flateWriteWrapper{fw: fw, tw: tw, p: p} } +func compressContextTakeover(w io.WriteCloser, level int) io.WriteCloser { + p := &flateWriterPools[level-minCompressionLevel] + tw := &truncWriter{w: w} + fw, _ := p.Get().(*flate.Writer) + if fw == nil { + fw, _ = flate.NewWriter(tw, level) + } else { + fw.Reset(tw) + } + return &flateWriteWrapper{fw: fw, tw: tw, p: p} +} + // truncWriter is an io.Writer that writes all but the last four bytes of the // stream to another io.Writer. type truncWriter struct { @@ -120,14 +144,16 @@ func (w *flateWriteWrapper) Close() error { } type flateReadWrapper struct { - fr io.ReadCloser + fr io.ReadCloser // flate.NewReader } func (r *flateReadWrapper) Read(p []byte) (int, error) { if r.fr == nil { return 0, io.ErrClosedPipe } + n, err := r.fr.Read(p) + if err == io.EOF { // Preemptively place the reader back in the pool. This helps with // scenarios where the application does not call NextReader() soon after diff --git a/conn.go b/conn.go index cd3569d..cf4a9af 100644 --- a/conn.go +++ b/conn.go @@ -38,6 +38,8 @@ const ( continuationFrame = 0 noFrame = -1 + + maxWindowBits = 1 << 15 ) // Close codes defined in RFC 6455, section 11.7. @@ -259,8 +261,12 @@ type Conn struct { readErrCount int messageReader *messageReader // the current low-level reader - readDecompress bool // whether last read frame had RSV1 set - newDecompressionReader func(io.Reader) io.ReadCloser + readDecompress bool // whether last read frame had RSV1 set + newDecompressionReader func(io.Reader, []byte) io.ReadCloser // arges may flateReadWrapper struct + + contextTakeover bool + dict []byte + mutex sync.RWMutex } func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { @@ -945,9 +951,14 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { if frameType == TextMessage || frameType == BinaryMessage { c.messageReader = &messageReader{c} c.reader = c.messageReader - if c.readDecompress { - c.reader = c.newDecompressionReader(c.reader) + + switch { + case c.readDecompress && c.contextTakeover: + c.reader = c.newDecompressionReader(c.reader, c.dict) + case c.readDecompress: + c.reader = c.newDecompressionReader(c.reader, nil) } + return frameType, c.reader, nil } } @@ -974,9 +985,11 @@ func (r *messageReader) Read(b []byte) (int, error) { for c.readErr == nil { if c.readRemaining > 0 { + // Determine the size of the data to be read. if int64(len(b)) > c.readRemaining { b = b[:c.readRemaining] } + n, err := c.br.Read(b) c.readErr = hideTempErr(err) if c.isServer { @@ -986,6 +999,7 @@ func (r *messageReader) Read(b []byte) (int, error) { if c.readRemaining > 0 && c.readErr == io.EOF { c.readErr = errUnexpectedEOF } + return n, c.readErr } @@ -1023,6 +1037,12 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { return messageType, nil, err } p, err = ioutil.ReadAll(r) + + // if context-takeover add payload to dictionary + if c.contextTakeover { + c.AddDict(p) + } + return messageType, p, err } @@ -1139,6 +1159,20 @@ func (c *Conn) SetCompressionLevel(level int) error { return nil } +func (c *Conn) AddDict(b []byte) { + c.mutex.Lock() + defer c.mutex.Unlock() + + // Todo I do not know whether to leave the dictionary with 32768 bytes or more + // If it is recognized as a duplicate character string, + // deleting a part of the character may make it impossible to decrypt it. + c.dict = append(b, c.dict...) + + if len(c.dict) > maxWindowBits { + c.dict = c.dict[:maxWindowBits] + } +} + // FormatCloseMessage formats closeCode and text as a WebSocket close message. // An empty message is returned for code CloseNoStatusReceived. func FormatCloseMessage(closeCode int, text string) []byte { diff --git a/server.go b/server.go index 8c20621..67c82b7 100644 --- a/server.go +++ b/server.go @@ -186,6 +186,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade if compress { switch { case contextTakeover: + c.contextTakeover = contextTakeover c.newCompressionWriter = compressContextTakeover c.newDecompressionReader = decompressContextTakeover default: