diff --git a/compression.go b/compression.go index d642631..850654e 100644 --- a/compression.go +++ b/compression.go @@ -12,12 +12,15 @@ import ( "sync" ) +const ( + minCompressionLevel = flate.HuffmanOnly + maxCompressionLevel = flate.BestCompression + defaultCompressionLevel = 1 +) + var ( - flateWriterPool = sync.Pool{New: func() interface{} { - fw, _ := flate.NewWriter(nil, 3) - return fw - }} - flateReaderPool = sync.Pool{New: func() interface{} { + flateWriterPools [maxCompressionLevel - minCompressionLevel]sync.Pool + flateReaderPool = sync.Pool{New: func() interface{} { return flate.NewReader(nil) }} ) @@ -34,11 +37,20 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser { return &flateReadWrapper{fr} } -func compressNoContextTakeover(w io.WriteCloser) io.WriteCloser { +func isValidCompressionLevel(level int) bool { + return minCompressionLevel <= level && level <= maxCompressionLevel +} + +func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { + p := &flateWriterPools[level-minCompressionLevel] tw := &truncWriter{w: w} - fw, _ := flateWriterPool.Get().(*flate.Writer) - fw.Reset(tw) - return &flateWriteWrapper{fw: fw, tw: tw} + fw, _ := p.Get().(*flate.Writer) + if fw == nil { + fw, _ = flate.NewWriter(tw, level) + } else { + fw.Reset(tw) + } + return &flateWriteWrapper{fw: fw, tw: tw, p: p} } // truncWriter is an io.Writer that writes all but the last four bytes of the @@ -80,6 +92,7 @@ func (w *truncWriter) Write(p []byte) (int, error) { type flateWriteWrapper struct { fw *flate.Writer tw *truncWriter + p *sync.Pool } func (w *flateWriteWrapper) Write(p []byte) (int, error) { @@ -94,7 +107,7 @@ func (w *flateWriteWrapper) Close() error { return errWriteClosed } err1 := w.fw.Flush() - flateWriterPool.Put(w.fw) + w.p.Put(w.fw) w.fw = nil if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { return errors.New("websocket: internal error, unexpected bytes at end of flate stream") diff --git a/compression_test.go b/compression_test.go index ba39482..659cf42 100644 --- a/compression_test.go +++ b/compression_test.go @@ -64,3 +64,17 @@ func BenchmarkWriteWithCompression(b *testing.B) { } b.ReportAllocs() } + +func TestValidCompressionLevel(t *testing.T) { + c := newConn(fakeNetConn{}, false, 1024, 1024) + for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { + if err := c.SetCompressionLevel(level); err == nil { + t.Errorf("no error for level %d", level) + } + } + for _, level := range []int{minCompressionLevel, maxCompressionLevel} { + if err := c.SetCompressionLevel(level); err != nil { + t.Errorf("error for level %d", level) + } + } +} diff --git a/conn.go b/conn.go index 075099a..5a6f65d 100644 --- a/conn.go +++ b/conn.go @@ -241,7 +241,8 @@ type Conn struct { writeErr error enableWriteCompression bool - newCompressionWriter func(io.WriteCloser) io.WriteCloser + compressionLevel int + newCompressionWriter func(io.WriteCloser, int) io.WriteCloser // Read fields reader io.ReadCloser // the current reader returned to the application @@ -285,6 +286,7 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) readFinal: true, writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize), enableWriteCompression: true, + compressionLevel: defaultCompressionLevel, } c.SetCloseHandler(nil) c.SetPingHandler(nil) @@ -450,7 +452,7 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { } c.writer = mw if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { - w := c.newCompressionWriter(c.writer) + w := c.newCompressionWriter(c.writer, c.compressionLevel) mw.compress = true c.writer = w } @@ -1061,6 +1063,20 @@ func (c *Conn) EnableWriteCompression(enable bool) { c.enableWriteCompression = enable } +// SetCompressionLevel sets the flate compression level for subsequent text and +// binary messages. This function is a noop if compression was not negotiated +// with the peer. Valid levels range from -2 to 9. Level -1 uses the default +// compression level. Level -2 uses Huffman compression only, Level 0 does not +// attempt any compression. Levels 1 through 9 range from best speed to best +// compression. +func (c *Conn) SetCompressionLevel(level int) error { + if !isValidCompressionLevel(level) { + return errors.New("websocket: invalid compression level") + } + c.compressionLevel = level + return nil +} + // FormatCloseMessage formats closeCode and text as a WebSocket close message. func FormatCloseMessage(closeCode int, text string) []byte { buf := make([]byte, 2+len(text)) diff --git a/doc.go b/doc.go index 44a2882..282d5a8 100644 --- a/doc.go +++ b/doc.go @@ -118,9 +118,10 @@ // // Applications are responsible for ensuring that no more than one goroutine // calls the write methods (NextWriter, SetWriteDeadline, WriteMessage, -// WriteJSON) concurrently and that no more than one goroutine calls the read -// methods (NextReader, SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, -// SetPingHandler) concurrently. +// WriteJSON, EnableWriteCompression, SetCompressionLevel) concurrently and +// that no more than one goroutine calls the read methods (NextReader, +// SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, SetPingHandler) +// concurrently. // // The Close and WriteControl methods can be called concurrently with all other // methods.