mirror of https://github.com/gorilla/websocket.git
mod: dict strategy
This commit is contained in:
parent
d0e8769234
commit
e2dd00db3d
|
@ -26,7 +26,7 @@ var (
|
||||||
}}
|
}}
|
||||||
)
|
)
|
||||||
|
|
||||||
func decompressNoContextTakeover(r io.Reader, dict []byte) io.ReadCloser {
|
func decompressNoContextTakeover(r io.Reader, dict *[]byte) io.ReadCloser {
|
||||||
const tail =
|
const tail =
|
||||||
// Add four bytes as specified in RFC
|
// Add four bytes as specified in RFC
|
||||||
"\x00\x00\xff\xff" +
|
"\x00\x00\xff\xff" +
|
||||||
|
@ -35,10 +35,10 @@ func decompressNoContextTakeover(r io.Reader, dict []byte) io.ReadCloser {
|
||||||
|
|
||||||
fr, _ := flateReaderPool.Get().(io.ReadCloser)
|
fr, _ := flateReaderPool.Get().(io.ReadCloser)
|
||||||
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
|
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
|
||||||
return &flateReadWrapper{fr}
|
return &flateReadWrapper{fr: fr}
|
||||||
}
|
}
|
||||||
|
|
||||||
func decompressContextTakeover(r io.Reader, dict []byte) io.ReadCloser {
|
func decompressContextTakeover(r io.Reader, dict *[]byte) io.ReadCloser {
|
||||||
const tail =
|
const tail =
|
||||||
// Add four bytes as specified in RFC
|
// Add four bytes as specified in RFC
|
||||||
"\x00\x00\xff\xff" +
|
"\x00\x00\xff\xff" +
|
||||||
|
@ -46,15 +46,21 @@ func decompressContextTakeover(r io.Reader, dict []byte) io.ReadCloser {
|
||||||
"\x01\x00\x00\xff\xff"
|
"\x01\x00\x00\xff\xff"
|
||||||
|
|
||||||
fr, _ := flateReaderPool.Get().(io.ReadCloser)
|
fr, _ := flateReaderPool.Get().(io.ReadCloser)
|
||||||
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), dict)
|
|
||||||
return &flateReadWrapper{fr}
|
if dict != nil {
|
||||||
|
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), *dict)
|
||||||
|
} else {
|
||||||
|
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &flateReadWrapper{fr: fr, hasDict: true, dict: dict}
|
||||||
}
|
}
|
||||||
|
|
||||||
func isValidCompressionLevel(level int) bool {
|
func isValidCompressionLevel(level int) bool {
|
||||||
return minCompressionLevel <= level && level <= maxCompressionLevel
|
return minCompressionLevel <= level && level <= maxCompressionLevel
|
||||||
}
|
}
|
||||||
|
|
||||||
func compressNoContextTakeover(w io.WriteCloser, level int, dict []byte) io.WriteCloser {
|
func compressNoContextTakeover(w io.WriteCloser, level int, dict *[]byte) io.WriteCloser {
|
||||||
p := &flateWriterPools[level-minCompressionLevel]
|
p := &flateWriterPools[level-minCompressionLevel]
|
||||||
tw := &truncWriter{w: w}
|
tw := &truncWriter{w: w}
|
||||||
fw, _ := p.Get().(*flate.Writer)
|
fw, _ := p.Get().(*flate.Writer)
|
||||||
|
@ -66,19 +72,18 @@ func compressNoContextTakeover(w io.WriteCloser, level int, dict []byte) io.Writ
|
||||||
return &flateWriteWrapper{fw: fw, tw: tw, p: p}
|
return &flateWriteWrapper{fw: fw, tw: tw, p: p}
|
||||||
}
|
}
|
||||||
|
|
||||||
func compressContextTakeover(w io.WriteCloser, level int, dict []byte) io.WriteCloser {
|
func compressContextTakeover(w io.WriteCloser, level int, dict *[]byte) io.WriteCloser {
|
||||||
p := &flateWriterDictPools[level-minCompressionLevel]
|
|
||||||
tw := &truncWriter{w: w}
|
tw := &truncWriter{w: w}
|
||||||
|
|
||||||
fw, _ := p.Get().(*flate.Writer)
|
var fw *flate.Writer
|
||||||
if fw == nil {
|
|
||||||
// use WriterDict
|
if dict != nil {
|
||||||
fw, _ = flate.NewWriterDict(tw, level, dict)
|
fw, _ = flate.NewWriterDict(tw, level, *dict)
|
||||||
} else {
|
} else {
|
||||||
fw.Reset(tw)
|
fw, _ = flate.NewWriterDict(tw, level, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &flateWriteWrapper{fw: fw, tw: tw, p: p}
|
return &flateWriteWrapper{fw: fw, tw: tw, hasDict: true, dict: dict}
|
||||||
}
|
}
|
||||||
|
|
||||||
// truncWriter is an io.Writer that writes all but the last four bytes of the
|
// truncWriter is an io.Writer that writes all but the last four bytes of the
|
||||||
|
@ -121,12 +126,20 @@ type flateWriteWrapper struct {
|
||||||
fw *flate.Writer
|
fw *flate.Writer
|
||||||
tw *truncWriter
|
tw *truncWriter
|
||||||
p *sync.Pool
|
p *sync.Pool
|
||||||
|
|
||||||
|
hasDict bool
|
||||||
|
dict *[]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *flateWriteWrapper) Write(p []byte) (int, error) {
|
func (w *flateWriteWrapper) Write(p []byte) (int, error) {
|
||||||
if w.fw == nil {
|
if w.fw == nil {
|
||||||
return 0, errWriteClosed
|
return 0, errWriteClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if w.hasDict {
|
||||||
|
w.addDict(p)
|
||||||
|
}
|
||||||
|
|
||||||
return w.fw.Write(p)
|
return w.fw.Write(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,7 +148,11 @@ func (w *flateWriteWrapper) Close() error {
|
||||||
return errWriteClosed
|
return errWriteClosed
|
||||||
}
|
}
|
||||||
err1 := w.fw.Flush()
|
err1 := w.fw.Flush()
|
||||||
w.p.Put(w.fw)
|
|
||||||
|
if !w.hasDict {
|
||||||
|
w.p.Put(w.fw)
|
||||||
|
}
|
||||||
|
|
||||||
w.fw = nil
|
w.fw = nil
|
||||||
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
|
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
|
||||||
return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
|
return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
|
||||||
|
@ -147,8 +164,21 @@ func (w *flateWriteWrapper) Close() error {
|
||||||
return err2
|
return err2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addDict adds payload to dict.
|
||||||
|
func (w *flateWriteWrapper) addDict(b []byte) {
|
||||||
|
*w.dict = append(*w.dict, b...)
|
||||||
|
|
||||||
|
if len(*w.dict) > maxWindowBits {
|
||||||
|
offset := len(*w.dict) - maxWindowBits
|
||||||
|
*w.dict = (*w.dict)[offset:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type flateReadWrapper struct {
|
type flateReadWrapper struct {
|
||||||
fr io.ReadCloser // flate.NewReader
|
fr io.ReadCloser // flate.NewReader
|
||||||
|
|
||||||
|
hasDict bool
|
||||||
|
dict *[]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *flateReadWrapper) Read(p []byte) (int, error) {
|
func (r *flateReadWrapper) Read(p []byte) (int, error) {
|
||||||
|
@ -164,6 +194,13 @@ func (r *flateReadWrapper) Read(p []byte) (int, error) {
|
||||||
// this final read.
|
// this final read.
|
||||||
r.Close()
|
r.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.hasDict {
|
||||||
|
if n > 0 {
|
||||||
|
r.addDict(p[:n])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -172,7 +209,21 @@ func (r *flateReadWrapper) Close() error {
|
||||||
return io.ErrClosedPipe
|
return io.ErrClosedPipe
|
||||||
}
|
}
|
||||||
err := r.fr.Close()
|
err := r.fr.Close()
|
||||||
flateReaderPool.Put(r.fr)
|
|
||||||
|
if !r.hasDict {
|
||||||
|
flateReaderPool.Put(r.fr)
|
||||||
|
}
|
||||||
|
|
||||||
r.fr = nil
|
r.fr = nil
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addDict adds payload to dict.
|
||||||
|
func (r *flateReadWrapper) addDict(b []byte) {
|
||||||
|
*r.dict = append(*r.dict, b...)
|
||||||
|
|
||||||
|
if len(*r.dict) > maxWindowBits {
|
||||||
|
offset := len(*r.dict) - maxWindowBits
|
||||||
|
*r.dict = (*r.dict)[offset:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
41
conn.go
41
conn.go
|
@ -243,7 +243,7 @@ type Conn struct {
|
||||||
|
|
||||||
enableWriteCompression bool
|
enableWriteCompression bool
|
||||||
compressionLevel int
|
compressionLevel int
|
||||||
newCompressionWriter func(io.WriteCloser, int, []byte) io.WriteCloser
|
newCompressionWriter func(io.WriteCloser, int, *[]byte) io.WriteCloser
|
||||||
|
|
||||||
// Read fields
|
// Read fields
|
||||||
reader io.ReadCloser // the current reader returned to the application
|
reader io.ReadCloser // the current reader returned to the application
|
||||||
|
@ -261,12 +261,12 @@ type Conn struct {
|
||||||
readErrCount int
|
readErrCount int
|
||||||
messageReader *messageReader // the current low-level reader
|
messageReader *messageReader // the current low-level reader
|
||||||
|
|
||||||
readDecompress bool // whether last read frame had RSV1 set
|
readDecompress bool // whether last read frame had RSV1 set
|
||||||
newDecompressionReader func(io.Reader, []byte) io.ReadCloser // arges may flateReadWrapper struct
|
newDecompressionReader func(io.Reader, *[]byte) io.ReadCloser // arges may flateReadWrapper struct
|
||||||
|
|
||||||
contextTakeover bool
|
contextTakeover bool
|
||||||
txDict []byte
|
txDict *[]byte
|
||||||
rxDict []byte
|
rxDict *[]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
|
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
|
||||||
|
@ -336,6 +336,9 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in
|
||||||
writeBuf: writeBuf,
|
writeBuf: writeBuf,
|
||||||
enableWriteCompression: true,
|
enableWriteCompression: true,
|
||||||
compressionLevel: defaultCompressionLevel,
|
compressionLevel: defaultCompressionLevel,
|
||||||
|
|
||||||
|
txDict: &[]byte{},
|
||||||
|
rxDict: &[]byte{},
|
||||||
}
|
}
|
||||||
c.SetCloseHandler(nil)
|
c.SetCloseHandler(nil)
|
||||||
c.SetPingHandler(nil)
|
c.SetPingHandler(nil)
|
||||||
|
@ -763,9 +766,6 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
|
||||||
if _, err = w.Write(data); err != nil {
|
if _, err = w.Write(data); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if c.contextTakeover {
|
|
||||||
c.AddTxDict(data)
|
|
||||||
}
|
|
||||||
return w.Close()
|
return w.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1046,11 +1046,6 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
|
||||||
}
|
}
|
||||||
p, err = ioutil.ReadAll(r)
|
p, err = ioutil.ReadAll(r)
|
||||||
|
|
||||||
// if context-takeover add payload to dictionary
|
|
||||||
if c.contextTakeover {
|
|
||||||
c.AddRxDict(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
return messageType, p, err
|
return messageType, p, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1167,26 +1162,6 @@ func (c *Conn) SetCompressionLevel(level int) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddTxDict adds payload to txDict.
|
|
||||||
func (c *Conn) AddTxDict(b []byte) {
|
|
||||||
c.txDict = append(c.txDict, b...)
|
|
||||||
|
|
||||||
if len(c.txDict) > maxWindowBits {
|
|
||||||
offset := len(c.txDict) - maxWindowBits
|
|
||||||
c.txDict = c.txDict[offset:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddTxDict adds payload to rxDict.
|
|
||||||
func (c *Conn) AddRxDict(b []byte) {
|
|
||||||
c.rxDict = append(c.rxDict, b...)
|
|
||||||
|
|
||||||
if len(c.rxDict) > maxWindowBits {
|
|
||||||
offset := len(c.rxDict) - maxWindowBits
|
|
||||||
c.rxDict = c.rxDict[offset:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// FormatCloseMessage formats closeCode and text as a WebSocket close message.
|
// FormatCloseMessage formats closeCode and text as a WebSocket close message.
|
||||||
// An empty message is returned for code CloseNoStatusReceived.
|
// An empty message is returned for code CloseNoStatusReceived.
|
||||||
func FormatCloseMessage(closeCode int, text string) []byte {
|
func FormatCloseMessage(closeCode int, text string) []byte {
|
||||||
|
|
11
conn_test.go
11
conn_test.go
|
@ -494,14 +494,3 @@ func TestBufioReuse(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkAddDict(b *testing.B) {
|
|
||||||
w := ioutil.Discard
|
|
||||||
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
|
|
||||||
messages := textMessages(100)
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
c.AddRxDict(messages[i%len(messages)])
|
|
||||||
}
|
|
||||||
b.ReportAllocs()
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue