// Copyright 2013 Gary Burd. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package websocket import ( "bufio" "encoding/binary" "errors" "io" "io/ioutil" "math/rand" "net" "strconv" "time" ) // Close codes defined in RFC 6455, section 11.7. const ( CloseNormalClosure = 1000 CloseGoingAway = 1001 CloseProtocolError = 1002 CloseUnsupportedData = 1003 CloseNoStatusReceived = 1005 CloseAbnormalClosure = 1006 CloseInvalidFramePayloadData = 1007 ClosePolicyViolation = 1008 CloseMessageTooBig = 1009 CloseMandatoryExtension = 1010 CloseInternalServerErr = 1011 CloseTLSHandshake = 1015 ) // The message types are defined in RFC 6455, section 11.8. const ( // TextMessage denotes a text data message. The text message payload is // interpreted as UTF-8 encoded text data. TextMessage = 1 // BinaryMessage denotes a binary data message. BinaryMessage = 2 // CloseMessage denotes a close control message. The optional message // payload contains a numeric code and text. Use the FormatCloseMessage // function to format a close message payload. CloseMessage = 8 // PingMessage denotes a ping control message. The optional message payload // is UTF-8 encoded text. PingMessage = 9 // PongMessage denotes a ping control message. The optional message payload // is UTF-8 encoded text. PongMessage = 10 ) var ( continuationFrame = 0 noFrame = -1 ) var ( ErrCloseSent = errors.New("websocket: close sent") ErrReadLimit = errors.New("websocket: read limit exceeded") ) type websocketError struct { msg string temporary bool timeout bool } func (e *websocketError) Error() string { return e.msg } func (e *websocketError) Temporary() bool { return e.temporary } func (e *websocketError) Timeout() bool { return e.timeout } var ( errWriteTimeout = &websocketError{msg: "websocket: write timeout", timeout: true} errBadWriteOpCode = errors.New("websocket: bad write message type") errWriteClosed = errors.New("websocket: write closed") errInvalidControlFrame = errors.New("websocket: invalid control frame") ) const ( maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask maxControlFramePayloadSize = 125 finalBit = 1 << 7 maskBit = 1 << 7 writeWait = time.Second ) func isControl(frameType int) bool { return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage } func isData(frameType int) bool { return frameType == TextMessage || frameType == BinaryMessage } func maskBytes(key [4]byte, pos int, b []byte) int { for i := range b { b[i] ^= key[pos&3] pos += 1 } return pos & 3 } func newMaskKey() [4]byte { n := rand.Uint32() return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 32)} } // Conn represents a WebSocket connection. type Conn struct { conn net.Conn isServer bool subprotocol string // Write fields mu chan bool // used as mutex to protect write to conn and closeSent closeSent bool // true if close message was sent // Message writer fields. writeErr error writeBuf []byte // frame is constructed in this buffer. writePos int // end of data in writeBuf. writeFrameType int // type of the current frame. writeSeq int // incremented to invalidate message writers. writeDeadline time.Time // Read fields readErr error br *bufio.Reader readRemaining int64 // bytes remaining in current frame. readFinal bool // true the current message has more frames. readSeq int // incremented to invalidate message readers. readLength int64 // Message size. readLimit int64 // Maximum message size. readMaskPos int readMaskKey [4]byte handlePong func(string) error handlePing func(string) error } func newConn(conn net.Conn, isServer bool, readBufSize, writeBufSize int) *Conn { mu := make(chan bool, 1) mu <- true c := &Conn{ isServer: isServer, br: bufio.NewReaderSize(conn, readBufSize), conn: conn, mu: mu, readFinal: true, writeBuf: make([]byte, writeBufSize+maxFrameHeaderSize), writeFrameType: noFrame, writePos: maxFrameHeaderSize, } c.SetPingHandler(nil) c.SetPongHandler(nil) return c } // Subprotocol returns the negotiated protocol for the connection. func (c *Conn) Subprotocol() string { return c.subprotocol } // Close closes the underlying network connection without sending or waiting for a close frame. func (c *Conn) Close() error { return c.conn.Close() } // LocalAddr returns the local network address. func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() } // RemoteAddr returns the remote network address. func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } // Write methods func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { <-c.mu defer func() { c.mu <- true }() if c.closeSent { return ErrCloseSent } else if frameType == CloseMessage { c.closeSent = true } c.conn.SetWriteDeadline(deadline) for _, buf := range bufs { if len(buf) > 0 { n, err := c.conn.Write(buf) if n != len(buf) { // Close on partial write. c.conn.Close() } if err != nil { return err } } } return nil } // WriteControl writes a control message with the given deadline. The allowed // message types are CloseMessage, PingMessage and PongMessage. func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { if !isControl(messageType) { return errBadWriteOpCode } if len(data) > maxControlFramePayloadSize { return errInvalidControlFrame } b0 := byte(messageType) | finalBit b1 := byte(len(data)) if !c.isServer { b1 |= maskBit } buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize) buf = append(buf, b0, b1) if c.isServer { buf = append(buf, data...) } else { key := newMaskKey() buf = append(buf, key[:]...) buf = append(buf, data...) maskBytes(key, 0, buf[6:]) } d := time.Hour * 1000 if !deadline.IsZero() { d = deadline.Sub(time.Now()) if d < 0 { return errWriteTimeout } } timer := time.NewTimer(d) select { case <-c.mu: timer.Stop() case <-timer.C: return errWriteTimeout } defer func() { c.mu <- true }() if c.closeSent { return ErrCloseSent } else if messageType == CloseMessage { c.closeSent = true } c.conn.SetWriteDeadline(deadline) n, err := c.conn.Write(buf) if n != 0 && n != len(buf) { c.conn.Close() } return err } // NextWriter returns a writer for the next message to send. The writer's // Close method flushes the complete message to the network. // // There can be at most one open writer on a connection. NextWriter closes the // previous writer if the application has not already done so. // // The NextWriter method and the writers returned from the method cannot be // accessed by more than one goroutine at a time. func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { if c.writeErr != nil { return nil, c.writeErr } if c.writeFrameType != noFrame { if err := c.flushFrame(true, nil); err != nil { return nil, err } } if !isControl(messageType) && !isData(messageType) { return nil, errBadWriteOpCode } c.writeFrameType = messageType return messageWriter{c, c.writeSeq}, nil } func (c *Conn) flushFrame(final bool, extra []byte) error { length := c.writePos - maxFrameHeaderSize + len(extra) // Check for invalid control frames. if isControl(c.writeFrameType) && (!final || length > maxControlFramePayloadSize) { c.writeSeq += 1 c.writeFrameType = noFrame c.writePos = maxFrameHeaderSize return errInvalidControlFrame } b0 := byte(c.writeFrameType) if final { b0 |= finalBit } b1 := byte(0) if !c.isServer { b1 |= maskBit } // Assume that the frame starts at beginning of c.writeBuf. framePos := 0 if c.isServer { // Adjust up if mask not included in the header. framePos = 4 } switch { case length >= 65536: c.writeBuf[framePos] = b0 c.writeBuf[framePos+1] = b1 | 127 binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length)) case length > 125: framePos += 6 c.writeBuf[framePos] = b0 c.writeBuf[framePos+1] = b1 | 126 binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length)) default: framePos += 8 c.writeBuf[framePos] = b0 c.writeBuf[framePos+1] = b1 | byte(length) } if !c.isServer { key := newMaskKey() copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:c.writePos]) if len(extra) > 0 { c.writeErr = errors.New("websocket: internal error, extra used in client mode") return c.writeErr } } // Write the buffers to the connection. c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra) // Setup for next frame. c.writePos = maxFrameHeaderSize c.writeFrameType = continuationFrame if final { c.writeSeq += 1 c.writeFrameType = noFrame } return c.writeErr } type messageWriter struct { c *Conn seq int } func (w messageWriter) err() error { c := w.c if c.writeSeq != w.seq { return errWriteClosed } if c.writeErr != nil { return c.writeErr } return nil } func (w messageWriter) ncopy(max int) (int, error) { n := len(w.c.writeBuf) - w.c.writePos if n <= 0 { if err := w.c.flushFrame(false, nil); err != nil { return 0, err } n = len(w.c.writeBuf) - w.c.writePos } if n > max { n = max } return n, nil } func (w messageWriter) write(final bool, p []byte) (int, error) { if err := w.err(); err != nil { return 0, err } if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { // Don't buffer large messages. err := w.c.flushFrame(final, p) if err != nil { return 0, err } return len(p), nil } nn := len(p) for len(p) > 0 { n, err := w.ncopy(len(p)) if err != nil { return 0, err } copy(w.c.writeBuf[w.c.writePos:], p[:n]) w.c.writePos += n p = p[n:] } return nn, nil } func (w messageWriter) Write(p []byte) (int, error) { return w.write(false, p) } func (w messageWriter) WriteString(p string) (int, error) { if err := w.err(); err != nil { return 0, err } nn := len(p) for len(p) > 0 { n, err := w.ncopy(len(p)) if err != nil { return 0, err } copy(w.c.writeBuf[w.c.writePos:], p[:n]) w.c.writePos += n p = p[n:] } return nn, nil } func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { if err := w.err(); err != nil { return 0, err } for { if w.c.writePos == len(w.c.writeBuf) { err = w.c.flushFrame(false, nil) if err != nil { break } } var n int n, err = r.Read(w.c.writeBuf[w.c.writePos:]) w.c.writePos += n nn += int64(n) if err != nil { if err == io.EOF { err = nil } break } } return nn, err } func (w messageWriter) Close() error { if err := w.err(); err != nil { return err } return w.c.flushFrame(true, nil) } // WriteMessage is a helper method for getting a writer using NextWriter, // writing the message and closing the writer. func (c *Conn) WriteMessage(messageType int, data []byte) error { wr, err := c.NextWriter(messageType) if err != nil { return err } w := wr.(messageWriter) if _, err := w.write(true, data); err != nil { return err } if c.writeSeq == w.seq { if err := c.flushFrame(true, nil); err != nil { return err } } return nil } // SetWriteDeadline sets the deadline for future calls to NextWriter and the // io.WriteCloser returned from NextWriter. If the deadline is reached, the // call will fail with a timeout instead of blocking. A zero value for t means // Write will not time out. Even if Write times out, it may return n > 0, // indicating that some of the data was successfully written. func (c *Conn) SetWriteDeadline(t time.Time) error { c.writeDeadline = t return nil } // Read methods func (c *Conn) advanceFrame() (int, error) { // 1. Skip remainder of previous frame. if c.readRemaining > 0 { if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil { return noFrame, err } } // 2. Read and parse first two bytes of frame header. var b [8]byte if err := c.read(b[:2]); err != nil { return noFrame, err } final := b[0]&finalBit != 0 frameType := int(b[0] & 0xf) reserved := int((b[0] >> 4) & 0x7) mask := b[1]&maskBit != 0 c.readRemaining = int64(b[1] & 0x7f) if reserved != 0 { return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved)) } switch frameType { case CloseMessage, PingMessage, PongMessage: if c.readRemaining > maxControlFramePayloadSize { return noFrame, c.handleProtocolError("control frame length > 125") } if !final { return noFrame, c.handleProtocolError("control frame not final") } case TextMessage, BinaryMessage: if !c.readFinal { return noFrame, c.handleProtocolError("message start before final message frame") } c.readFinal = final case continuationFrame: if c.readFinal { return noFrame, c.handleProtocolError("continuation after final message frame") } c.readFinal = final default: return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType)) } // 3. Read and parse frame length. switch c.readRemaining { case 126: if err := c.read(b[:2]); err != nil { return noFrame, err } c.readRemaining = int64(binary.BigEndian.Uint16(b[:2])) case 127: if err := c.read(b[:8]); err != nil { return noFrame, err } c.readRemaining = int64(binary.BigEndian.Uint64(b[:8])) } // 4. Handle frame masking. if mask != c.isServer { return noFrame, c.handleProtocolError("incorrect mask flag") } if mask { c.readMaskPos = 0 if err := c.read(c.readMaskKey[:]); err != nil { return noFrame, err } } // 5. For text and binary messages, enforce read limit and return. if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { c.readLength += c.readRemaining if c.readLimit > 0 && c.readLength > c.readLimit { c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) return noFrame, ErrReadLimit } return frameType, nil } // 6. Read control frame payload. payload := make([]byte, c.readRemaining) c.readRemaining = 0 if err := c.read(payload); err != nil { return noFrame, err } maskBytes(c.readMaskKey, 0, payload) // 7. Process control frame payload. switch frameType { case PongMessage: if err := c.handlePong(string(payload)); err != nil { return noFrame, err } case PingMessage: if err := c.handlePing(string(payload)); err != nil { return noFrame, err } case CloseMessage: c.WriteControl(CloseMessage, []byte{}, time.Now().Add(writeWait)) if len(payload) < 2 { return noFrame, io.EOF } closeCode := binary.BigEndian.Uint16(payload) switch closeCode { case CloseNormalClosure, CloseGoingAway: return noFrame, io.EOF default: return noFrame, errors.New("websocket: close " + strconv.Itoa(int(closeCode)) + " " + string(payload[2:])) } } return frameType, nil } func (c *Conn) handleProtocolError(message string) error { c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait)) return errors.New("websocket: " + message) } func (c *Conn) read(buf []byte) error { var err error for len(buf) > 0 && err == nil { var nn int nn, err = c.br.Read(buf) buf = buf[nn:] } if err == io.EOF { if len(buf) == 0 { err = nil } else { err = io.ErrUnexpectedEOF } } return err } // NextReader returns the next data message received from the peer. The // returned messageType is either TextMessage or BinaryMessage. // // There can be at most one open reader on a connection. NextReader discards // the previous message if the application has not already consumed it. // // The NextReader method and the readers returned from the method cannot be // accessed by more than one goroutine at a time. func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { c.readSeq += 1 c.readLength = 0 for c.readErr == nil { var frameType int frameType, c.readErr = c.advanceFrame() if frameType == TextMessage || frameType == BinaryMessage { return frameType, messageReader{c, c.readSeq}, nil } } return noFrame, nil, c.readErr } type messageReader struct { c *Conn seq int } func (r messageReader) Read(b []byte) (n int, err error) { if r.seq != r.c.readSeq { return 0, io.EOF } for r.c.readErr == nil { if r.c.readRemaining > 0 { if int64(len(b)) > r.c.readRemaining { b = b[:r.c.readRemaining] } r.c.readErr = r.c.read(b) r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b) r.c.readRemaining -= int64(len(b)) return len(b), r.c.readErr } if r.c.readFinal { r.c.readSeq += 1 return 0, io.EOF } var frameType int frameType, r.c.readErr = r.c.advanceFrame() if frameType == TextMessage || frameType == BinaryMessage { r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") } } return 0, r.c.readErr } // ReadMessage is a helper method for getting a reader using NextReader and // reading from that reader to a buffer. func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { var r io.Reader messageType, r, err = c.NextReader() if err != nil { return messageType, nil, err } p, err = ioutil.ReadAll(r) return messageType, p, err } // SetReadDeadline sets the deadline for future calls to NextReader and the // io.Reader returned from NextReader. If the deadline is reached, the call // will fail with a timeout instead of blocking. A zero value for t means that // the methods will not time out. func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } // SetReadLimit sets the maximum size for a message read from the peer. If a // message exceeds the limit, the connection sends a close frame to the peer // and returns ErrReadLimit to the application. func (c *Conn) SetReadLimit(limit int64) { c.readLimit = limit } // SetLinger sets the behavior of Close() on a connection which still has // data waiting to be sent or to be acknowledged. If sec < 0 (the default), // Close returns immediately and the operating system finishes sending the // data in the background. If sec == 0, Close returns immediately and the // operating system discards any unsent or unacknowledged data. If sec > 0, // Close blocks for at most sec seconds waiting for data to be sent and // acknowledged. func (c *Conn) SetLinger(sec int) error { return c.conn.(*net.TCPConn).SetLinger(sec) } // SetNoDelay controls whether the operating system should delay packet // transmission in hopes of sending fewer packets (Nagle's algorithm). // The default is true (no delay), meaning that data is sent as soon as // possible after a Write. func (c *Conn) SetNoDelay(noDelay bool) error { return c.conn.(*net.TCPConn).SetNoDelay(noDelay) } // SetPingHandler sets the handler for ping messages received from the peer. // The default ping handler sends a pong to the peer. func (c *Conn) SetPingHandler(h func(string) error) { if h == nil { h = func(message string) error { c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) return nil } } c.handlePing = h } // SetPongHandler sets then handler for pong messages received from the peer. // The default pong handler does nothing. func (c *Conn) SetPongHandler(h func(string) error) { if h == nil { h = func(string) error { return nil } } c.handlePong = h } // FormatCloseMessage formats closeCode and text as a WebSocket close message. func FormatCloseMessage(closeCode int, text string) []byte { buf := make([]byte, 2+len(text)) binary.BigEndian.PutUint16(buf, uint16(closeCode)) copy(buf[2:], text) return buf }