From b03dcbad2ae3e2d00ee5589ae43c2a4eff5daa6d Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Fri, 18 Apr 2014 00:07:36 +0200 Subject: [PATCH] Updated Origin check. The host in the Origin header must match the host of the request by default. --- examples/autobahn/server.go | 10 ++++++++-- examples/chat/conn.go | 5 +---- examples/filewatch/main.go | 5 +---- server.go | 27 +++++++++++++++++++++++++-- 4 files changed, 35 insertions(+), 12 deletions(-) diff --git a/examples/autobahn/server.go b/examples/autobahn/server.go index e00c3ec..89720ba 100644 --- a/examples/autobahn/server.go +++ b/examples/autobahn/server.go @@ -30,7 +30,10 @@ import ( func echoCopy(w http.ResponseWriter, r *http.Request, writerOnly bool) { u := websocket.Upgrader{ ReadBufferSize: 4096, - WriteBufferSize: 4096} + WriteBufferSize: 4096, + CheckOrigin: func(r *http.Request) bool { + return true + }} conn, err := u.Upgrade(w, r, nil) if err != nil { log.Println("Upgrade:", err) @@ -92,7 +95,10 @@ func echoCopyFull(w http.ResponseWriter, r *http.Request) { func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage bool) { u := websocket.Upgrader{ ReadBufferSize: 4096, - WriteBufferSize: 4096} + WriteBufferSize: 4096, + CheckOrigin: func(r *http.Request) bool { + return true + }} conn, err := u.Upgrade(w, r, nil) if err != nil { log.Println("Upgrade:", err) diff --git a/examples/chat/conn.go b/examples/chat/conn.go index 0be32dc..ac8ca41 100644 --- a/examples/chat/conn.go +++ b/examples/chat/conn.go @@ -91,10 +91,7 @@ func serveWs(w http.ResponseWriter, r *http.Request) { } u := websocket.Upgrader{ ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - return r.Header.Get("Origin") == "http://"+r.Host - }} + WriteBufferSize: 1024} ws, err := u.Upgrade(w, r, nil) if err != nil { if _, ok := err.(websocket.HandshakeError); !ok { diff --git a/examples/filewatch/main.go b/examples/filewatch/main.go index b338d89..405aa70 100644 --- a/examples/filewatch/main.go +++ b/examples/filewatch/main.go @@ -109,10 +109,7 @@ func writer(ws *websocket.Conn, lastMod time.Time) { func serveWs(w http.ResponseWriter, r *http.Request) { u := websocket.Upgrader{ ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - return r.Header.Get("Origin") == "http://"+r.Host - }} + WriteBufferSize: 1024} ws, err := u.Upgrade(w, r, nil) if err != nil { if _, ok := err.(websocket.HandshakeError); !ok { diff --git a/server.go b/server.go index 0f5130c..9251734 100644 --- a/server.go +++ b/server.go @@ -9,6 +9,7 @@ import ( "errors" "net" "net/http" + "net/url" "strings" "time" ) @@ -42,7 +43,8 @@ type Upgrader struct { Error func(w http.ResponseWriter, r *http.Request, status int, reason error) // CheckOrigin returns true if the request Origin header is acceptable. - // If CheckOrigin is nil, then no origin check is done. + // If CheckOrigin is nil, the host in the Origin header must match + // the host of the request. CheckOrigin func(r *http.Request) bool } @@ -70,6 +72,19 @@ func (u *Upgrader) hasSubprotocol(subprotocol string) bool { return false } +// Check if host in Origin header matches host of request +func (u *Upgrader) checkSameOrigin(r *http.Request) bool { + origin := r.Header.Get("Origin") + if origin == "" { + return false + } + uri, err := url.ParseRequestURI(origin) + if err != nil { + return false + } + return uri.Host == r.Host +} + // Upgrade upgrades the HTTP server connection to the WebSocket protocol. // // The responseHeader is included in the response to the client's upgrade @@ -100,7 +115,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade return nil, err } - if u.CheckOrigin != nil && !u.CheckOrigin(r) { + checkOrigin := u.CheckOrigin + if checkOrigin == nil { + checkOrigin = u.checkSameOrigin + } + if !checkOrigin(r) { err := HandshakeError{"websocket: origin not allowed"} u.returnError(w, r, http.StatusForbidden, err) return nil, err @@ -229,6 +248,10 @@ func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) { // don't return errors to maintain backwards compatibility } + u.CheckOrigin = func(r *http.Request) bool { + // allow all connections by default + return true + } return u.Upgrade(w, r, responseHeader) }