mod: dict strategy

This commit is contained in:
misu 2018-01-29 15:10:19 +09:00
parent d0e8769234
commit e2dd00db3d
3 changed files with 75 additions and 60 deletions

View File

@ -26,7 +26,7 @@ var (
}} }}
) )
func decompressNoContextTakeover(r io.Reader, dict []byte) io.ReadCloser { func decompressNoContextTakeover(r io.Reader, dict *[]byte) io.ReadCloser {
const tail = const tail =
// Add four bytes as specified in RFC // Add four bytes as specified in RFC
"\x00\x00\xff\xff" + "\x00\x00\xff\xff" +
@ -35,10 +35,10 @@ func decompressNoContextTakeover(r io.Reader, dict []byte) io.ReadCloser {
fr, _ := flateReaderPool.Get().(io.ReadCloser) fr, _ := flateReaderPool.Get().(io.ReadCloser)
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
return &flateReadWrapper{fr} return &flateReadWrapper{fr: fr}
} }
func decompressContextTakeover(r io.Reader, dict []byte) io.ReadCloser { func decompressContextTakeover(r io.Reader, dict *[]byte) io.ReadCloser {
const tail = const tail =
// Add four bytes as specified in RFC // Add four bytes as specified in RFC
"\x00\x00\xff\xff" + "\x00\x00\xff\xff" +
@ -46,15 +46,21 @@ func decompressContextTakeover(r io.Reader, dict []byte) io.ReadCloser {
"\x01\x00\x00\xff\xff" "\x01\x00\x00\xff\xff"
fr, _ := flateReaderPool.Get().(io.ReadCloser) fr, _ := flateReaderPool.Get().(io.ReadCloser)
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), dict)
return &flateReadWrapper{fr} if dict != nil {
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), *dict)
} else {
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
}
return &flateReadWrapper{fr: fr, hasDict: true, dict: dict}
} }
func isValidCompressionLevel(level int) bool { func isValidCompressionLevel(level int) bool {
return minCompressionLevel <= level && level <= maxCompressionLevel return minCompressionLevel <= level && level <= maxCompressionLevel
} }
func compressNoContextTakeover(w io.WriteCloser, level int, dict []byte) io.WriteCloser { func compressNoContextTakeover(w io.WriteCloser, level int, dict *[]byte) io.WriteCloser {
p := &flateWriterPools[level-minCompressionLevel] p := &flateWriterPools[level-minCompressionLevel]
tw := &truncWriter{w: w} tw := &truncWriter{w: w}
fw, _ := p.Get().(*flate.Writer) fw, _ := p.Get().(*flate.Writer)
@ -66,19 +72,18 @@ func compressNoContextTakeover(w io.WriteCloser, level int, dict []byte) io.Writ
return &flateWriteWrapper{fw: fw, tw: tw, p: p} return &flateWriteWrapper{fw: fw, tw: tw, p: p}
} }
func compressContextTakeover(w io.WriteCloser, level int, dict []byte) io.WriteCloser { func compressContextTakeover(w io.WriteCloser, level int, dict *[]byte) io.WriteCloser {
p := &flateWriterDictPools[level-minCompressionLevel]
tw := &truncWriter{w: w} tw := &truncWriter{w: w}
fw, _ := p.Get().(*flate.Writer) var fw *flate.Writer
if fw == nil {
// use WriterDict if dict != nil {
fw, _ = flate.NewWriterDict(tw, level, dict) fw, _ = flate.NewWriterDict(tw, level, *dict)
} else { } else {
fw.Reset(tw) fw, _ = flate.NewWriterDict(tw, level, nil)
} }
return &flateWriteWrapper{fw: fw, tw: tw, p: p} return &flateWriteWrapper{fw: fw, tw: tw, hasDict: true, dict: dict}
} }
// truncWriter is an io.Writer that writes all but the last four bytes of the // truncWriter is an io.Writer that writes all but the last four bytes of the
@ -121,12 +126,20 @@ type flateWriteWrapper struct {
fw *flate.Writer fw *flate.Writer
tw *truncWriter tw *truncWriter
p *sync.Pool p *sync.Pool
hasDict bool
dict *[]byte
} }
func (w *flateWriteWrapper) Write(p []byte) (int, error) { func (w *flateWriteWrapper) Write(p []byte) (int, error) {
if w.fw == nil { if w.fw == nil {
return 0, errWriteClosed return 0, errWriteClosed
} }
if w.hasDict {
w.addDict(p)
}
return w.fw.Write(p) return w.fw.Write(p)
} }
@ -135,7 +148,11 @@ func (w *flateWriteWrapper) Close() error {
return errWriteClosed return errWriteClosed
} }
err1 := w.fw.Flush() err1 := w.fw.Flush()
w.p.Put(w.fw)
if !w.hasDict {
w.p.Put(w.fw)
}
w.fw = nil w.fw = nil
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
return errors.New("websocket: internal error, unexpected bytes at end of flate stream") return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
@ -147,8 +164,21 @@ func (w *flateWriteWrapper) Close() error {
return err2 return err2
} }
// addDict adds payload to dict.
func (w *flateWriteWrapper) addDict(b []byte) {
*w.dict = append(*w.dict, b...)
if len(*w.dict) > maxWindowBits {
offset := len(*w.dict) - maxWindowBits
*w.dict = (*w.dict)[offset:]
}
}
type flateReadWrapper struct { type flateReadWrapper struct {
fr io.ReadCloser // flate.NewReader fr io.ReadCloser // flate.NewReader
hasDict bool
dict *[]byte
} }
func (r *flateReadWrapper) Read(p []byte) (int, error) { func (r *flateReadWrapper) Read(p []byte) (int, error) {
@ -164,6 +194,13 @@ func (r *flateReadWrapper) Read(p []byte) (int, error) {
// this final read. // this final read.
r.Close() r.Close()
} }
if r.hasDict {
if n > 0 {
r.addDict(p[:n])
}
}
return n, err return n, err
} }
@ -172,7 +209,21 @@ func (r *flateReadWrapper) Close() error {
return io.ErrClosedPipe return io.ErrClosedPipe
} }
err := r.fr.Close() err := r.fr.Close()
flateReaderPool.Put(r.fr)
if !r.hasDict {
flateReaderPool.Put(r.fr)
}
r.fr = nil r.fr = nil
return err return err
} }
// addDict adds payload to dict.
func (r *flateReadWrapper) addDict(b []byte) {
*r.dict = append(*r.dict, b...)
if len(*r.dict) > maxWindowBits {
offset := len(*r.dict) - maxWindowBits
*r.dict = (*r.dict)[offset:]
}
}

