websocket/compression.go

245 lines
4.9 KiB
Go
Raw Normal View History

2016-12-16 01:53:35 +03:00
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"errors"
"io"
"strings"
2016-12-16 01:53:35 +03:00
"sync"
2018-01-30 13:02:50 +03:00
"compress/flate"
2016-12-16 01:53:35 +03:00
)
const (
2017-01-19 22:27:08 +03:00
minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6
maxCompressionLevel = flate.BestCompression
defaultCompressionLevel = 1
)
2016-12-16 01:53:35 +03:00
var (
2018-01-26 11:52:23 +03:00
flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool
flateWriterDictPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool
flateReaderPool = sync.Pool{New: func() interface{} {
return flate.NewReader(nil)
}}
)
2018-01-29 09:10:19 +03:00
func decompressNoContextTakeover(r io.Reader, dict *[]byte) io.ReadCloser {
const tail =
// Add four bytes as specified in RFC
"\x00\x00\xff\xff" +
// Add final block to squelch unexpected EOF error from flate reader.
"\x01\x00\x00\xff\xff"
2016-12-18 02:33:06 +03:00
fr, _ := flateReaderPool.Get().(io.ReadCloser)
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
2018-01-29 09:10:19 +03:00
return &flateReadWrapper{fr: fr}
}
2018-01-29 09:10:19 +03:00
func decompressContextTakeover(r io.Reader, dict *[]byte) io.ReadCloser {
2018-01-24 10:52:47 +03:00
const tail =
// Add four bytes as specified in RFC
"\x00\x00\xff\xff" +
// Add final block to squelch unexpected EOF error from flate reader.
"\x01\x00\x00\xff\xff"
fr, _ := flateReaderPool.Get().(io.ReadCloser)
2018-01-29 12:12:43 +03:00
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), *dict)
2018-01-29 09:10:19 +03:00
return &flateReadWrapper{fr: fr, hasDict: true, dict: dict}
2018-01-24 10:52:47 +03:00
}
func isValidCompressionLevel(level int) bool {
return minCompressionLevel <= level && level <= maxCompressionLevel
}
2018-01-30 13:02:50 +03:00
func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
p := &flateWriterPools[level-minCompressionLevel]
tw := &truncWriter{w: w}
fw, _ := p.Get().(*flate.Writer)
if fw == nil {
fw, _ = flate.NewWriter(tw, level)
} else {
fw.Reset(tw)
}
return &flateWriteWrapper{fw: fw, tw: tw, p: p}
}
// truncWriter is an io.Writer that writes all but the last four bytes of the
// stream to another io.Writer.
type truncWriter struct {
w io.WriteCloser
n int
p [4]byte
}
func (w *truncWriter) Write(p []byte) (int, error) {
n := 0
// fill buffer first for simplicity.
if w.n < len(w.p) {
n = copy(w.p[w.n:], p)
p = p[n:]
w.n += n
if len(p) == 0 {
return n, nil
}
}
m := len(p)
if m > len(w.p) {
m = len(w.p)
}
if nn, err := w.w.Write(w.p[:m]); err != nil {
return n + nn, err
}
copy(w.p[:], w.p[m:])
copy(w.p[len(w.p)-m:], p[len(p)-m:])
nn, err := w.w.Write(p[:len(p)-m])
return n + nn, err
}
2016-12-18 02:33:06 +03:00
type flateWriteWrapper struct {
fw *flate.Writer
tw *truncWriter
p *sync.Pool
isDictWriter bool
}
2016-12-18 02:33:06 +03:00
func (w *flateWriteWrapper) Write(p []byte) (int, error) {
2016-12-16 01:53:35 +03:00
if w.fw == nil {
return 0, errWriteClosed
}
2018-01-29 09:10:19 +03:00
return w.fw.Write(p)
}
2016-12-18 02:33:06 +03:00
func (w *flateWriteWrapper) Close() error {
2016-12-16 01:53:35 +03:00
if w.fw == nil {
return errWriteClosed
}
err1 := w.fw.Flush()
2018-01-29 09:10:19 +03:00
if !w.isDictWriter {
w.p.Put(w.fw)
w.fw = nil
}
2018-01-29 09:10:19 +03:00
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
}
2018-01-31 14:52:04 +03:00
if !w.isDictWriter {
w.tw.p = [4]byte{}
w.tw.n = 0
}
err2 := w.tw.w.Close()
if err1 != nil {
return err1
}
return err2
}
2016-12-18 02:33:06 +03:00
type flateReadWrapper struct {
2018-01-24 10:52:47 +03:00
fr io.ReadCloser // flate.NewReader
2018-01-29 09:10:19 +03:00
hasDict bool
dict *[]byte
2016-12-18 02:33:06 +03:00
}
func (r *flateReadWrapper) Read(p []byte) (int, error) {
if r.fr == nil {
return 0, io.ErrClosedPipe
}
2018-01-24 10:52:47 +03:00
2016-12-18 02:33:06 +03:00
n, err := r.fr.Read(p)
2018-01-24 10:52:47 +03:00
2016-12-18 02:33:06 +03:00
if err == io.EOF {
// Preemptively place the reader back in the pool. This helps with
// scenarios where the application does not call NextReader() soon after
// this final read.
r.Close()
}
2018-01-29 09:10:19 +03:00
if r.hasDict {
if n > 0 {
r.addDict(p[:n])
}
}
2016-12-18 02:33:06 +03:00
return n, err
}
func (r *flateReadWrapper) Close() error {
if r.fr == nil {
return io.ErrClosedPipe
}
err := r.fr.Close()
2018-01-29 09:10:19 +03:00
if !r.hasDict {
flateReaderPool.Put(r.fr)
}
2016-12-18 02:33:06 +03:00
r.fr = nil
return err
}
2018-01-29 09:10:19 +03:00
// addDict adds payload to dict.
func (r *flateReadWrapper) addDict(b []byte) {
*r.dict = append(*r.dict, b...)
if len(*r.dict) > maxWindowBits {
offset := len(*r.dict) - maxWindowBits
*r.dict = (*r.dict)[offset:]
}
}
2018-02-01 06:59:30 +03:00
type (
contextTakeoverWriterFactory struct {
fw *flate.Writer
tw truncWriter
}
flateTakeoverWriteWrapper struct {
f *contextTakeoverWriterFactory
}
)
func (f *contextTakeoverWriterFactory) newCompressionWriter(w io.WriteCloser, level int) io.WriteCloser {
f.tw.w = w
f.tw.n = 0
return &flateTakeoverWriteWrapper{f}
}
func (w *flateTakeoverWriteWrapper) Write(p []byte) (int, error) {
if w.f == nil {
return 0, errWriteClosed
}
return w.f.fw.Write(p)
}
func (w *flateTakeoverWriteWrapper) Close() error {
if w.f == nil {
return errWriteClosed
}
f := w.f
w.f = nil
err1 := f.fw.Flush()
if f.tw.p != [4]byte{0, 0, 0xff, 0xff} {
return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
}
err2 := f.tw.w.Close()
if err1 != nil {
return err1
}
return err2
}