From b648f206c26b1333ae0c90e91ab3e227859bd8d6 Mon Sep 17 00:00:00 2001 From: David Dollar Date: Mon, 27 Nov 2017 19:10:45 -0500 Subject: [PATCH] Use ASCII case folding in same origin test --- server.go | 2 +- server_test.go | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/server.go b/server.go index 6ae97c5..0441e5d 100644 --- a/server.go +++ b/server.go @@ -76,7 +76,7 @@ func checkSameOrigin(r *http.Request) bool { if err != nil { return false } - return u.Host == r.Host + return equalASCIIFold(u.Host, r.Host) } func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { diff --git a/server_test.go b/server_test.go index 0a28141..c43dbb2 100644 --- a/server_test.go +++ b/server_test.go @@ -49,3 +49,21 @@ func TestIsWebSocketUpgrade(t *testing.T) { } } } + +var checkSameOriginTests = []struct { + ok bool + r *http.Request +}{ + {false, &http.Request{Host: "example.org", Header: map[string][]string{"Origin": []string{"https://other.org"}}}}, + {true, &http.Request{Host: "example.org", Header: map[string][]string{"Origin": []string{"https://example.org"}}}}, + {true, &http.Request{Host: "Example.org", Header: map[string][]string{"Origin": []string{"https://example.org"}}}}, +} + +func TestCheckSameOrigin(t *testing.T) { + for _, tt := range checkSameOriginTests { + ok := checkSameOrigin(tt.r) + if tt.ok != ok { + t.Errorf("checkSameOrigin(%+v) returned %v, want %v", tt.r, ok, tt.ok) + } + } +}