evio/evio_translate.go

264 lines
6.3 KiB
Go

// Copyright 2017 Joshua J Baker. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package evio
import (
"io"
"net"
"sync"
"time"
)
type tconn struct {
cond [2]*sync.Cond // stream locks
closed [2]bool // init to -1. when it reaches zero we're closed
prebuf [2][]byte // buffers before translation
postbuf [2][]byte // buffers after translate
rd [2]io.ReadCloser // reader pipes
wr [2]io.Writer // writer pipes
mu sync.Mutex // only for the error
action Action // the last known action
err error // the final error if any
}
func (c *tconn) write(st int, b []byte) {
c.cond[st].L.Lock()
c.prebuf[st] = append(c.prebuf[st], b...)
c.cond[st].Broadcast()
c.cond[st].L.Unlock()
}
func (c *tconn) read(st int) []byte {
c.cond[st].L.Lock()
buf := c.postbuf[st]
c.postbuf[st] = nil
c.cond[st].L.Unlock()
return buf
}
func (c *tconn) Read(p []byte) (n int, err error) { return c.rd[0].Read(p) }
func (c *tconn) Write(p []byte) (n int, err error) { return c.wr[1].Write(p) }
// nopConn just wraps a io.ReadWriter and makes it into a net.Conn.
type nopConn struct{ io.ReadWriter }
func (c *nopConn) Read(p []byte) (n int, err error) { return c.ReadWriter.Read(p) }
func (c *nopConn) Write(p []byte) (n int, err error) { return c.ReadWriter.Write(p) }
func (c *nopConn) LocalAddr() net.Addr { return nil }
func (c *nopConn) RemoteAddr() net.Addr { return nil }
func (c *nopConn) SetDeadline(deadline time.Time) error { return nil }
func (c *nopConn) SetWriteDeadline(deadline time.Time) error { return nil }
func (c *nopConn) SetReadDeadline(deadline time.Time) error { return nil }
func (c *nopConn) Close() error { return nil }
// NopConn returns a net.Conn with a no-op LocalAddr, RemoteAddr,
// SetDeadline, SetWriteDeadline, SetReadDeadline, and Close methods wrapping
// the provided ReadWriter rw.
func NopConn(rw io.ReadWriter) net.Conn {
return &nopConn{rw}
}
// Translate provides a utility for performing byte level translation
// on the input and output streams for a connection. This is useful for
// things like compression, encryption, TLS, etc. The function wraps
// existing events and returns new events that manage the translation.
// The `should` parameter is an optional function that can be used to
// ignore or accept the translation for a specific connection.
// The `translate` parameter is a function that provides a ReadWriter
// for each new connection and returns a ReadWriter that performs the
// actual translation.
func Translate(
events Events,
should func(id int, addr Addr) bool,
translate func(rd io.ReadWriter) io.ReadWriter,
) Events {
tevents := events
var wake func(id int) bool
var mu sync.Mutex
idc := make(map[int]*tconn)
get := func(id int) *tconn {
mu.Lock()
c := idc[id]
mu.Unlock()
return c
}
create := func(id int) *tconn {
mu.Lock()
c := &tconn{
cond: [2]*sync.Cond{
sync.NewCond(&sync.Mutex{}),
sync.NewCond(&sync.Mutex{}),
},
}
idc[id] = c
mu.Unlock()
tc := translate(c)
for st := 0; st < 2; st++ {
c.rd[st], c.wr[st] = io.Pipe()
var rd io.Reader
var wr io.Writer
if st == 0 {
rd = tc
wr = c.wr[0]
} else {
rd = c.rd[1]
wr = tc
}
go func(st int, rd io.Reader, wr io.Writer) {
c.cond[st].L.Lock()
for {
if c.closed[st] {
break
}
if len(c.prebuf[st]) > 0 {
buf := c.prebuf[st]
c.prebuf[st] = nil
c.cond[st].L.Unlock()
n, err := wr.Write(buf)
if err != nil {
return
}
c.cond[st].L.Lock()
if n > 0 {
c.prebuf[st] = append(buf[n:], c.prebuf[st]...)
}
continue
}
c.cond[st].Wait()
}
c.cond[st].L.Unlock()
}(st, rd, wr)
go func(st int, wr io.Writer) {
var ferr error
defer func() {
if ferr != nil {
c.mu.Lock()
if c.err == nil {
c.err = ferr
}
c.mu.Unlock()
}
}()
var packet [2048]byte
for {
n, err := rd.Read(packet[:])
if err != nil {
if err != io.EOF && err != io.ErrClosedPipe {
ferr = err
}
return
}
c.cond[st].L.Lock()
c.postbuf[st] = append(c.postbuf[st], packet[:n]...)
c.cond[st].L.Unlock()
wake(id)
}
}(st, wr)
}
return c
}
destroy := func(c *tconn, id int) error {
for st := 0; st < 2; st++ {
if rd, ok := c.rd[st].(io.Closer); ok {
rd.Close()
}
if wr, ok := c.wr[st].(io.Closer); ok {
wr.Close()
}
c.cond[st].L.Lock()
c.closed[st] = true
c.cond[st].Broadcast()
c.cond[st].L.Unlock()
}
mu.Lock()
delete(idc, id)
mu.Unlock()
c.mu.Lock()
err := c.err
c.mu.Unlock()
return err
}
tevents.Serving = func(wakefn func(id int) bool, addrs []net.Addr) (action Action) {
wake = wakefn
if events.Serving != nil {
action = events.Serving(wakefn, addrs)
}
return
}
tevents.Opened = func(id int, addr Addr) (out []byte, opts Options, action Action) {
if should != nil && !should(id, addr) {
if events.Opened != nil {
out, opts, action = events.Opened(id, addr)
}
return
}
c := create(id)
if events.Opened != nil {
out, opts, c.action = events.Opened(id, addr)
if len(out) > 0 {
c.write(1, out)
out = nil
wake(id)
}
}
return
}
tevents.Closed = func(id int, err error) (action Action) {
c := get(id)
if c != nil {
ferr := destroy(c, id)
if err == nil {
err = ferr
}
}
if events.Closed != nil {
action = events.Closed(id, err)
}
return
}
tevents.Data = func(id int, in []byte) (out []byte, action Action) {
c := get(id)
if c == nil {
if events.Data != nil {
out, action = events.Data(id, in)
}
return
}
if in == nil {
// wake up
out = c.read(1)
if len(out) > 0 {
wake(id)
return
}
if c.action != None {
return nil, c.action
}
in = c.read(0)
if len(in) > 0 {
if events.Data != nil {
out, c.action = events.Data(id, in)
if len(out) > 0 {
c.write(1, out)
out = nil
}
wake(id)
}
return
}
} else if len(in) > 0 {
if c.action != None {
return nil, c.action
}
// accept new input data
c.write(0, in)
in = nil
}
return
}
return tevents
}