41
conn.go
View File

@ -243,7 +243,7 @@ type Conn struct {
enableWriteCompression bool enableWriteCompression bool
compressionLevel int compressionLevel int
newCompressionWriter func(io.WriteCloser, int, []byte) io.WriteCloser newCompressionWriter func(io.WriteCloser, int, *[]byte) io.WriteCloser
// Read fields // Read fields
reader io.ReadCloser // the current reader returned to the application reader io.ReadCloser // the current reader returned to the application
@ -261,12 +261,12 @@ type Conn struct {
readErrCount int readErrCount int
messageReader *messageReader // the current low-level reader messageReader *messageReader // the current low-level reader
readDecompress bool // whether last read frame had RSV1 set readDecompress bool // whether last read frame had RSV1 set
newDecompressionReader func(io.Reader, []byte) io.ReadCloser // arges may flateReadWrapper struct newDecompressionReader func(io.Reader, *[]byte) io.ReadCloser // arges may flateReadWrapper struct
contextTakeover bool contextTakeover bool
txDict []byte txDict *[]byte
rxDict []byte rxDict *[]byte
} }
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
@ -336,6 +336,9 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in
writeBuf: writeBuf, writeBuf: writeBuf,
enableWriteCompression: true, enableWriteCompression: true,
compressionLevel: defaultCompressionLevel, compressionLevel: defaultCompressionLevel,
txDict: &[]byte{},
rxDict: &[]byte{},
} }
c.SetCloseHandler(nil) c.SetCloseHandler(nil)
c.SetPingHandler(nil) c.SetPingHandler(nil)
@ -763,9 +766,6 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
if _, err = w.Write(data); err != nil { if _, err = w.Write(data); err != nil {
return err return err
} }
if c.contextTakeover {
c.AddTxDict(data)
}
return w.Close() return w.Close()
} }
@ -1046,11 +1046,6 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
} }
p, err = ioutil.ReadAll(r) p, err = ioutil.ReadAll(r)
// if context-takeover add payload to dictionary
if c.contextTakeover {
c.AddRxDict(p)
}
return messageType, p, err return messageType, p, err
} }
@ -1167,26 +1162,6 @@ func (c *Conn) SetCompressionLevel(level int) error {
return nil return nil
} }
// AddTxDict adds payload to txDict.
func (c *Conn) AddTxDict(b []byte) {
c.txDict = append(c.txDict, b...)
if len(c.txDict) > maxWindowBits {
offset := len(c.txDict) - maxWindowBits
c.txDict = c.txDict[offset:]
}
}
// AddTxDict adds payload to rxDict.
func (c *Conn) AddRxDict(b []byte) {
c.rxDict = append(c.rxDict, b...)
if len(c.rxDict) > maxWindowBits {
offset := len(c.rxDict) - maxWindowBits
c.rxDict = c.rxDict[offset:]
}
}
// FormatCloseMessage formats closeCode and text as a WebSocket close message. // FormatCloseMessage formats closeCode and text as a WebSocket close message.
// An empty message is returned for code CloseNoStatusReceived. // An empty message is returned for code CloseNoStatusReceived.
func FormatCloseMessage(closeCode int, text string) []byte { func FormatCloseMessage(closeCode int, text string) []byte {

View File

@ -494,14 +494,3 @@ func TestBufioReuse(t *testing.T) {
} }
} }
func BenchmarkAddDict(b *testing.B) {
w := ioutil.Discard
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
messages := textMessages(100)
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.AddRxDict(messages[i%len(messages)])
}
b.ReportAllocs()
}