General refactoring. Part 2.

This commit is contained in:
Manu Mtz-Almeida 2014-10-09 01:40:42 +02:00
parent 030706c39a
commit aa7b00a083
7 changed files with 163 additions and 147 deletions

104
auth.go
View File

@ -16,70 +16,29 @@ const (
)
type (
BasicAuthPair struct {
Code string
Accounts map[string]string
authPair struct {
Value string
User string
}
Accounts map[string]string
Pairs []BasicAuthPair
authPairs []authPair
)
func (a Pairs) Len() int { return len(a) }
func (a Pairs) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a Pairs) Less(i, j int) bool { return a[i].Code < a[j].Code }
func processCredentials(accounts Accounts) (Pairs, error) {
if len(accounts) == 0 {
return nil, errors.New("Empty list of authorized credentials.")
}
pairs := make(Pairs, 0, len(accounts))
for user, password := range accounts {
if len(user) == 0 || len(password) == 0 {
return nil, errors.New("User or password is empty")
}
base := user + ":" + password
code := "Basic " + base64.StdEncoding.EncodeToString([]byte(base))
pairs = append(pairs, BasicAuthPair{code, user})
}
// We have to sort the credentials in order to use bsearch later.
sort.Sort(pairs)
return pairs, nil
}
func secureCompare(given, actual string) bool {
if subtle.ConstantTimeEq(int32(len(given)), int32(len(actual))) == 1 {
return subtle.ConstantTimeCompare([]byte(given), []byte(actual)) == 1
} else {
/* Securely compare actual to itself to keep constant time, but always return false */
return subtle.ConstantTimeCompare([]byte(actual), []byte(actual)) == 1 && false
}
}
func searchCredential(pairs Pairs, auth string) string {
if len(auth) == 0 {
return ""
}
// Search user in the slice of allowed credentials
r := sort.Search(len(pairs), func(i int) bool { return pairs[i].Code >= auth })
if r < len(pairs) && secureCompare(pairs[r].Code, auth) {
return pairs[r].User
} else {
return ""
}
}
func (a authPairs) Len() int { return len(a) }
func (a authPairs) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a authPairs) Less(i, j int) bool { return a[i].Value < a[j].Value }
// Implements a basic Basic HTTP Authorization. It takes as argument a map[string]string where
// the key is the user name and the value is the password.
func BasicAuth(accounts Accounts) HandlerFunc {
pairs, err := processCredentials(accounts)
pairs, err := processAccounts(accounts)
if err != nil {
panic(err)
}
return func(c *Context) {
// Search user in the slice of allowed credentials
user := searchCredential(pairs, c.Request.Header.Get("Authorization"))
if len(user) == 0 {
user, ok := searchCredential(pairs, c.Request.Header.Get("Authorization"))
if !ok {
// Credentials doesn't match, we return 401 Unauthorized and abort request.
c.Writer.Header().Set("WWW-Authenticate", "Basic realm=\"Authorization Required\"")
c.Fail(401, errors.New("Unauthorized"))
@ -90,3 +49,46 @@ func BasicAuth(accounts Accounts) HandlerFunc {
}
}
}
func processAccounts(accounts Accounts) (authPairs, error) {
if len(accounts) == 0 {
return nil, errors.New("Empty list of authorized credentials")
}
pairs := make(authPairs, 0, len(accounts))
for user, password := range accounts {
if len(user) == 0 {
return nil, errors.New("User can not be empty")
}
base := user + ":" + password
value := "Basic " + base64.StdEncoding.EncodeToString([]byte(base))
pairs = append(pairs, authPair{
Value: value,
User: user,
})
}
// We have to sort the credentials in order to use bsearch later.
sort.Sort(pairs)
return pairs, nil
}
func searchCredential(pairs authPairs, auth string) (string, bool) {
if len(auth) == 0 {
return "", false
}
// Search user in the slice of allowed credentials
r := sort.Search(len(pairs), func(i int) bool { return pairs[i].Value >= auth })
if r < len(pairs) && secureCompare(pairs[r].Value, auth) {
return pairs[r].User, true
} else {
return "", false
}
}
func secureCompare(given, actual string) bool {
if subtle.ConstantTimeEq(int32(len(given)), int32(len(actual))) == 1 {
return subtle.ConstantTimeCompare([]byte(given), []byte(actual)) == 1
} else {
/* Securely compare actual to itself to keep constant time, but always return false */
return subtle.ConstantTimeCompare([]byte(actual), []byte(actual)) == 1 && false
}
}

View File

@ -197,6 +197,17 @@ func (c *Context) MustGet(key string) interface{} {
return value
}
func (c *Context) ClientIP() string {
clientIP := c.Request.Header.Get("X-Real-IP")
if len(clientIP) == 0 {
clientIP = c.Request.Header.Get("X-Forwarded-For")
}
if len(clientIP) == 0 {
clientIP = c.Request.RemoteAddr
}
return clientIP
}
/************************************/
/********* PARSING REQUEST **********/
/************************************/

32
gin.go
View File

@ -30,6 +30,7 @@ type (
Engine struct {
*RouterGroup
HTMLRender render.Render
Default404Body []byte
pool sync.Pool
allNoRoute []HandlerFunc
noRoute []HandlerFunc
@ -37,21 +38,6 @@ type (
}
)
func (engine *Engine) handle404(w http.ResponseWriter, req *http.Request) {
c := engine.createContext(w, req, nil, engine.allNoRoute)
// set 404 by default, useful for logging
c.Writer.WriteHeader(404)
c.Next()
if !c.Writer.Written() {
if c.Writer.Status() == 404 {
c.Data(-1, MIMEPlain, []byte("404 page not found"))
} else {
c.Writer.WriteHeaderNow()
}
}
engine.reuseContext(c)
}
// Returns a new blank Engine instance without any middleware attached.
// The most basic configuration
func New() *Engine {
@ -62,6 +48,7 @@ func New() *Engine {
engine: engine,
}
engine.router = httprouter.New()
engine.Default404Body = []byte("404 page not found")
engine.router.NotFound = engine.handle404
engine.pool.New = func() interface{} {
c := &Context{Engine: engine}
@ -119,6 +106,21 @@ func (engine *Engine) rebuild404Handlers() {
engine.allNoRoute = engine.combineHandlers(engine.noRoute)
}
func (engine *Engine) handle404(w http.ResponseWriter, req *http.Request) {
c := engine.createContext(w, req, nil, engine.allNoRoute)
// set 404 by default, useful for logging
c.Writer.WriteHeader(404)
c.Next()
if !c.Writer.Written() {
if c.Writer.Status() == 404 {
c.Data(-1, MIMEPlain, engine.Default404Body)
} else {
c.Writer.WriteHeaderNow()
}
}
engine.reuseContext(c)
}
// ServeHTTP makes the router implement the http.Handler interface.
func (engine *Engine) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
engine.router.ServeHTTP(writer, request)

110
logger.go
View File

@ -10,6 +10,17 @@ import (
"time"
)
var (
green = string([]byte{27, 91, 57, 55, 59, 52, 50, 109})
white = string([]byte{27, 91, 57, 48, 59, 52, 55, 109})
yellow = string([]byte{27, 91, 57, 55, 59, 52, 51, 109})
red = string([]byte{27, 91, 57, 55, 59, 52, 49, 109})
blue = string([]byte{27, 91, 57, 55, 59, 52, 52, 109})
magenta = string([]byte{27, 91, 57, 55, 59, 52, 53, 109})
cyan = string([]byte{27, 91, 57, 55, 59, 52, 54, 109})
reset = string([]byte{27, 91, 48, 109})
)
func ErrorLogger() HandlerFunc {
return ErrorLoggerT(ErrorTypeAll)
}
@ -26,17 +37,6 @@ func ErrorLoggerT(typ uint32) HandlerFunc {
}
}
var (
green = string([]byte{27, 91, 57, 55, 59, 52, 50, 109})
white = string([]byte{27, 91, 57, 48, 59, 52, 55, 109})
yellow = string([]byte{27, 91, 57, 55, 59, 52, 51, 109})
red = string([]byte{27, 91, 57, 55, 59, 52, 49, 109})
blue = string([]byte{27, 91, 57, 55, 59, 52, 52, 109})
magenta = string([]byte{27, 91, 57, 55, 59, 52, 53, 109})
cyan = string([]byte{27, 91, 57, 55, 59, 52, 54, 109})
reset = string([]byte{27, 91, 48, 109})
)
func Logger() HandlerFunc {
stdlogger := log.New(os.Stdout, "", 0)
//errlogger := log.New(os.Stderr, "", 0)
@ -48,58 +48,58 @@ func Logger() HandlerFunc {
// Process request
c.Next()
// save the IP of the requester
requester := c.Request.Header.Get("X-Real-IP")
// if the requester-header is empty, check the forwarded-header
if len(requester) == 0 {
requester = c.Request.Header.Get("X-Forwarded-For")
}
// if the requester is still empty, use the hard-coded address from the socket
if len(requester) == 0 {
requester = c.Request.RemoteAddr
}
var color string
code := c.Writer.Status()
switch {
case code >= 200 && code <= 299:
color = green
case code >= 300 && code <= 399:
color = white
case code >= 400 && code <= 499:
color = yellow
default:
color = red
}
var methodColor string
method := c.Request.Method
switch {
case method == "GET":
methodColor = blue
case method == "POST":
methodColor = cyan
case method == "PUT":
methodColor = yellow
case method == "DELETE":
methodColor = red
case method == "PATCH":
methodColor = green
case method == "HEAD":
methodColor = magenta
case method == "OPTIONS":
methodColor = white
}
// Stop timer
end := time.Now()
latency := end.Sub(start)
clientIP := c.ClientIP()
method := c.Request.Method
statusCode := c.Writer.Status()
statusColor := colorForStatus(statusCode)
methodColor := colorForMethod(method)
stdlogger.Printf("[GIN] %v |%s %3d %s| %12v | %s |%s %s %-7s %s\n%s",
end.Format("2006/01/02 - 15:04:05"),
color, code, reset,
statusColor, statusCode, reset,
latency,
requester,
clientIP,
methodColor, reset, method,
c.Request.URL.Path,
c.Errors.String(),
)
}
}
func colorForStatus(code int) string {
switch {
case code >= 200 && code <= 299:
return green
case code >= 300 && code <= 399:
return white
case code >= 400 && code <= 499:
return yellow
default:
return red
}
}
func colorForMethod(method string) string {
switch {
case method == "GET":
return blue
case method == "POST":
return cyan
case method == "PUT":
return yellow
case method == "DELETE":
return red
case method == "PATCH":
return green
case method == "HEAD":
return magenta
case method == "OPTIONS":
return white
default:
return reset
}
}

21
mode.go
View File

@ -5,6 +5,7 @@
package gin
import (
"fmt"
"os"
)
@ -24,6 +25,15 @@ const (
var gin_mode int = debugCode
var mode_name string = DebugMode
func init() {
value := os.Getenv(GIN_MODE)
if len(value) == 0 {
SetMode(DebugMode)
} else {
SetMode(value)
}
}
func SetMode(value string) {
switch value {
case DebugMode:
@ -33,7 +43,7 @@ func SetMode(value string) {
case TestMode:
gin_mode = testCode
default:
panic("gin mode unknown, the allowed modes are: " + DebugMode + " and " + ReleaseMode)
panic("gin mode unknown: " + value)
}
mode_name = value
}
@ -46,11 +56,8 @@ func IsDebugging() bool {
return gin_mode == debugCode
}
func init() {
value := os.Getenv(GIN_MODE)
if len(value) == 0 {
SetMode(DebugMode)
} else {
SetMode(value)
func debugPrint(format string, values ...interface{}) {
if IsDebugging() {
fmt.Printf("[GIN-debug] "+format, values)
}
}

View File

@ -48,7 +48,7 @@ func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers []Han
handlers = group.combineHandlers(handlers)
if IsDebugging() {
nuHandlers := len(handlers)
handlerName := nameOfFuncion(handlers[nuHandlers-1])
handlerName := nameOfFunction(handlers[nuHandlers-1])
debugPrint("%-5s %-25s --> %s (%d handlers)\n", httpMethod, absolutePath, handlerName, nuHandlers)
}
@ -105,6 +105,8 @@ func (group *RouterGroup) Static(relativePath, root string) {
absolutePath := group.calculateAbsolutePath(relativePath)
handler := group.createStaticHandler(absolutePath, root)
absolutePath = path.Join(absolutePath, "/*filepath")
// Register GET and HEAD handlers
group.GET(absolutePath, handler)
group.HEAD(absolutePath, handler)
}
@ -120,8 +122,7 @@ func (group *RouterGroup) combineHandlers(handlers []HandlerFunc) []HandlerFunc
finalSize := len(group.Handlers) + len(handlers)
mergedHandlers := make([]HandlerFunc, 0, finalSize)
mergedHandlers = append(mergedHandlers, group.Handlers...)
mergedHandlers = append(mergedHandlers, handlers...)
return mergedHandlers
return append(mergedHandlers, handlers...)
}
func (group *RouterGroup) calculateAbsolutePath(relativePath string) string {

View File

@ -6,7 +6,6 @@ package gin
import (
"encoding/xml"
"fmt"
"reflect"
"runtime"
"strings"
@ -39,20 +38,14 @@ func (h H) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
}
func filterFlags(content string) string {
for i, a := range content {
if a == ' ' || a == ';' {
for i, char := range content {
if char == ' ' || char == ';' {
return content[:i]
}
}
return content
}
func debugPrint(format string, values ...interface{}) {
if IsDebugging() {
fmt.Printf("[GIN-debug] "+format, values)
}
}
func chooseData(custom, wildcard interface{}) interface{} {
if custom == nil {
if wildcard == nil {
@ -84,6 +77,6 @@ func lastChar(str string) uint8 {
return str[size-1]
}
func nameOfFuncion(f interface{}) string {
func nameOfFunction(f interface{}) string {
return runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name()
}