mirror of https://github.com/gorilla/websocket.git
Improve client/server tests
Tests must not call *testing.T methods after the test function returns. Use a sync.WaitGroup to ensure that server handler functions complete before tests return.
This commit is contained in:
parent
7e5e9b5a25
commit
688592ebe6
|
@ -23,6 +23,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -44,12 +45,15 @@ var cstDialer = Dialer{
|
||||||
HandshakeTimeout: 30 * time.Second,
|
HandshakeTimeout: 30 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
type cstHandler struct{ *testing.T }
|
type cstHandler struct {
|
||||||
|
*testing.T
|
||||||
|
s *cstServer
|
||||||
|
}
|
||||||
|
|
||||||
type cstServer struct {
|
type cstServer struct {
|
||||||
*httptest.Server
|
URL string
|
||||||
URL string
|
Server *httptest.Server
|
||||||
t *testing.T
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -58,9 +62,15 @@ const (
|
||||||
cstRequestURI = cstPath + "?" + cstRawQuery
|
cstRequestURI = cstPath + "?" + cstRawQuery
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (s *cstServer) Close() {
|
||||||
|
s.Server.Close()
|
||||||
|
// Wait for handler functions to complete.
|
||||||
|
s.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
func newServer(t *testing.T) *cstServer {
|
func newServer(t *testing.T) *cstServer {
|
||||||
var s cstServer
|
var s cstServer
|
||||||
s.Server = httptest.NewServer(cstHandler{t})
|
s.Server = httptest.NewServer(cstHandler{T: t, s: &s})
|
||||||
s.Server.URL += cstRequestURI
|
s.Server.URL += cstRequestURI
|
||||||
s.URL = makeWsProto(s.Server.URL)
|
s.URL = makeWsProto(s.Server.URL)
|
||||||
return &s
|
return &s
|
||||||
|
@ -68,13 +78,19 @@ func newServer(t *testing.T) *cstServer {
|
||||||
|
|
||||||
func newTLSServer(t *testing.T) *cstServer {
|
func newTLSServer(t *testing.T) *cstServer {
|
||||||
var s cstServer
|
var s cstServer
|
||||||
s.Server = httptest.NewTLSServer(cstHandler{t})
|
s.Server = httptest.NewTLSServer(cstHandler{T: t, s: &s})
|
||||||
s.Server.URL += cstRequestURI
|
s.Server.URL += cstRequestURI
|
||||||
s.URL = makeWsProto(s.Server.URL)
|
s.URL = makeWsProto(s.Server.URL)
|
||||||
return &s
|
return &s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Because tests wait for a response from a server, we are guaranteed that
|
||||||
|
// the wait group count is incremented before the test waits on the group
|
||||||
|
// in the call to (*cstServer).Close().
|
||||||
|
t.s.wg.Add(1)
|
||||||
|
defer t.s.wg.Done()
|
||||||
|
|
||||||
if r.URL.Path != cstPath {
|
if r.URL.Path != cstPath {
|
||||||
t.Logf("path=%v, want %v", r.URL.Path, cstPath)
|
t.Logf("path=%v, want %v", r.URL.Path, cstPath)
|
||||||
http.Error(w, "bad path", http.StatusBadRequest)
|
http.Error(w, "bad path", http.StatusBadRequest)
|
||||||
|
|
Loading…
Reference in New Issue