// Copyright 2014 Manu Martinez-Almeida. All rights reserved. // Use of this source code is governed by a MIT style // license that can be found in the LICENSE file. package gin import ( "errors" "io" "io/ioutil" "math" "mime/multipart" "net" "net/http" "net/url" "strings" "time" "github.com/gin-contrib/sse" "github.com/gin-gonic/gin/binding" "github.com/gin-gonic/gin/render" ) // Content-Type MIME of the most common data formats const ( MIMEJSON = binding.MIMEJSON MIMEHTML = binding.MIMEHTML MIMEXML = binding.MIMEXML MIMEXML2 = binding.MIMEXML2 MIMEPlain = binding.MIMEPlain MIMEPOSTForm = binding.MIMEPOSTForm MIMEMultipartPOSTForm = binding.MIMEMultipartPOSTForm ) const ( defaultMemory = 32 << 20 // 32 MB abortIndex int8 = math.MaxInt8 / 2 ) // Context is the most important part of gin. It allows us to pass variables between middleware, // manage the flow, validate the JSON of a request and render a JSON response for example. type Context struct { writermem responseWriter Request *http.Request Writer ResponseWriter Params Params handlers HandlersChain index int8 engine *Engine Keys map[string]interface{} Errors errorMsgs Accepted []string } /************************************/ /********** CONTEXT CREATION ********/ /************************************/ func (c *Context) reset() { c.Writer = &c.writermem c.Params = c.Params[0:0] c.handlers = nil c.index = -1 c.Keys = nil c.Errors = c.Errors[0:0] c.Accepted = nil } // Copy returns a copy of the current context that can be safely used outside the request's scope. // This has to be used when the context has to be passed to a goroutine. func (c *Context) Copy() *Context { var cp = *c cp.writermem.ResponseWriter = nil cp.Writer = &cp.writermem cp.index = abortIndex cp.handlers = nil return &cp } // HandlerName returns the main handler's name. For example if the handler is "handleGetUsers()", this // function will return "main.handleGetUsers" func (c *Context) HandlerName() string { return nameOfFunction(c.handlers.Last()) } // Handler returns the main handler. func (c *Context) Handler() HandlerFunc { return c.handlers.Last() } /************************************/ /*********** FLOW CONTROL ***********/ /************************************/ // Next should be used only inside middleware. // It executes the pending handlers in the chain inside the calling handler. // See example in GitHub. func (c *Context) Next() { c.index++ s := int8(len(c.handlers)) for ; c.index < s; c.index++ { c.handlers[c.index](c) } } // IsAborted returns true if the current context was aborted. func (c *Context) IsAborted() bool { return c.index >= abortIndex } // Abort prevents pending handlers from being called. Note that this will not stop the current handler. // Let's say you have an authorization middleware that validates that the current request is authorized. If the // authorization fails (ex: the password does not match), call Abort to ensure the remaining handlers // for this request are not called. func (c *Context) Abort() { c.index = abortIndex } // AbortWithStatus calls `Abort()` and writes the headers with the specified status code. // For example, a failed attempt to authenticate a request could use: context.AbortWithStatus(401). func (c *Context) AbortWithStatus(code int) { c.Status(code) c.Writer.WriteHeaderNow() c.Abort() } // AbortWithStatusJSON calls `Abort()` and then `JSON` internally. This method stops the chain, writes the status code and return a JSON body // It also sets the Content-Type as "application/json". func (c *Context) AbortWithStatusJSON(code int, jsonObj interface{}) { c.Abort() c.JSON(code, jsonObj) } // AbortWithError calls `AbortWithStatus()` and `Error()` internally. This method stops the chain, writes the status code and // pushes the specified error to `c.Errors`. // See Context.Error() for more details. func (c *Context) AbortWithError(code int, err error) *Error { c.AbortWithStatus(code) return c.Error(err) } /************************************/ /********* ERROR MANAGEMENT *********/ /************************************/ // Attaches an error to the current context. The error is pushed to a list of errors. // It's a good idea to call Error for each error that occurred during the resolution of a request. // A middleware can be used to collect all the errors // and push them to a database together, print a log, or append it in the HTTP response. func (c *Context) Error(err error) *Error { var parsedError *Error switch err.(type) { case *Error: parsedError = err.(*Error) default: parsedError = &Error{ Err: err, Type: ErrorTypePrivate, } } c.Errors = append(c.Errors, parsedError) return parsedError } /************************************/ /******** METADATA MANAGEMENT********/ /************************************/ // Set is used to store a new key/value pair exclusively for this context. // It also lazy initializes c.Keys if it was not used previously. func (c *Context) Set(key string, value interface{}) { if c.Keys == nil { c.Keys = make(map[string]interface{}) } c.Keys[key] = value } // Get returns the value for the given key, ie: (value, true). // If the value does not exists it returns (nil, false) func (c *Context) Get(key string) (value interface{}, exists bool) { value, exists = c.Keys[key] return } // MustGet returns the value for the given key if it exists, otherwise it panics. func (c *Context) MustGet(key string) interface{} { if value, exists := c.Get(key); exists { return value } panic("Key \"" + key + "\" does not exist") } // GetString returns the value associated with the key as a string. func (c *Context) GetString(key string) (s string) { if val, ok := c.Get(key); ok && val != nil { s, _ = val.(string) } return } // GetBool returns the value associated with the key as a boolean. func (c *Context) GetBool(key string) (b bool) { if val, ok := c.Get(key); ok && val != nil { b, _ = val.(bool) } return } // GetInt returns the value associated with the key as an integer. func (c *Context) GetInt(key string) (i int) { if val, ok := c.Get(key); ok && val != nil { i, _ = val.(int) } return } // GetInt64 returns the value associated with the key as an integer. func (c *Context) GetInt64(key string) (i64 int64) { if val, ok := c.Get(key); ok && val != nil { i64, _ = val.(int64) } return } // GetFloat64 returns the value associated with the key as a float64. func (c *Context) GetFloat64(key string) (f64 float64) { if val, ok := c.Get(key); ok && val != nil { f64, _ = val.(float64) } return } // GetTime returns the value associated with the key as time. func (c *Context) GetTime(key string) (t time.Time) { if val, ok := c.Get(key); ok && val != nil { t, _ = val.(time.Time) } return } // GetDuration returns the value associated with the key as a duration. func (c *Context) GetDuration(key string) (d time.Duration) { if val, ok := c.Get(key); ok && val != nil { d, _ = val.(time.Duration) } return } // GetStringSlice returns the value associated with the key as a slice of strings. func (c *Context) GetStringSlice(key string) (ss []string) { if val, ok := c.Get(key); ok && val != nil { ss, _ = val.([]string) } return } // GetStringMap returns the value associated with the key as a map of interfaces. func (c *Context) GetStringMap(key string) (sm map[string]interface{}) { if val, ok := c.Get(key); ok && val != nil { sm, _ = val.(map[string]interface{}) } return } // GetStringMapString returns the value associated with the key as a map of strings. func (c *Context) GetStringMapString(key string) (sms map[string]string) { if val, ok := c.Get(key); ok && val != nil { sms, _ = val.(map[string]string) } return } // GetStringMapStringSlice returns the value associated with the key as a map to a slice of strings. func (c *Context) GetStringMapStringSlice(key string) (smss map[string][]string) { if val, ok := c.Get(key); ok && val != nil { smss, _ = val.(map[string][]string) } return } /************************************/ /************ INPUT DATA ************/ /************************************/ // Param returns the value of the URL param. // It is a shortcut for c.Params.ByName(key) // router.GET("/user/:id", func(c *gin.Context) { // // a GET request to /user/john // id := c.Param("id") // id == "john" // }) func (c *Context) Param(key string) string { return c.Params.ByName(key) } // Query returns the keyed url query value if it exists, // otherwise it returns an empty string `("")`. // It is shortcut for `c.Request.URL.Query().Get(key)` // GET /path?id=1234&name=Manu&value= // c.Query("id") == "1234" // c.Query("name") == "Manu" // c.Query("value") == "" // c.Query("wtf") == "" func (c *Context) Query(key string) string { value, _ := c.GetQuery(key) return value } // DefaultQuery returns the keyed url query value if it exists, // otherwise it returns the specified defaultValue string. // See: Query() and GetQuery() for further information. // GET /?name=Manu&lastname= // c.DefaultQuery("name", "unknown") == "Manu" // c.DefaultQuery("id", "none") == "none" // c.DefaultQuery("lastname", "none") == "" func (c *Context) DefaultQuery(key, defaultValue string) string { if value, ok := c.GetQuery(key); ok { return value } return defaultValue } // GetQuery is like Query(), it returns the keyed url query value // if it exists `(value, true)` (even when the value is an empty string), // otherwise it returns `("", false)`. // It is shortcut for `c.Request.URL.Query().Get(key)` // GET /?name=Manu&lastname= // ("Manu", true) == c.GetQuery("name") // ("", false) == c.GetQuery("id") // ("", true) == c.GetQuery("lastname") func (c *Context) GetQuery(key string) (string, bool) { if values, ok := c.GetQueryArray(key); ok { return values[0], ok } return "", false } // QueryArray returns a slice of strings for a given query key. // The length of the slice depends on the number of params with the given key. func (c *Context) QueryArray(key string) []string { values, _ := c.GetQueryArray(key) return values } // GetQueryArray returns a slice of strings for a given query key, plus // a boolean value whether at least one value exists for the given key. func (c *Context) GetQueryArray(key string) ([]string, bool) { req := c.Request if values, ok := req.URL.Query()[key]; ok && len(values) > 0 { return values, true } return []string{}, false } // PostForm returns the specified key from a POST urlencoded form or multipart form // when it exists, otherwise it returns an empty string `("")`. func (c *Context) PostForm(key string) string { value, _ := c.GetPostForm(key) return value } // DefaultPostForm returns the specified key from a POST urlencoded form or multipart form // when it exists, otherwise it returns the specified defaultValue string. // See: PostForm() and GetPostForm() for further information. func (c *Context) DefaultPostForm(key, defaultValue string) string { if value, ok := c.GetPostForm(key); ok { return value } return defaultValue } // GetPostForm is like PostForm(key). It returns the specified key from a POST urlencoded // form or multipart form when it exists `(value, true)` (even when the value is an empty string), // otherwise it returns ("", false). // For example, during a PATCH request to update the user's email: // email=mail@example.com --> ("mail@example.com", true) := GetPostForm("email") // set email to "mail@example.com" // email= --> ("", true) := GetPostForm("email") // set email to "" // --> ("", false) := GetPostForm("email") // do nothing with email func (c *Context) GetPostForm(key string) (string, bool) { if values, ok := c.GetPostFormArray(key); ok { return values[0], ok } return "", false } // PostFormArray returns a slice of strings for a given form key. // The length of the slice depends on the number of params with the given key. func (c *Context) PostFormArray(key string) []string { values, _ := c.GetPostFormArray(key) return values } // GetPostFormArray returns a slice of strings for a given form key, plus // a boolean value whether at least one value exists for the given key. func (c *Context) GetPostFormArray(key string) ([]string, bool) { req := c.Request req.ParseForm() req.ParseMultipartForm(defaultMemory) if values := req.PostForm[key]; len(values) > 0 { return values, true } if req.MultipartForm != nil && req.MultipartForm.File != nil { if values := req.MultipartForm.Value[key]; len(values) > 0 { return values, true } } return []string{}, false } // FormFile returns the first file for the provided form key. func (c *Context) FormFile(name string) (*multipart.FileHeader, error) { _, fh, err := c.Request.FormFile(name) return fh, err } // MultipartForm is the parsed multipart form, including file uploads. func (c *Context) MultipartForm() (*multipart.Form, error) { err := c.Request.ParseMultipartForm(defaultMemory) return c.Request.MultipartForm, err } // Bind checks the Content-Type to select a binding engine automatically, // Depending the "Content-Type" header different bindings are used: // "application/json" --> JSON binding // "application/xml" --> XML binding // otherwise --> returns an error // It parses the request's body as JSON if Content-Type == "application/json" using JSON or XML as a JSON input. // It decodes the json payload into the struct specified as a pointer. // Like ParseBody() but this method also writes a 400 error if the json is not valid. func (c *Context) Bind(obj interface{}) error { b := binding.Default(c.Request.Method, c.ContentType()) return c.BindWith(obj, b) } // BindJSON is a shortcut for c.BindWith(obj, binding.JSON) func (c *Context) BindJSON(obj interface{}) error { return c.BindWith(obj, binding.JSON) } // BindWith binds the passed struct pointer using the specified binding engine. // See the binding package. func (c *Context) BindWith(obj interface{}, b binding.Binding) error { if err := b.Bind(c.Request, obj); err != nil { c.AbortWithError(400, err).SetType(ErrorTypeBind) return err } return nil } // ClientIP implements a best effort algorithm to return the real client IP, it parses // X-Real-IP and X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy. // Use X-Forwarded-For before X-Real-Ip as nginx uses X-Real-Ip with the proxy's IP. func (c *Context) ClientIP() string { if c.engine.ForwardedByClientIP { clientIP := c.requestHeader("X-Forwarded-For") if index := strings.IndexByte(clientIP, ','); index >= 0 { clientIP = clientIP[0:index] } clientIP = strings.TrimSpace(clientIP) if len(clientIP) > 0 { return clientIP } clientIP = strings.TrimSpace(c.requestHeader("X-Real-Ip")) if len(clientIP) > 0 { return clientIP } } if c.engine.AppEngine { if addr := c.Request.Header.Get("X-Appengine-Remote-Addr"); addr != "" { return addr } } if ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr)); err == nil { return ip } return "" } // ContentType returns the Content-Type header of the request. func (c *Context) ContentType() string { return filterFlags(c.requestHeader("Content-Type")) } // IsWebsocket returns true if the request headers indicate that a websocket // handshake is being initiated by the client. func (c *Context) IsWebsocket() bool { if strings.Contains(strings.ToLower(c.requestHeader("Connection")), "upgrade") && strings.ToLower(c.requestHeader("Upgrade")) == "websocket" { return true } return false } func (c *Context) requestHeader(key string) string { if values, _ := c.Request.Header[key]; len(values) > 0 { return values[0] } return "" } /************************************/ /******** RESPONSE RENDERING ********/ /************************************/ // bodyAllowedForStatus is a copy of http.bodyAllowedForStatus non-exported function func bodyAllowedForStatus(status int) bool { switch { case status >= 100 && status <= 199: return false case status == 204: return false case status == 304: return false } return true } func (c *Context) Status(code int) { c.writermem.WriteHeader(code) } // Header is a intelligent shortcut for c.Writer.Header().Set(key, value) // It writes a header in the response. // If value == "", this method removes the header `c.Writer.Header().Del(key)` func (c *Context) Header(key, value string) { if len(value) == 0 { c.Writer.Header().Del(key) } else { c.Writer.Header().Set(key, value) } } // GetHeader returns value from request headers func (c *Context) GetHeader(key string) string { return c.requestHeader(key) } // GetRawData return stream data func (c *Context) GetRawData() ([]byte, error) { return ioutil.ReadAll(c.Request.Body) } func (c *Context) SetCookie( name string, value string, maxAge int, path string, domain string, secure bool, httpOnly bool, ) { if path == "" { path = "/" } http.SetCookie(c.Writer, &http.Cookie{ Name: name, Value: url.QueryEscape(value), MaxAge: maxAge, Path: path, Domain: domain, Secure: secure, HttpOnly: httpOnly, }) } func (c *Context) Cookie(name string) (string, error) { cookie, err := c.Request.Cookie(name) if err != nil { return "", err } val, _ := url.QueryUnescape(cookie.Value) return val, nil } func (c *Context) Render(code int, r render.Render) { c.Status(code) if !bodyAllowedForStatus(code) { r.WriteContentType(c.Writer) c.Writer.WriteHeaderNow() return } if err := r.Render(c.Writer); err != nil { panic(err) } } // HTML renders the HTTP template specified by its file name. // It also updates the HTTP code and sets the Content-Type as "text/html". // See http://golang.org/doc/articles/wiki/ func (c *Context) HTML(code int, name string, obj interface{}) { instance := c.engine.HTMLRender.Instance(name, obj) c.Render(code, instance) } // IndentedJSON serializes the given struct as pretty JSON (indented + endlines) into the response body. // It also sets the Content-Type as "application/json". // WARNING: we recommend to use this only for development purposes since printing pretty JSON is // more CPU and bandwidth consuming. Use Context.JSON() instead. func (c *Context) IndentedJSON(code int, obj interface{}) { c.Render(code, render.IndentedJSON{Data: obj}) } // JSON serializes the given struct as JSON into the response body. // It also sets the Content-Type as "application/json". func (c *Context) JSON(code int, obj interface{}) { c.Render(code, render.JSON{Data: obj}) } // XML serializes the given struct as XML into the response body. // It also sets the Content-Type as "application/xml". func (c *Context) XML(code int, obj interface{}) { c.Render(code, render.XML{Data: obj}) } // YAML serializes the given struct as YAML into the response body. func (c *Context) YAML(code int, obj interface{}) { c.Render(code, render.YAML{Data: obj}) } // String writes the given string into the response body. func (c *Context) String(code int, format string, values ...interface{}) { c.Render(code, render.String{Format: format, Data: values}) } // Redirect returns a HTTP redirect to the specific location. func (c *Context) Redirect(code int, location string) { c.Render(-1, render.Redirect{ Code: code, Location: location, Request: c.Request, }) } // Data writes some data into the body stream and updates the HTTP code. func (c *Context) Data(code int, contentType string, data []byte) { c.Render(code, render.Data{ ContentType: contentType, Data: data, }) } // File writes the specified file into the body stream in a efficient way. func (c *Context) File(filepath string) { http.ServeFile(c.Writer, c.Request, filepath) } // SSEvent writes a Server-Sent Event into the body stream. func (c *Context) SSEvent(name string, message interface{}) { c.Render(-1, sse.Event{ Event: name, Data: message, }) } func (c *Context) Stream(step func(w io.Writer) bool) { w := c.Writer clientGone := w.CloseNotify() for { select { case <-clientGone: return default: keepOpen := step(w) w.Flush() if !keepOpen { return } } } } /************************************/ /******** CONTENT NEGOTIATION *******/ /************************************/ type Negotiate struct { Offered []string HTMLName string HTMLData interface{} JSONData interface{} XMLData interface{} Data interface{} } func (c *Context) Negotiate(code int, config Negotiate) { switch c.NegotiateFormat(config.Offered...) { case binding.MIMEJSON: data := chooseData(config.JSONData, config.Data) c.JSON(code, data) case binding.MIMEHTML: data := chooseData(config.HTMLData, config.Data) c.HTML(code, config.HTMLName, data) case binding.MIMEXML: data := chooseData(config.XMLData, config.Data) c.XML(code, data) default: c.AbortWithError(http.StatusNotAcceptable, errors.New("the accepted formats are not offered by the server")) } } func (c *Context) NegotiateFormat(offered ...string) string { assert1(len(offered) > 0, "you must provide at least one offer") if c.Accepted == nil { c.Accepted = parseAccept(c.requestHeader("Accept")) } if len(c.Accepted) == 0 { return offered[0] } for _, accepted := range c.Accepted { for _, offert := range offered { if accepted == offert { return offert } } } return "" } func (c *Context) SetAccepted(formats ...string) { c.Accepted = formats } /************************************/ /***** GOLANG.ORG/X/NET/CONTEXT *****/ /************************************/ func (c *Context) Deadline() (deadline time.Time, ok bool) { return } func (c *Context) Done() <-chan struct{} { return nil } func (c *Context) Err() error { return nil } func (c *Context) Value(key interface{}) interface{} { if key == 0 { return c.Request } if keyAsString, ok := key.(string); ok { val, _ := c.Get(keyAsString) return val } return nil }