Updated Origin check.

The host in the Origin header must match the host of the request by default.
This commit is contained in:
Joachim Bauch 2014-04-18 00:07:36 +02:00
parent f99474eb97
commit b03dcbad2a
4 changed files with 35 additions and 12 deletions

View File

@ -30,7 +30,10 @@ import (
func echoCopy(w http.ResponseWriter, r *http.Request, writerOnly bool) { func echoCopy(w http.ResponseWriter, r *http.Request, writerOnly bool) {
u := websocket.Upgrader{ u := websocket.Upgrader{
ReadBufferSize: 4096, ReadBufferSize: 4096,
WriteBufferSize: 4096} WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
}}
conn, err := u.Upgrade(w, r, nil) conn, err := u.Upgrade(w, r, nil)
if err != nil { if err != nil {
log.Println("Upgrade:", err) log.Println("Upgrade:", err)
@ -92,7 +95,10 @@ func echoCopyFull(w http.ResponseWriter, r *http.Request) {
func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage bool) { func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage bool) {
u := websocket.Upgrader{ u := websocket.Upgrader{
ReadBufferSize: 4096, ReadBufferSize: 4096,
WriteBufferSize: 4096} WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
}}
conn, err := u.Upgrade(w, r, nil) conn, err := u.Upgrade(w, r, nil)
if err != nil { if err != nil {
log.Println("Upgrade:", err) log.Println("Upgrade:", err)

View File

@ -91,10 +91,7 @@ func serveWs(w http.ResponseWriter, r *http.Request) {
} }
u := websocket.Upgrader{ u := websocket.Upgrader{
ReadBufferSize: 1024, ReadBufferSize: 1024,
WriteBufferSize: 1024, WriteBufferSize: 1024}
CheckOrigin: func(r *http.Request) bool {
return r.Header.Get("Origin") == "http://"+r.Host
}}
ws, err := u.Upgrade(w, r, nil) ws, err := u.Upgrade(w, r, nil)
if err != nil { if err != nil {
if _, ok := err.(websocket.HandshakeError); !ok { if _, ok := err.(websocket.HandshakeError); !ok {

View File

@ -109,10 +109,7 @@ func writer(ws *websocket.Conn, lastMod time.Time) {
func serveWs(w http.ResponseWriter, r *http.Request) { func serveWs(w http.ResponseWriter, r *http.Request) {
u := websocket.Upgrader{ u := websocket.Upgrader{
ReadBufferSize: 1024, ReadBufferSize: 1024,
WriteBufferSize: 1024, WriteBufferSize: 1024}
CheckOrigin: func(r *http.Request) bool {
return r.Header.Get("Origin") == "http://"+r.Host
}}
ws, err := u.Upgrade(w, r, nil) ws, err := u.Upgrade(w, r, nil)
if err != nil { if err != nil {
if _, ok := err.(websocket.HandshakeError); !ok { if _, ok := err.(websocket.HandshakeError); !ok {

View File

@ -9,6 +9,7 @@ import (
"errors" "errors"
"net" "net"
"net/http" "net/http"
"net/url"
"strings" "strings"
"time" "time"
) )
@ -42,7 +43,8 @@ type Upgrader struct {
Error func(w http.ResponseWriter, r *http.Request, status int, reason error) Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
// CheckOrigin returns true if the request Origin header is acceptable. // CheckOrigin returns true if the request Origin header is acceptable.
// If CheckOrigin is nil, then no origin check is done. // If CheckOrigin is nil, the host in the Origin header must match
// the host of the request.
CheckOrigin func(r *http.Request) bool CheckOrigin func(r *http.Request) bool
} }
@ -70,6 +72,19 @@ func (u *Upgrader) hasSubprotocol(subprotocol string) bool {
return false return false
} }
// Check if host in Origin header matches host of request
func (u *Upgrader) checkSameOrigin(r *http.Request) bool {
origin := r.Header.Get("Origin")
if origin == "" {
return false
}
uri, err := url.ParseRequestURI(origin)
if err != nil {
return false
}
return uri.Host == r.Host
}
// Upgrade upgrades the HTTP server connection to the WebSocket protocol. // Upgrade upgrades the HTTP server connection to the WebSocket protocol.
// //
// The responseHeader is included in the response to the client's upgrade // The responseHeader is included in the response to the client's upgrade
@ -100,7 +115,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
return nil, err return nil, err
} }
if u.CheckOrigin != nil && !u.CheckOrigin(r) { checkOrigin := u.CheckOrigin
if checkOrigin == nil {
checkOrigin = u.checkSameOrigin
}
if !checkOrigin(r) {
err := HandshakeError{"websocket: origin not allowed"} err := HandshakeError{"websocket: origin not allowed"}
u.returnError(w, r, http.StatusForbidden, err) u.returnError(w, r, http.StatusForbidden, err)
return nil, err return nil, err
@ -229,6 +248,10 @@ func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header,
u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) { u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) {
// don't return errors to maintain backwards compatibility // don't return errors to maintain backwards compatibility
} }
u.CheckOrigin = func(r *http.Request) bool {
// allow all connections by default
return true
}
return u.Upgrade(w, r, responseHeader) return u.Upgrade(w, r, responseHeader)
} }