mirror of https://github.com/gorilla/websocket.git
Relax default origin test.
Update the default origin test to treat no origin specified as OK. If the client can create a request without the origin set, then the client can also create a request with an arbitrary origin.
This commit is contained in:
parent
3d0e89148e
commit
0f32413e5e
|
@ -137,7 +137,6 @@ type dialHandler struct {
|
||||||
var dialUpgrader = &Upgrader{
|
var dialUpgrader = &Upgrader{
|
||||||
ReadBufferSize: 1024,
|
ReadBufferSize: 1024,
|
||||||
WriteBufferSize: 1024,
|
WriteBufferSize: 1024,
|
||||||
CheckOrigin: func(r *http.Request) bool { return true },
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t dialHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (t dialHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -242,3 +241,12 @@ func TestDialBadScheme(t *testing.T) {
|
||||||
t.Fatalf("Dial() did not return error")
|
t.Fatalf("Dial() did not return error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDialBadOrigin(t *testing.T) {
|
||||||
|
s := httptest.NewServer(dialHandler{t})
|
||||||
|
defer s.Close()
|
||||||
|
_, _, err := DefaultDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Dial() did not return error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
14
server.go
14
server.go
|
@ -48,8 +48,8 @@ type Upgrader struct {
|
||||||
Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
|
Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
|
||||||
|
|
||||||
// CheckOrigin returns true if the request Origin header is acceptable. If
|
// CheckOrigin returns true if the request Origin header is acceptable. If
|
||||||
// CheckOrigin is nil, the host in the Origin header must match the host of
|
// CheckOrigin is nil, the host in the Origin header must not be set or
|
||||||
// the request.
|
// must match the host of the request.
|
||||||
CheckOrigin func(r *http.Request) bool
|
CheckOrigin func(r *http.Request) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,13 +63,13 @@ func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status in
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkSameOrigin returns true if the origin is equal to the request host.
|
// checkSameOrigin returns true if the origin is not set or is equal to the request host.
|
||||||
func checkSameOrigin(r *http.Request) bool {
|
func checkSameOrigin(r *http.Request) bool {
|
||||||
origin := r.Header.Get("Origin")
|
origin := r.Header["Origin"]
|
||||||
if origin == "" {
|
if len(origin) == 0 {
|
||||||
return false
|
return true
|
||||||
}
|
}
|
||||||
u, err := url.Parse(origin)
|
u, err := url.Parse(origin[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue