evio/evio_test.go

724 lines
15 KiB
Go

// Copyright 2017 Joshua J Baker. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package evio
import (
"bufio"
"fmt"
"io"
"math/rand"
"net"
"os"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestServe(t *testing.T) {
// start a server
// connect 10 clients
// each client will pipe random data for 1-3 seconds.
// the writes to the server will be random sizes. 0KB - 1MB.
// the server will echo back the data.
// waits for graceful connection closing.
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
testServe("tcp", ":9990", false, 10)
}()
wg.Add(1)
go func() {
defer wg.Done()
testServe("tcp", ":9991", true, 10)
}()
wg.Add(1)
go func() {
defer wg.Done()
testServe("tcp-net", ":9992", false, 10)
}()
wg.Add(1)
go func() {
defer wg.Done()
testServe("tcp-net", ":9993", true, 10)
}()
wg.Wait()
}
func testServe(network, addr string, unix bool, nclients int) {
var started bool
var connected int
var disconnected int
var events Events
events.Serving = func(srv Server) (action Action) {
return
}
events.Opened = func(id int, info Info) (out []byte, opts Options, action Action) {
connected++
out = []byte("sweetness\r\n")
opts.TCPKeepAlive = time.Minute * 5
return
}
events.Closed = func(id int, err error) (action Action) {
disconnected++
if connected == disconnected && disconnected == nclients {
action = Shutdown
}
return
}
events.Data = func(id int, in []byte) (out []byte, action Action) {
out = in
return
}
events.Tick = func() (delay time.Duration, action Action) {
if !started {
for i := 0; i < nclients; i++ {
go startClient(network, addr)
}
started = true
}
delay = time.Second / 5
return
}
var err error
if unix {
socket := strings.Replace(addr, ":", "socket", 1)
os.RemoveAll(socket)
defer os.RemoveAll(socket)
err = Serve(events, network+"://"+addr, "unix://"+socket)
} else {
err = Serve(events, network+"://"+addr)
}
if err != nil {
panic(err)
}
}
func startClient(network, addr string) {
network = strings.Replace(network, "-net", "", -1)
rand.Seed(time.Now().UnixNano())
c, err := net.Dial(network, addr)
if err != nil {
panic(err)
}
defer c.Close()
rd := bufio.NewReader(c)
msg, err := rd.ReadBytes('\n')
if err != nil {
panic(err)
}
if string(msg) != "sweetness\r\n" {
panic("bad header")
}
duration := time.Duration((rand.Float64()*2 + 1) * float64(time.Second))
start := time.Now()
for time.Since(start) < duration {
sz := rand.Int() % (1024 * 1024)
data := make([]byte, sz)
if _, err := rand.Read(data); err != nil {
panic(err)
}
if _, err := c.Write(data); err != nil {
panic(err)
}
data2 := make([]byte, sz)
if _, err := io.ReadFull(rd, data2); err != nil {
panic(err)
}
if string(data) != string(data2) {
fmt.Printf("mismatch: %d bytes\n", len(data))
//panic("mismatch")
}
}
}
func TestWake(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
testWake("tcp", ":9991", false)
}()
wg.Add(1)
go func() {
defer wg.Done()
testWake("tcp", ":9992", true)
}()
wg.Add(1)
go func() {
defer wg.Done()
testWake("unix", "socket1", false)
}()
wg.Add(1)
go func() {
defer wg.Done()
testWake("unix", "socket2", true)
}()
wg.Wait()
}
func testWake(network, addr string, stdlib bool) {
var events Events
var srv Server
events.Serving = func(srvin Server) (action Action) {
srv = srvin
go func() {
conn, err := net.Dial(network, addr)
must(err)
defer conn.Close()
rd := bufio.NewReader(conn)
for i := 0; i < 1000; i++ {
line := []byte(fmt.Sprintf("msg%d\r\n", i))
conn.Write(line)
data, err := rd.ReadBytes('\n')
must(err)
if string(data) != string(line) {
panic("msg mismatch")
}
}
}()
return
}
var cid int
var cout []byte
var cin []byte
var cclosed bool
var cond = sync.NewCond(&sync.Mutex{})
events.Opened = func(id int, info Info) (out []byte, opts Options, action Action) {
cid = id
return
}
events.Closed = func(id int, err error) (action Action) {
action = Shutdown
cond.L.Lock()
cclosed = true
cond.Broadcast()
cond.L.Unlock()
return
}
go func() {
cond.L.Lock()
for !cclosed {
if len(cin) > 0 {
cout = append(cout, cin...)
cin = nil
}
if len(cout) > 0 {
srv.Wake(cid)
}
cond.Wait()
}
cond.L.Unlock()
}()
events.Data = func(id int, in []byte) (out []byte, action Action) {
if in == nil {
cond.L.Lock()
out = cout
cout = nil
cond.L.Unlock()
} else {
cond.L.Lock()
cin = append(cin, in...)
cond.Broadcast()
cond.L.Unlock()
}
return
}
if stdlib {
must(Serve(events, network+"-net://"+addr))
} else {
must(Serve(events, network+"://"+addr))
}
}
func must(err error) {
if err != nil {
panic(err)
}
}
func TestTick(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
testTick("tcp", ":9991", false)
}()
wg.Add(1)
go func() {
defer wg.Done()
testTick("tcp", ":9992", true)
}()
wg.Add(1)
go func() {
defer wg.Done()
testTick("unix", "socket1", false)
}()
wg.Add(1)
go func() {
defer wg.Done()
testTick("unix", "socket2", true)
}()
wg.Wait()
}
func testTick(network, addr string, stdlib bool) {
var events Events
var count int
start := time.Now()
events.Tick = func() (delay time.Duration, action Action) {
if count == 25 {
action = Shutdown
return
}
count++
delay = time.Millisecond * 10
return
}
if stdlib {
must(Serve(events, network+"-net://"+addr))
} else {
must(Serve(events, network+"://"+addr))
}
dur := time.Since(start)
if dur < 250&time.Millisecond || dur > time.Second {
panic("bad ticker timing")
}
}
func TestShutdown(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
testShutdown("tcp", ":9991", false)
}()
wg.Add(1)
go func() {
defer wg.Done()
testShutdown("tcp", ":9992", true)
}()
wg.Add(1)
go func() {
defer wg.Done()
testShutdown("unix", "socket1", false)
}()
wg.Add(1)
go func() {
defer wg.Done()
testShutdown("unix", "socket2", true)
}()
wg.Wait()
}
func testShutdown(network, addr string, stdlib bool) {
var events Events
var count int
var clients int64
var N = 10
events.Opened = func(id int, info Info) (out []byte, opts Options, action Action) {
atomic.AddInt64(&clients, 1)
return
}
events.Closed = func(id int, err error) (action Action) {
atomic.AddInt64(&clients, -1)
return
}
events.Tick = func() (delay time.Duration, action Action) {
if count == 0 {
// start clients
for i := 0; i < N; i++ {
go func() {
conn, err := net.Dial(network, addr)
must(err)
defer conn.Close()
_, err = conn.Read([]byte{0})
if err == nil {
panic("expected error")
}
}()
}
} else {
if int(atomic.LoadInt64(&clients)) == N {
action = Shutdown
}
}
count++
delay = time.Second / 5
return
}
if stdlib {
must(Serve(events, network+"-net://"+addr))
} else {
must(Serve(events, network+"://"+addr))
}
if clients != 0 {
panic("did not call close on all clients")
}
}
func TestDetach(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
testDetach("tcp", ":9991", false)
}()
wg.Add(1)
go func() {
defer wg.Done()
testDetach("tcp", ":9992", true)
}()
wg.Add(1)
go func() {
defer wg.Done()
testDetach("unix", "socket1", false)
}()
wg.Add(1)
go func() {
defer wg.Done()
testDetach("unix", "socket2", true)
}()
wg.Wait()
}
func testDetach(network, addr string, stdlib bool) {
// we will write a bunch of data with the text "--detached--" in the
// middle followed by a bunch of data.
rand.Seed(time.Now().UnixNano())
rdat := make([]byte, 10*1024*1024)
if _, err := rand.Read(rdat); err != nil {
panic("random error: " + err.Error())
}
expected := []byte(string(rdat) + "--detached--" + string(rdat))
var cin []byte
var events Events
events.Data = func(id int, in []byte) (out []byte, action Action) {
cin = append(cin, in...)
if len(cin) == len(expected) {
if string(cin) != string(expected) {
panic("mismatch client -> server")
}
return cin, Detach
}
return
}
//expected := "detached\r\n"
var done int64
events.Detached = func(id int, conn io.ReadWriteCloser) (action Action) {
go func() {
defer conn.Close()
// detached connection
n, err := conn.Write([]byte(expected))
must(err)
if n != len(expected) {
panic("not enough data written")
}
}()
return
}
events.Serving = func(srv Server) (action Action) {
go func() {
// client connection
conn, err := net.Dial(network, addr)
must(err)
defer conn.Close()
_, err = conn.Write(expected)
must(err)
// read from the attached response
packet := make([]byte, len(expected))
time.Sleep(time.Second / 3)
_, err = io.ReadFull(conn, packet)
must(err)
if string(packet) != string(expected) {
panic("mismatch server -> client 1")
}
// read from the detached response
time.Sleep(time.Second / 3)
_, err = io.ReadFull(conn, packet)
must(err)
if string(packet) != string(expected) {
panic("mismatch server -> client 2")
}
time.Sleep(time.Second / 3)
_, err = conn.Read([]byte{0})
if err == nil {
panic("expected nil, got '" + err.Error() + "'")
}
atomic.StoreInt64(&done, 1)
}()
return
}
events.Tick = func() (delay time.Duration, action Action) {
delay = time.Second / 5
if atomic.LoadInt64(&done) == 1 {
action = Shutdown
}
return
}
if stdlib {
must(Serve(events, network+"-net://"+addr))
} else {
must(Serve(events, network+"://"+addr))
}
}
func TestBadAddresses(t *testing.T) {
var events Events
events.Serving = func(srv Server) (action Action) {
return Shutdown
}
if err := Serve(events, "tulip://howdy"); err == nil {
t.Fatalf("expected error")
}
if err := Serve(events, "howdy"); err == nil {
t.Fatalf("expected error")
}
if err := Serve(events, "tcp://"); err != nil {
t.Fatalf("expected nil, got '%v'", err)
}
}
func TestInputStream(t *testing.T) {
var s InputStream
in := []byte("HELLO")
data := s.Begin(in)
if string(data) != string(in) {
t.Fatalf("expected '%v', got '%v'", in, data)
}
s.End(in[3:])
data = s.Begin([]byte("WLY"))
if string(data) != "LOWLY" {
t.Fatalf("expected '%v', got '%v'", "LOWLY", data)
}
s.End(nil)
data = s.Begin([]byte("PLAYER"))
if string(data) != "PLAYER" {
t.Fatalf("expected '%v', got '%v'", "PLAYER", data)
}
}
func TestPrePostwrite(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
testPrePostwrite("tcp", ":9991", false)
}()
wg.Add(1)
go func() {
defer wg.Done()
testPrePostwrite("tcp", ":9992", true)
}()
wg.Add(1)
go func() {
defer wg.Done()
testPrePostwrite("unix", "socket1", false)
}()
wg.Add(1)
go func() {
defer wg.Done()
testPrePostwrite("unix", "socket2", true)
}()
wg.Wait()
}
func testPrePostwrite(network, addr string, stdlib bool) {
var events Events
var srv Server
var packets int
var tout []byte
events.Opened = func(id int, info Info) (out []byte, opts Options, action Action) {
packets++
out = []byte(fmt.Sprintf("hello %d\r\n", packets))
tout = append(tout, out...)
srv.Wake(id)
return
}
events.Data = func(id int, in []byte) (out []byte, action Action) {
packets++
out = []byte(fmt.Sprintf("hello %d\r\n", packets))
tout = append(tout, out...)
srv.Wake(id)
return
}
events.Prewrite = func(id int, amount int) (action Action) {
if amount != len(tout) {
panic("invalid prewrite amount")
}
return
}
events.Postwrite = func(id int, amount, remaining int) (action Action) {
tout = tout[amount:]
if remaining != len(tout) {
panic("invalid postwrite amount")
}
return
}
events.Closed = func(id int, err error) (action Action) {
action = Shutdown
return
}
events.Serving = func(srvin Server) (action Action) {
srv = srvin
go func() {
conn, err := net.Dial(network, addr)
must(err)
defer conn.Close()
rd := bufio.NewReader(conn)
for i := 0; i < 1000; i++ {
line, err := rd.ReadBytes('\n')
must(err)
ex := fmt.Sprintf("hello %d\r\n", i+1)
if string(line) != ex {
panic(fmt.Sprintf("expected '%v', got '%v'", ex, line))
}
}
}()
return
}
if stdlib {
must(Serve(events, network+"-net://"+addr))
} else {
must(Serve(events, network+"://"+addr))
}
}
func TestTranslate(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
testTranslate("tcp", ":9991", "passthrough", false)
}()
wg.Add(1)
go func() {
defer wg.Done()
testTranslate("tcp", ":9992", "passthrough", true)
}()
wg.Add(1)
go func() {
defer wg.Done()
testTranslate("unix", "socket1", "passthrough", false)
}()
wg.Add(1)
go func() {
defer wg.Done()
testTranslate("unix", "socket2", "passthrough", true)
}()
wg.Wait()
}
func testTranslate(network, addr string, kind string, stdlib bool) {
var events Events
events.Data = func(id int, in []byte) (out []byte, action Action) {
out = in
return
}
events.Closed = func(id int, err error) (action Action) {
action = Shutdown
return
}
events.Opened = func(id int, info Info) (out []byte, opts Options, action Action) {
out = []byte("sweetness\r\n")
return
}
events.Serving = func(srv Server) (action Action) {
go func() {
conn, err := net.Dial(network, addr)
must(err)
defer conn.Close()
line := "sweetness\r\n"
packet := make([]byte, len(line))
n, err := io.ReadFull(conn, packet)
must(err)
if n != len(line) {
panic("invalid amount")
}
if string(packet) != string(line) {
panic(fmt.Sprintf("expected '%v', got '%v'\n", line, packet))
}
for i := 0; i < 100; i++ {
line := fmt.Sprintf("hello %d\r\n", i)
n, err := conn.Write([]byte(line))
must(err)
if n != len(line) {
panic("invalid amount")
}
packet := make([]byte, len(line))
n, err = io.ReadFull(conn, packet)
must(err)
if n != len(line) {
panic("invalid amount")
}
if string(packet) != string(line) {
panic(fmt.Sprintf("expected '%v', got '%v'\n", line, packet))
}
}
}()
return
}
tevents := Translate(events,
func(id int, info Info) bool {
return true
},
func(id int, rw io.ReadWriter) io.ReadWriter {
switch kind {
case "passthrough":
return rw
}
panic("invalid kind")
},
)
if stdlib {
must(Serve(tevents, network+"-net://"+addr))
} else {
must(Serve(tevents, network+"://"+addr))
}
// test with no shoulds
tevents = Translate(events,
func(id int, info Info) bool {
return false
},
func(id int, rw io.ReadWriter) io.ReadWriter {
return rw
},
)
if stdlib {
must(Serve(tevents, network+"-net://"+addr))
} else {
must(Serve(tevents, network+"://"+addr))
}
}
// func TestVariousAddr(t *testing.T) {
// var events Events
// var kind string
// events.Serving = func(wake func(id int) bool, addrs []net.Addr) (action Action) {
// addr := addrs[0].(*net.TCPAddr)
// if (kind == "tcp4" && len(addr.IP) != 4) || (kind == "tcp6" && len(addr.IP) != 16) {
// println(len(addr.IP))
// panic("invalid ip")
// }
// go func(kind string) {
// conn, err := net.Dial(kind, ":9991")
// must(err)
// defer conn.Close()
// }(kind)
// return
// }
// events.Closed = func(id int, err error) (action Action) {
// return Shutdown
// }
// kind = "tcp4"
// must(Serve(events, "tcp4://:9991"))
// kind = "tcp6"
// must(Serve(events, "tcp6://:9991"))
// }