impl: decompressContextTakeover

This commit is contained in:
misu 2018-01-24 16:52:47 +09:00
parent e0da4e377f
commit 36c43970ee
3 changed files with 67 additions and 6 deletions

View File

@ -25,7 +25,7 @@ var (
}} }}
) )
func decompressNoContextTakeover(r io.Reader) io.ReadCloser { func decompressNoContextTakeover(r io.Reader, b []byte) io.ReadCloser {
const tail = const tail =
// Add four bytes as specified in RFC // Add four bytes as specified in RFC
"\x00\x00\xff\xff" + "\x00\x00\xff\xff" +
@ -37,6 +37,18 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
return &flateReadWrapper{fr} return &flateReadWrapper{fr}
} }
func decompressContextTakeover(r io.Reader, dict []byte) io.ReadCloser {
const tail =
// Add four bytes as specified in RFC
"\x00\x00\xff\xff" +
// Add final block to squelch unexpected EOF error from flate reader.
"\x01\x00\x00\xff\xff"
fr, _ := flateReaderPool.Get().(io.ReadCloser)
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), dict)
return &flateReadWrapper{fr}
}
func isValidCompressionLevel(level int) bool { func isValidCompressionLevel(level int) bool {
return minCompressionLevel <= level && level <= maxCompressionLevel return minCompressionLevel <= level && level <= maxCompressionLevel
} }
@ -53,6 +65,18 @@ func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
return &flateWriteWrapper{fw: fw, tw: tw, p: p} return &flateWriteWrapper{fw: fw, tw: tw, p: p}
} }
func compressContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
p := &flateWriterPools[level-minCompressionLevel]
tw := &truncWriter{w: w}
fw, _ := p.Get().(*flate.Writer)
if fw == nil {
fw, _ = flate.NewWriter(tw, level)
} else {
fw.Reset(tw)
}
return &flateWriteWrapper{fw: fw, tw: tw, p: p}
}
// truncWriter is an io.Writer that writes all but the last four bytes of the // truncWriter is an io.Writer that writes all but the last four bytes of the
// stream to another io.Writer. // stream to another io.Writer.
type truncWriter struct { type truncWriter struct {
@ -120,14 +144,16 @@ func (w *flateWriteWrapper) Close() error {
} }
type flateReadWrapper struct { type flateReadWrapper struct {
fr io.ReadCloser fr io.ReadCloser // flate.NewReader
} }
func (r *flateReadWrapper) Read(p []byte) (int, error) { func (r *flateReadWrapper) Read(p []byte) (int, error) {
if r.fr == nil { if r.fr == nil {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
} }
n, err := r.fr.Read(p) n, err := r.fr.Read(p)
if err == io.EOF { if err == io.EOF {
// Preemptively place the reader back in the pool. This helps with // Preemptively place the reader back in the pool. This helps with
// scenarios where the application does not call NextReader() soon after // scenarios where the application does not call NextReader() soon after

42
conn.go
View File

@ -38,6 +38,8 @@ const (
continuationFrame = 0 continuationFrame = 0
noFrame = -1 noFrame = -1
maxWindowBits = 1 << 15
) )
// Close codes defined in RFC 6455, section 11.7. // Close codes defined in RFC 6455, section 11.7.
@ -259,8 +261,12 @@ type Conn struct {
readErrCount int readErrCount int
messageReader *messageReader // the current low-level reader messageReader *messageReader // the current low-level reader
readDecompress bool // whether last read frame had RSV1 set readDecompress bool // whether last read frame had RSV1 set
newDecompressionReader func(io.Reader) io.ReadCloser newDecompressionReader func(io.Reader, []byte) io.ReadCloser // arges may flateReadWrapper struct
contextTakeover bool
dict []byte
mutex sync.RWMutex
} }
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
@ -945,9 +951,14 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
if frameType == TextMessage || frameType == BinaryMessage { if frameType == TextMessage || frameType == BinaryMessage {
c.messageReader = &messageReader{c} c.messageReader = &messageReader{c}
c.reader = c.messageReader c.reader = c.messageReader
if c.readDecompress {
c.reader = c.newDecompressionReader(c.reader) switch {
case c.readDecompress && c.contextTakeover:
c.reader = c.newDecompressionReader(c.reader, c.dict)
case c.readDecompress:
c.reader = c.newDecompressionReader(c.reader, nil)
} }
return frameType, c.reader, nil return frameType, c.reader, nil
} }
} }
@ -974,9 +985,11 @@ func (r *messageReader) Read(b []byte) (int, error) {
for c.readErr == nil { for c.readErr == nil {
if c.readRemaining > 0 { if c.readRemaining > 0 {
// Determine the size of the data to be read.
if int64(len(b)) > c.readRemaining { if int64(len(b)) > c.readRemaining {
b = b[:c.readRemaining] b = b[:c.readRemaining]
} }
n, err := c.br.Read(b) n, err := c.br.Read(b)
c.readErr = hideTempErr(err) c.readErr = hideTempErr(err)
if c.isServer { if c.isServer {
@ -986,6 +999,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
if c.readRemaining > 0 && c.readErr == io.EOF { if c.readRemaining > 0 && c.readErr == io.EOF {
c.readErr = errUnexpectedEOF c.readErr = errUnexpectedEOF
} }
return n, c.readErr return n, c.readErr
} }
@ -1023,6 +1037,12 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
return messageType, nil, err return messageType, nil, err
} }
p, err = ioutil.ReadAll(r) p, err = ioutil.ReadAll(r)
// if context-takeover add payload to dictionary
if c.contextTakeover {
c.AddDict(p)
}
return messageType, p, err return messageType, p, err
} }
@ -1139,6 +1159,20 @@ func (c *Conn) SetCompressionLevel(level int) error {
return nil return nil
} }
func (c *Conn) AddDict(b []byte) {
c.mutex.Lock()
defer c.mutex.Unlock()
// Todo I do not know whether to leave the dictionary with 32768 bytes or more
// If it is recognized as a duplicate character string,
// deleting a part of the character may make it impossible to decrypt it.
c.dict = append(b, c.dict...)
if len(c.dict) > maxWindowBits {
c.dict = c.dict[:maxWindowBits]
}
}
// FormatCloseMessage formats closeCode and text as a WebSocket close message. // FormatCloseMessage formats closeCode and text as a WebSocket close message.
// An empty message is returned for code CloseNoStatusReceived. // An empty message is returned for code CloseNoStatusReceived.
func FormatCloseMessage(closeCode int, text string) []byte { func FormatCloseMessage(closeCode int, text string) []byte {

View File

@ -186,6 +186,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
if compress { if compress {
switch { switch {
case contextTakeover: case contextTakeover:
c.contextTakeover = contextTakeover
c.newCompressionWriter = compressContextTakeover c.newCompressionWriter = compressContextTakeover
c.newDecompressionReader = decompressContextTakeover c.newDecompressionReader = decompressContextTakeover
default: default: