upgrade: add context-takeover test to client_server_test

This commit is contained in:
misu 2018-02-02 13:48:44 +09:00
parent c83088956f
commit e79bb70823
1 changed files with 92 additions and 8 deletions

View File

@ -42,6 +42,8 @@ var cstDialer = Dialer{
type cstHandler struct{ *testing.T }
type cstContextTakeoverHandler struct{ *testing.T }
type cstServer struct {
*httptest.Server
URL string
@ -61,6 +63,14 @@ func newServer(t *testing.T) *cstServer {
return &s
}
func newContextTakeoverServer(t *testing.T) *cstServer {
var s cstServer
s.Server = httptest.NewServer(cstContextTakeoverHandler{t})
s.Server.URL += cstRequestURI
s.URL = makeWsProto(s.Server.URL)
return &s
}
func newTLSServer(t *testing.T) *cstServer {
var s cstServer
s.Server = httptest.NewTLSServer(cstHandler{t})
@ -118,6 +128,80 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
func (t cstContextTakeoverHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != cstPath {
t.Logf("path=%v, want %v", r.URL.Path, cstPath)
http.Error(w, "bad path", 400)
return
}
if r.URL.RawQuery != cstRawQuery {
t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery)
http.Error(w, "bad path", 400)
return
}
subprotos := Subprotocols(r)
if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
http.Error(w, "bad protocol", 400)
return
}
cu := cstUpgrader
cu.CompressionLevel = defaultCompressionLevel
cu.EnableContextTakeover = true
ws, err := cu.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}})
if err != nil {
t.Logf("Upgrade: %v", err)
return
}
defer ws.Close()
if ws.Subprotocol() != "p1" {
t.Logf("Subprotocol() = %s, want p1", ws.Subprotocol())
ws.Close()
return
}
// first message
op, rd, err := ws.NextReader()
if err != nil {
t.Logf("NextReader: %v", err)
return
}
wr, err := ws.NextWriter(op)
if err != nil {
t.Logf("NextWriter: %v", err)
return
}
if _, err = io.Copy(wr, rd); err != nil {
t.Logf("NextWriter: %v", err)
return
}
if err := wr.Close(); err != nil {
t.Logf("Close: %v", err)
return
}
// second message
op, rd, err = ws.NextReader()
if err != nil {
t.Logf("NextReader: %v", err)
return
}
wr, err = ws.NextWriter(op)
if err != nil {
t.Logf("NextWriter: %v", err)
return
}
if _, err = io.Copy(wr, rd); err != nil {
t.Logf("NextWriter: %v", err)
return
}
if err := wr.Close(); err != nil {
t.Logf("Close: %v", err)
return
}
}
func makeWsProto(s string) string {
return "ws" + strings.TrimPrefix(s, "http")
}
@ -161,11 +245,11 @@ func multipleSendRecv(t *testing.T, ws *Conn) {
t.Fatalf("message=%s, want %s", p, message)
}
message_2 := "Can you read message?"
nextMessage := "Can you read message?"
if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
t.Fatalf("SetWriteDeadline: %v", err)
}
if err := ws.WriteMessage(TextMessage, []byte(message_2)); err != nil {
if err := ws.WriteMessage(TextMessage, []byte(nextMessage)); err != nil {
t.Fatalf("_WriteMessage: %v", err)
}
if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
@ -174,10 +258,10 @@ func multipleSendRecv(t *testing.T, ws *Conn) {
_, p, err = ws.ReadMessage()
if err != nil {
t.Fatalf("_ReadMessage: %v", err) // _ReadMessage: websocket: close 1006 (abnormal closure): unexpected EOF
t.Fatalf("_ReadMessage: %v", err)
}
if string(p) != message {
t.Fatalf("_message=%s, want %s", p, message_2)
if string(p) != nextMessage {
t.Fatalf("_message=%s, want %s", p, nextMessage)
}
}
@ -562,20 +646,20 @@ func TestDialCompression(t *testing.T) {
}
func TestDialCompressionOfContextTakeover(t *testing.T) {
s := newServer(t)
s := newContextTakeoverServer(t)
defer s.Close()
dialer := cstDialer
dialer.EnableCompression = true
dialer.EnableContextTakeover = true
dialer.CompressionLevel = 2
ws, _, err := dialer.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
// Todo multiple send and receive.
sendRecv(t, ws)
multipleSendRecv(t, ws)
}
func TestSocksProxyDial(t *testing.T) {