mirror of https://github.com/gorilla/websocket.git
impl: decompressContextTakeover
This commit is contained in:
parent
e0da4e377f
commit
36c43970ee
|
@ -25,7 +25,7 @@ var (
|
||||||
}}
|
}}
|
||||||
)
|
)
|
||||||
|
|
||||||
func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
|
func decompressNoContextTakeover(r io.Reader, b []byte) io.ReadCloser {
|
||||||
const tail =
|
const tail =
|
||||||
// Add four bytes as specified in RFC
|
// Add four bytes as specified in RFC
|
||||||
"\x00\x00\xff\xff" +
|
"\x00\x00\xff\xff" +
|
||||||
|
@ -37,6 +37,18 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
|
||||||
return &flateReadWrapper{fr}
|
return &flateReadWrapper{fr}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func decompressContextTakeover(r io.Reader, dict []byte) io.ReadCloser {
|
||||||
|
const tail =
|
||||||
|
// Add four bytes as specified in RFC
|
||||||
|
"\x00\x00\xff\xff" +
|
||||||
|
// Add final block to squelch unexpected EOF error from flate reader.
|
||||||
|
"\x01\x00\x00\xff\xff"
|
||||||
|
|
||||||
|
fr, _ := flateReaderPool.Get().(io.ReadCloser)
|
||||||
|
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), dict)
|
||||||
|
return &flateReadWrapper{fr}
|
||||||
|
}
|
||||||
|
|
||||||
func isValidCompressionLevel(level int) bool {
|
func isValidCompressionLevel(level int) bool {
|
||||||
return minCompressionLevel <= level && level <= maxCompressionLevel
|
return minCompressionLevel <= level && level <= maxCompressionLevel
|
||||||
}
|
}
|
||||||
|
@ -53,6 +65,18 @@ func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
|
||||||
return &flateWriteWrapper{fw: fw, tw: tw, p: p}
|
return &flateWriteWrapper{fw: fw, tw: tw, p: p}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func compressContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
|
||||||
|
p := &flateWriterPools[level-minCompressionLevel]
|
||||||
|
tw := &truncWriter{w: w}
|
||||||
|
fw, _ := p.Get().(*flate.Writer)
|
||||||
|
if fw == nil {
|
||||||
|
fw, _ = flate.NewWriter(tw, level)
|
||||||
|
} else {
|
||||||
|
fw.Reset(tw)
|
||||||
|
}
|
||||||
|
return &flateWriteWrapper{fw: fw, tw: tw, p: p}
|
||||||
|
}
|
||||||
|
|
||||||
// truncWriter is an io.Writer that writes all but the last four bytes of the
|
// truncWriter is an io.Writer that writes all but the last four bytes of the
|
||||||
// stream to another io.Writer.
|
// stream to another io.Writer.
|
||||||
type truncWriter struct {
|
type truncWriter struct {
|
||||||
|
@ -120,14 +144,16 @@ func (w *flateWriteWrapper) Close() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
type flateReadWrapper struct {
|
type flateReadWrapper struct {
|
||||||
fr io.ReadCloser
|
fr io.ReadCloser // flate.NewReader
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *flateReadWrapper) Read(p []byte) (int, error) {
|
func (r *flateReadWrapper) Read(p []byte) (int, error) {
|
||||||
if r.fr == nil {
|
if r.fr == nil {
|
||||||
return 0, io.ErrClosedPipe
|
return 0, io.ErrClosedPipe
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := r.fr.Read(p)
|
n, err := r.fr.Read(p)
|
||||||
|
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
// Preemptively place the reader back in the pool. This helps with
|
// Preemptively place the reader back in the pool. This helps with
|
||||||
// scenarios where the application does not call NextReader() soon after
|
// scenarios where the application does not call NextReader() soon after
|
||||||
|
|
40
conn.go
40
conn.go
|
@ -38,6 +38,8 @@ const (
|
||||||
|
|
||||||
continuationFrame = 0
|
continuationFrame = 0
|
||||||
noFrame = -1
|
noFrame = -1
|
||||||
|
|
||||||
|
maxWindowBits = 1 << 15
|
||||||
)
|
)
|
||||||
|
|
||||||
// Close codes defined in RFC 6455, section 11.7.
|
// Close codes defined in RFC 6455, section 11.7.
|
||||||
|
@ -260,7 +262,11 @@ type Conn struct {
|
||||||
messageReader *messageReader // the current low-level reader
|
messageReader *messageReader // the current low-level reader
|
||||||
|
|
||||||
readDecompress bool // whether last read frame had RSV1 set
|
readDecompress bool // whether last read frame had RSV1 set
|
||||||
newDecompressionReader func(io.Reader) io.ReadCloser
|
newDecompressionReader func(io.Reader, []byte) io.ReadCloser // arges may flateReadWrapper struct
|
||||||
|
|
||||||
|
contextTakeover bool
|
||||||
|
dict []byte
|
||||||
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
|
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
|
||||||
|
@ -945,9 +951,14 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
||||||
if frameType == TextMessage || frameType == BinaryMessage {
|
if frameType == TextMessage || frameType == BinaryMessage {
|
||||||
c.messageReader = &messageReader{c}
|
c.messageReader = &messageReader{c}
|
||||||
c.reader = c.messageReader
|
c.reader = c.messageReader
|
||||||
if c.readDecompress {
|
|
||||||
c.reader = c.newDecompressionReader(c.reader)
|
switch {
|
||||||
|
case c.readDecompress && c.contextTakeover:
|
||||||
|
c.reader = c.newDecompressionReader(c.reader, c.dict)
|
||||||
|
case c.readDecompress:
|
||||||
|
c.reader = c.newDecompressionReader(c.reader, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
return frameType, c.reader, nil
|
return frameType, c.reader, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -974,9 +985,11 @@ func (r *messageReader) Read(b []byte) (int, error) {
|
||||||
for c.readErr == nil {
|
for c.readErr == nil {
|
||||||
|
|
||||||
if c.readRemaining > 0 {
|
if c.readRemaining > 0 {
|
||||||
|
// Determine the size of the data to be read.
|
||||||
if int64(len(b)) > c.readRemaining {
|
if int64(len(b)) > c.readRemaining {
|
||||||
b = b[:c.readRemaining]
|
b = b[:c.readRemaining]
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := c.br.Read(b)
|
n, err := c.br.Read(b)
|
||||||
c.readErr = hideTempErr(err)
|
c.readErr = hideTempErr(err)
|
||||||
if c.isServer {
|
if c.isServer {
|
||||||
|
@ -986,6 +999,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
|
||||||
if c.readRemaining > 0 && c.readErr == io.EOF {
|
if c.readRemaining > 0 && c.readErr == io.EOF {
|
||||||
c.readErr = errUnexpectedEOF
|
c.readErr = errUnexpectedEOF
|
||||||
}
|
}
|
||||||
|
|
||||||
return n, c.readErr
|
return n, c.readErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1023,6 +1037,12 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
|
||||||
return messageType, nil, err
|
return messageType, nil, err
|
||||||
}
|
}
|
||||||
p, err = ioutil.ReadAll(r)
|
p, err = ioutil.ReadAll(r)
|
||||||
|
|
||||||
|
// if context-takeover add payload to dictionary
|
||||||
|
if c.contextTakeover {
|
||||||
|
c.AddDict(p)
|
||||||
|
}
|
||||||
|
|
||||||
return messageType, p, err
|
return messageType, p, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1139,6 +1159,20 @@ func (c *Conn) SetCompressionLevel(level int) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conn) AddDict(b []byte) {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
|
// Todo I do not know whether to leave the dictionary with 32768 bytes or more
|
||||||
|
// If it is recognized as a duplicate character string,
|
||||||
|
// deleting a part of the character may make it impossible to decrypt it.
|
||||||
|
c.dict = append(b, c.dict...)
|
||||||
|
|
||||||
|
if len(c.dict) > maxWindowBits {
|
||||||
|
c.dict = c.dict[:maxWindowBits]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// FormatCloseMessage formats closeCode and text as a WebSocket close message.
|
// FormatCloseMessage formats closeCode and text as a WebSocket close message.
|
||||||
// An empty message is returned for code CloseNoStatusReceived.
|
// An empty message is returned for code CloseNoStatusReceived.
|
||||||
func FormatCloseMessage(closeCode int, text string) []byte {
|
func FormatCloseMessage(closeCode int, text string) []byte {
|
||||||
|
|
|
@ -186,6 +186,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
if compress {
|
if compress {
|
||||||
switch {
|
switch {
|
||||||
case contextTakeover:
|
case contextTakeover:
|
||||||
|
c.contextTakeover = contextTakeover
|
||||||
c.newCompressionWriter = compressContextTakeover
|
c.newCompressionWriter = compressContextTakeover
|
||||||
c.newDecompressionReader = decompressContextTakeover
|
c.newDecompressionReader = decompressContextTakeover
|
||||||
default:
|
default:
|
||||||
|
|
Loading…
Reference in New Issue