diff --git a/client_server_test.go b/client_server_test.go index 21672c0..e30be94 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -137,7 +137,6 @@ type dialHandler struct { var dialUpgrader = &Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, } func (t dialHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -242,3 +241,12 @@ func TestDialBadScheme(t *testing.T) { t.Fatalf("Dial() did not return error") } } + +func TestDialBadOrigin(t *testing.T) { + s := httptest.NewServer(dialHandler{t}) + defer s.Close() + _, _, err := DefaultDialer.Dial(s.URL, http.Header{"Origin": {"bad"}}) + if err == nil { + t.Fatalf("Dial() did not return error") + } +} diff --git a/server.go b/server.go index 65e9d9f..4de5bb0 100644 --- a/server.go +++ b/server.go @@ -48,8 +48,8 @@ type Upgrader struct { Error func(w http.ResponseWriter, r *http.Request, status int, reason error) // CheckOrigin returns true if the request Origin header is acceptable. If - // CheckOrigin is nil, the host in the Origin header must match the host of - // the request. + // CheckOrigin is nil, the host in the Origin header must not be set or + // must match the host of the request. CheckOrigin func(r *http.Request) bool } @@ -63,13 +63,13 @@ func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status in return nil, err } -// checkSameOrigin returns true if the origin is equal to the request host. +// checkSameOrigin returns true if the origin is not set or is equal to the request host. func checkSameOrigin(r *http.Request) bool { - origin := r.Header.Get("Origin") - if origin == "" { - return false + origin := r.Header["Origin"] + if len(origin) == 0 { + return true } - u, err := url.Parse(origin) + u, err := url.Parse(origin[0]) if err != nil { return false }