add new interface Extractor for making token extraction pluggable

This commit is contained in:
Dave Grijalva 2016-06-06 15:27:44 -07:00
parent b6d201ffa0
commit bb45bfcdec
4 changed files with 109 additions and 24 deletions

75
request/extractor.go Normal file
View File

@ -0,0 +1,75 @@
package request
import (
"errors"
"net/http"
)
// Errors
var (
ErrNoTokenInRequest = errors.New("no token present in request")
)
// Interface for extracting a token from an HTTP request
type Extractor interface {
ExtractToken(*http.Request) (string, error)
}
// Extract token from headers
type HeaderExtractor []string
func (e HeaderExtractor) ExtractToken(req *http.Request) (string, error) {
// loop over header names and return the first one that contains data
for _, header := range e {
if ah := req.Header.Get(header); ah != "" {
return ah, nil
}
}
return "", ErrNoTokenInRequest
}
// Extract token from request args
type ArgumentExtractor []string
func (e ArgumentExtractor) ExtractToken(req *http.Request) (string, error) {
// Make sure form is parsed
req.ParseMultipartForm(10e6)
// loop over arg names and return the first one that contains data
for _, arg := range e {
if ah := req.Form.Get(arg); ah != "" {
return ah, nil
}
}
return "", ErrNoTokenInRequest
}
// Tries extractors in order until one works or an error occurs
type MultiExtractor []Extractor
func (e MultiExtractor) ExtractToken(req *http.Request) (string, error) {
// loop over header names and return the first one that contains data
for _, extractor := range e {
if tok, err := extractor.ExtractToken(req); tok != "" {
return tok, nil
} else if err != ErrNoTokenInRequest {
return "", err
}
}
return "", ErrNoTokenInRequest
}
// Wrap an Extractor in this to post-process the value before it's handed off
type PostExtractionFilter struct {
Extractor
Filter func(string) (string, error)
}
func (e *PostExtractionFilter) ExtractToken(req *http.Request) (string, error) {
if tok, err := e.Extractor.ExtractToken(req); tok != "" {
return e.Filter(tok)
} else {
return "", err
}
}

25
request/oauth2.go Normal file
View File

@ -0,0 +1,25 @@
package request
import (
"strings"
)
// Extract Authorization header and strip 'Bearer ' from it
var AuthorizationHeaderExtractor = &PostExtractionFilter{
HeaderExtractor{"Authorization"},
func(tok string) (string, error) {
// Should be a bearer token
if len(tok) > 6 && strings.ToUpper(tok[0:7]) == "BEARER " {
return tok[7:], nil
}
return tok, nil
},
}
// Extractor for OAuth2 access tokens
var OAuth2Extractor = &MultiExtractor{
// Look for authorization token first
AuthorizationHeaderExtractor,
// Extract access_token from form or GET argument
&ArgumentExtractor{"access_token"},
}

View File

@ -1,39 +1,24 @@
package request package request
import ( import (
"errors"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
"net/http" "net/http"
"strings"
)
// Errors
var (
ErrNoTokenInRequest = errors.New("no token present in request")
) )
// Try to find the token in an http.Request. // Try to find the token in an http.Request.
// This method will call ParseMultipartForm if there's no token in the header. // This method will call ParseMultipartForm if there's no token in the header.
// Currently, it looks in the Authorization header as well as // Currently, it looks in the Authorization header as well as
// looking for an 'access_token' request parameter in req.Form. // looking for an 'access_token' request parameter in req.Form.
func ParseFromRequest(req *http.Request, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) { func ParseFromRequest(req *http.Request, extractor Extractor, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) {
return ParseFromRequestWithClaims(req, jwt.MapClaims{}, keyFunc) return ParseFromRequestWithClaims(req, extractor, jwt.MapClaims{}, keyFunc)
} }
func ParseFromRequestWithClaims(req *http.Request, claims jwt.Claims, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) { func ParseFromRequestWithClaims(req *http.Request, extractor Extractor, claims jwt.Claims, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) {
// Look for an Authorization header // Extract token from request
if ah := req.Header.Get("Authorization"); ah != "" { tokStr, err := extractor.ExtractToken(req)
// Should be a bearer token if err != nil {
if len(ah) > 6 && strings.ToUpper(ah[0:7]) == "BEARER " { return nil, err
return jwt.ParseWithClaims(ah[7:], claims, keyFunc)
}
} }
// Look for "access_token" parameter return jwt.ParseWithClaims(tokStr, claims, keyFunc)
req.ParseMultipartForm(10e6)
if tokStr := req.Form.Get("access_token"); tokStr != "" {
return jwt.ParseWithClaims(tokStr, claims, keyFunc)
}
return nil, ErrNoTokenInRequest
} }

View File

@ -65,7 +65,7 @@ func TestParseRequest(t *testing.T) {
r.Header.Set(k, tokenString) r.Header.Set(k, tokenString)
} }
} }
token, err := ParseFromRequestWithClaims(r, jwt.MapClaims{}, keyfunc) token, err := ParseFromRequestWithClaims(r, OAuth2Extractor, jwt.MapClaims{}, keyfunc)
if token == nil { if token == nil {
t.Errorf("[%v] Token was not found: %v", data.name, err) t.Errorf("[%v] Token was not found: %v", data.name, err)