Relax default origin test.

Update the default origin test to treat no origin specified as OK. If
the client can create a request without the origin set, then the client
can also create a request with an arbitrary origin.
This commit is contained in:
Gary Burd 2014-06-30 14:57:20 -07:00
parent 3d0e89148e
commit 0f32413e5e
2 changed files with 16 additions and 8 deletions

View File

@ -137,7 +137,6 @@ type dialHandler struct {
var dialUpgrader = &Upgrader{ var dialUpgrader = &Upgrader{
ReadBufferSize: 1024, ReadBufferSize: 1024,
WriteBufferSize: 1024, WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool { return true },
} }
func (t dialHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (t dialHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@ -242,3 +241,12 @@ func TestDialBadScheme(t *testing.T) {
t.Fatalf("Dial() did not return error") t.Fatalf("Dial() did not return error")
} }
} }
func TestDialBadOrigin(t *testing.T) {
s := httptest.NewServer(dialHandler{t})
defer s.Close()
_, _, err := DefaultDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
if err == nil {
t.Fatalf("Dial() did not return error")
}
}

View File

@ -48,8 +48,8 @@ type Upgrader struct {
Error func(w http.ResponseWriter, r *http.Request, status int, reason error) Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
// CheckOrigin returns true if the request Origin header is acceptable. If // CheckOrigin returns true if the request Origin header is acceptable. If
// CheckOrigin is nil, the host in the Origin header must match the host of // CheckOrigin is nil, the host in the Origin header must not be set or
// the request. // must match the host of the request.
CheckOrigin func(r *http.Request) bool CheckOrigin func(r *http.Request) bool
} }
@ -63,13 +63,13 @@ func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status in
return nil, err return nil, err
} }
// checkSameOrigin returns true if the origin is equal to the request host. // checkSameOrigin returns true if the origin is not set or is equal to the request host.
func checkSameOrigin(r *http.Request) bool { func checkSameOrigin(r *http.Request) bool {
origin := r.Header.Get("Origin") origin := r.Header["Origin"]
if origin == "" { if len(origin) == 0 {
return false return true
} }
u, err := url.Parse(origin) u, err := url.Parse(origin[0])
if err != nil { if err != nil {
return false return false
} }