impl: decompressContextTakeover

This commit is contained in:
misu 2018-01-24 16:52:47 +09:00
parent e0da4e377f
commit 36c43970ee
3 changed files with 67 additions and 6 deletions

View File

@ -25,7 +25,7 @@ var (
}}
)
func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
func decompressNoContextTakeover(r io.Reader, b []byte) io.ReadCloser {
const tail =
// Add four bytes as specified in RFC
"\x00\x00\xff\xff" +
@ -37,6 +37,18 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
return &flateReadWrapper{fr}
}
func decompressContextTakeover(r io.Reader, dict []byte) io.ReadCloser {
const tail =
// Add four bytes as specified in RFC
"\x00\x00\xff\xff" +
// Add final block to squelch unexpected EOF error from flate reader.
"\x01\x00\x00\xff\xff"
fr, _ := flateReaderPool.Get().(io.ReadCloser)
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), dict)
return &flateReadWrapper{fr}
}
func isValidCompressionLevel(level int) bool {
return minCompressionLevel <= level && level <= maxCompressionLevel
}
@ -53,6 +65,18 @@ func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
return &flateWriteWrapper{fw: fw, tw: tw, p: p}
}
func compressContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
p := &flateWriterPools[level-minCompressionLevel]
tw := &truncWriter{w: w}
fw, _ := p.Get().(*flate.Writer)
if fw == nil {
fw, _ = flate.NewWriter(tw, level)
} else {
fw.Reset(tw)
}
return &flateWriteWrapper{fw: fw, tw: tw, p: p}
}
// truncWriter is an io.Writer that writes all but the last four bytes of the
// stream to another io.Writer.
type truncWriter struct {
@ -120,14 +144,16 @@ func (w *flateWriteWrapper) Close() error {
}
type flateReadWrapper struct {
fr io.ReadCloser
fr io.ReadCloser // flate.NewReader
}
func (r *flateReadWrapper) Read(p []byte) (int, error) {
if r.fr == nil {
return 0, io.ErrClosedPipe
}
n, err := r.fr.Read(p)
if err == io.EOF {
// Preemptively place the reader back in the pool. This helps with
// scenarios where the application does not call NextReader() soon after

42
conn.go
View File

@ -38,6 +38,8 @@ const (
continuationFrame = 0
noFrame = -1
maxWindowBits = 1 << 15
)
// Close codes defined in RFC 6455, section 11.7.
@ -259,8 +261,12 @@ type Conn struct {
readErrCount int
messageReader *messageReader // the current low-level reader
readDecompress bool // whether last read frame had RSV1 set
newDecompressionReader func(io.Reader) io.ReadCloser
readDecompress bool // whether last read frame had RSV1 set
newDecompressionReader func(io.Reader, []byte) io.ReadCloser // arges may flateReadWrapper struct
contextTakeover bool
dict []byte
mutex sync.RWMutex
}
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
@ -945,9 +951,14 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
if frameType == TextMessage || frameType == BinaryMessage {
c.messageReader = &messageReader{c}
c.reader = c.messageReader
if c.readDecompress {
c.reader = c.newDecompressionReader(c.reader)
switch {
case c.readDecompress && c.contextTakeover:
c.reader = c.newDecompressionReader(c.reader, c.dict)
case c.readDecompress:
c.reader = c.newDecompressionReader(c.reader, nil)
}
return frameType, c.reader, nil
}
}
@ -974,9 +985,11 @@ func (r *messageReader) Read(b []byte) (int, error) {
for c.readErr == nil {
if c.readRemaining > 0 {
// Determine the size of the data to be read.
if int64(len(b)) > c.readRemaining {
b = b[:c.readRemaining]
}
n, err := c.br.Read(b)
c.readErr = hideTempErr(err)
if c.isServer {
@ -986,6 +999,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
if c.readRemaining > 0 && c.readErr == io.EOF {
c.readErr = errUnexpectedEOF
}
return n, c.readErr
}
@ -1023,6 +1037,12 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
return messageType, nil, err
}
p, err = ioutil.ReadAll(r)
// if context-takeover add payload to dictionary
if c.contextTakeover {
c.AddDict(p)
}
return messageType, p, err
}
@ -1139,6 +1159,20 @@ func (c *Conn) SetCompressionLevel(level int) error {
return nil
}
func (c *Conn) AddDict(b []byte) {
c.mutex.Lock()
defer c.mutex.Unlock()
// Todo I do not know whether to leave the dictionary with 32768 bytes or more
// If it is recognized as a duplicate character string,
// deleting a part of the character may make it impossible to decrypt it.
c.dict = append(b, c.dict...)
if len(c.dict) > maxWindowBits {
c.dict = c.dict[:maxWindowBits]
}
}
// FormatCloseMessage formats closeCode and text as a WebSocket close message.
// An empty message is returned for code CloseNoStatusReceived.
func FormatCloseMessage(closeCode int, text string) []byte {

View File

@ -186,6 +186,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
if compress {
switch {
case contextTakeover:
c.contextTakeover = contextTakeover
c.newCompressionWriter = compressContextTakeover
c.newDecompressionReader = decompressContextTakeover
default: