diff --git a/context.go b/context.go index 44c5352f..fdca9e3e 100644 --- a/context.go +++ b/context.go @@ -6,6 +6,7 @@ package gin import ( "errors" + "io" "math" "net/http" "strings" @@ -379,7 +380,21 @@ func (c *Context) File(filepath string) { http.ServeFile(c.Writer, c.Request, filepath) } -func (c *Context) Stream(step func(w http.ResponseWriter)) { +func (c *Context) SSEvent(name string, message interface{}) { + render.WriteSSEvent(c.Writer, name, message) +} + +func (c *Context) Header(code int, headers map[string]string) { + if len(headers) > 0 { + header := c.Writer.Header() + for key, value := range headers { + header.Set(key, value) + } + } + c.Writer.WriteHeader(code) +} + +func (c *Context) Stream(step func(w io.Writer) bool) { w := c.Writer clientGone := w.CloseNotify() for { @@ -387,8 +402,11 @@ func (c *Context) Stream(step func(w http.ResponseWriter)) { case <-clientGone: return default: - step(w) + keepopen := step(w) w.Flush() + if !keepopen { + return + } } } } diff --git a/render/ssevent.go b/render/ssevent.go new file mode 100644 index 00000000..34f4e475 --- /dev/null +++ b/render/ssevent.go @@ -0,0 +1,58 @@ +package render + +import ( + "encoding/json" + "fmt" + "net/http" + "reflect" +) + +type sseRender struct{} + +var SSEvent Render = sseRender{} + +func (_ sseRender) Render(w http.ResponseWriter, code int, data ...interface{}) error { + eventName := data[0].(string) + obj := data[1] + return WriteSSEvent(w, eventName, obj) +} + +func WriteSSEvent(w http.ResponseWriter, eventName string, data interface{}) error { + header := w.Header() + if len(header.Get("Content-Type")) == 0 { + w.Header().Set("Content-Type", "text/event-stream") + } + var stringData string + switch typeOfData(data) { + case reflect.Struct, reflect.Slice: + if jsonBytes, err := json.Marshal(data); err == nil { + stringData = string(jsonBytes) + } else { + return err + } + case reflect.Ptr: + stringData = escape(fmt.Sprintf("%v", &data)) + "\n" + default: + stringData = escape(fmt.Sprintf("%v", data)) + "\n" + } + _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n", escape(eventName), stringData) + return err +} + +func typeOfData(data interface{}) reflect.Kind { + value := reflect.ValueOf(data) + valueType := value.Kind() + if valueType == reflect.Ptr { + newValue := value.Elem().Kind() + if newValue == reflect.Struct || newValue == reflect.Slice { + return newValue + } else { + return valueType + } + } + return valueType +} + +func escape(str string) string { + return str +}