diff --git a/client_server_test.go b/client_server_test.go index 67dd346..610fbe2 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -23,6 +23,7 @@ import ( "net/url" "reflect" "strings" + "sync" "testing" "time" ) @@ -44,12 +45,15 @@ var cstDialer = Dialer{ HandshakeTimeout: 30 * time.Second, } -type cstHandler struct{ *testing.T } +type cstHandler struct { + *testing.T + s *cstServer +} type cstServer struct { - *httptest.Server - URL string - t *testing.T + URL string + Server *httptest.Server + wg sync.WaitGroup } const ( @@ -58,9 +62,15 @@ const ( cstRequestURI = cstPath + "?" + cstRawQuery ) +func (s *cstServer) Close() { + s.Server.Close() + // Wait for handler functions to complete. + s.wg.Wait() +} + func newServer(t *testing.T) *cstServer { var s cstServer - s.Server = httptest.NewServer(cstHandler{t}) + s.Server = httptest.NewServer(cstHandler{T: t, s: &s}) s.Server.URL += cstRequestURI s.URL = makeWsProto(s.Server.URL) return &s @@ -68,13 +78,19 @@ func newServer(t *testing.T) *cstServer { func newTLSServer(t *testing.T) *cstServer { var s cstServer - s.Server = httptest.NewTLSServer(cstHandler{t}) + s.Server = httptest.NewTLSServer(cstHandler{T: t, s: &s}) s.Server.URL += cstRequestURI s.URL = makeWsProto(s.Server.URL) return &s } func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Because tests wait for a response from a server, we are guaranteed that + // the wait group count is incremented before the test waits on the group + // in the call to (*cstServer).Close(). + t.s.wg.Add(1) + defer t.s.wg.Done() + if r.URL.Path != cstPath { t.Logf("path=%v, want %v", r.URL.Path, cstPath) http.Error(w, "bad path", http.StatusBadRequest)