commit 273ecadfcae3306e1943ac6d4160ef02cac48247 Author: Gary Burd Date: Wed Oct 16 09:41:47 2013 -0700 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0026861 --- /dev/null +++ b/.gitignore @@ -0,0 +1,22 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..09e5be6 --- /dev/null +++ b/LICENSE @@ -0,0 +1,23 @@ +Copyright (c) 2013, Gorilla web toolkit +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the documentation and/or + other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..42d2bc1 --- /dev/null +++ b/README.md @@ -0,0 +1,26 @@ +# WebSocket + +This project is a [Go](http://golang.org/) implementation of the +[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. + +The project passes the server tests in the [Autobahn WebSockets Test +Suite](http://autobahn.ws/testsuite) using the application in the [examples/autobahn +subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn). + +## Documentation + +* [Reference](http://godoc.org/github.com/gorilla/websocket) +* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat) + +## Features + +- Send and receive ping, pong and close control messages. +- Limit size of received messages. +- Stream messages. +- Specify IO buffer sizes. +- Application has full control over origin checks and sub-protocol negotiation. + +## Installation + + go get github.com/gorilla/websocket + diff --git a/client.go b/client.go new file mode 100644 index 0000000..be87444 --- /dev/null +++ b/client.go @@ -0,0 +1,69 @@ +// Copyright 2013 Gary Burd. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "errors" + "net" + "net/http" + "net/url" + "strings" +) + +// ErrBadHandshake is returned when the server response to opening handshake is +// invalid. +var ErrBadHandshake = errors.New("websocket: bad handshake") + +// NewClient creates a new client connection using the given net connection. +// The URL u specifies the host and request URI. Use requestHeader to specify +// the origin (Origin), subprotocols (Set-WebSocket-Protocol) and cookies +// (Cookie). Use the response.Header to get the selected subprotocol +// (Sec-WebSocket-Protocol) and cookies (Set-Cookie). +// +// If the WebSocket handshake fails, ErrBadHandshake is returned along with a +// non-nil *http.Response so that callers can handle redirects, authentication, +// etc. +func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) { + challengeKey, err := generateChallengeKey() + if err != nil { + return nil, nil, err + } + acceptKey := computeAcceptKey(challengeKey) + + c = newConn(netConn, false, readBufSize, writeBufSize) + p := c.writeBuf[:0] + p = append(p, "GET "...) + p = append(p, u.RequestURI()...) + p = append(p, " HTTP/1.1\r\nHost: "...) + p = append(p, u.Host...) + p = append(p, "\r\nUpgrade: websocket\r\nConnection: upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: "...) + p = append(p, challengeKey...) + p = append(p, "\r\n"...) + for k, vs := range requestHeader { + for _, v := range vs { + p = append(p, k...) + p = append(p, ": "...) + p = append(p, v...) + p = append(p, "\r\n"...) + } + } + p = append(p, "\r\n"...) + + if _, err := netConn.Write(p); err != nil { + return nil, nil, err + } + + resp, err := http.ReadResponse(c.br, &http.Request{Method: "GET", URL: u}) + if err != nil { + return nil, nil, err + } + if resp.StatusCode != 101 || + !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || + !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + resp.Header.Get("Sec-Websocket-Accept") != acceptKey { + return nil, resp, ErrBadHandshake + } + return c, resp, nil +} diff --git a/client_server_test.go b/client_server_test.go new file mode 100644 index 0000000..6464e36 --- /dev/null +++ b/client_server_test.go @@ -0,0 +1,114 @@ +// Copyright 2013 Gary Burd. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket_test + +import ( + "io" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +type wsHandler struct { + *testing.T +} + +func (t wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + http.Error(w, "Method not allowed", 405) + t.Logf("bad method: %s", r.Method) + return + } + if r.Header.Get("Origin") != "http://"+r.Host { + http.Error(w, "Origin not allowed", 403) + t.Logf("bad origin: %s", r.Header.Get("Origin")) + return + } + ws, err := websocket.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}, 1024, 1024) + if _, ok := err.(websocket.HandshakeError); ok { + t.Logf("bad handshake: %v", err) + http.Error(w, "Not a websocket handshake", 400) + return + } else if err != nil { + t.Logf("upgrade error: %v", err) + return + } + defer ws.Close() + for { + op, r, err := ws.NextReader() + if err != nil { + if err != io.EOF { + t.Logf("NextReader: %v", err) + } + return + } + if op == websocket.PongMessage { + continue + } + w, err := ws.NextWriter(op) + if err != nil { + t.Logf("NextWriter: %v", err) + return + } + if _, err = io.Copy(w, r); err != nil { + t.Logf("Copy: %v", err) + return + } + if err := w.Close(); err != nil { + t.Logf("Close: %v", err) + return + } + } +} + +func TestClientServer(t *testing.T) { + s := httptest.NewServer(wsHandler{t}) + defer s.Close() + u, _ := url.Parse(s.URL) + c, err := net.Dial("tcp", u.Host) + if err != nil { + t.Fatalf("Dial: %v", err) + } + ws, resp, err := websocket.NewClient(c, u, http.Header{"Origin": {s.URL}}, 1024, 1024) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + defer ws.Close() + + var sessionID string + for _, c := range resp.Cookies() { + if c.Name == "sessionID" { + sessionID = c.Value + } + } + if sessionID != "1234" { + t.Error("Set-Cookie not received from the server.") + } + + w, _ := ws.NextWriter(websocket.TextMessage) + io.WriteString(w, "HELLO") + w.Close() + ws.SetReadDeadline(time.Now().Add(1 * time.Second)) + op, r, err := ws.NextReader() + if err != nil { + t.Fatalf("NextReader: %v", err) + } + if op != websocket.TextMessage { + t.Fatalf("op=%d, want %d", op, websocket.TextMessage) + } + b, err := ioutil.ReadAll(r) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + if string(b) != "HELLO" { + t.Fatalf("message=%s, want %s", b, "HELLO") + } +} diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..bef4860 --- /dev/null +++ b/conn.go @@ -0,0 +1,759 @@ +// Copyright 2013 Gary Burd. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "encoding/binary" + "errors" + "io" + "io/ioutil" + "math/rand" + "net" + "strconv" + "time" +) + +// Close codes defined in RFC 6455, section 11.7. +const ( + CloseNormalClosure = 1000 + CloseGoingAway = 1001 + CloseProtocolError = 1002 + CloseUnsupportedData = 1003 + CloseNoStatusReceived = 1005 + CloseAbnormalClosure = 1006 + CloseInvalidFramePayloadData = 1007 + ClosePolicyViolation = 1008 + CloseMessageTooBig = 1009 + CloseMandatoryExtension = 1010 + CloseInternalServerErr = 1011 + CloseTLSHandshake = 1015 +) + +// The message types are defined in RFC 6455, section 11.8. +const ( + // TextMessage denotes a text message. The text message payload is + // interpreted as UTF-8 encoded text data. + TextMessage = 1 + + // BinaryMessage denotes a binary data message. + BinaryMessage = 2 + + // CloseMessage denotes a close control message. The optional message + // payload contains a numeric code and text. Use the FormatCloseMessage + // function to format a close message payload. + CloseMessage = 8 + + // PingMessage denotes a ping control message. The optional message payload + // is UTF-8 encoded text. + PingMessage = 9 + + // PongMessage denotes a ping control message. The optional message payload + // is UTF-8 encoded text. + PongMessage = 10 +) + +var ( + continuationFrame = 0 + noFrame = -1 +) + +var ( + ErrCloseSent = errors.New("websocket: close sent") + ErrReadLimit = errors.New("websocket: read limit exceeded") +) + +var ( + errBadWriteOpCode = errors.New("websocket: bad write message type") + errWriteTimeout = errors.New("websocket: write timeout") + errWriteClosed = errors.New("websocket: write closed") + errInvalidControlFrame = errors.New("websocket: invalid control frame") +) + +const ( + maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask + maxControlFramePayloadSize = 125 + finalBit = 1 << 7 + maskBit = 1 << 7 + writeWait = time.Second +) + +func isControl(frameType int) bool { + return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage +} + +func isData(frameType int) bool { + return frameType == TextMessage || frameType == BinaryMessage +} + +func maskBytes(key [4]byte, pos int, b []byte) int { + for i := range b { + b[i] ^= key[pos&3] + pos += 1 + } + return pos & 3 +} + +func newMaskKey() [4]byte { + n := rand.Uint32() + return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 32)} +} + +// Conn represents a WebSocket connection. +type Conn struct { + conn net.Conn + isServer bool + + // Write fields + mu chan bool // used as mutex to protect write to conn and closeSent + closeSent bool // true if close message was sent + + // Message writer fields. + writeErr error + writeBuf []byte // frame is constructed in this buffer. + writePos int // end of data in writeBuf. + writeFrameType int // type of the current frame. + writeSeq int // incremented to invalidate message writers. + writeDeadline time.Time + + // Read fields + readErr error + br *bufio.Reader + readRemaining int64 // bytes remaining in current frame. + readFinal bool // true the current message has more frames. + readSeq int // incremented to invalidate message readers. + readLength int64 // Message size. + readLimit int64 // Maximum message size. + readMaskPos int + readMaskKey [4]byte + handlePong func(string) error + handlePing func(string) error +} + +func newConn(conn net.Conn, isServer bool, readBufSize, writeBufSize int) *Conn { + mu := make(chan bool, 1) + mu <- true + + c := &Conn{ + isServer: isServer, + br: bufio.NewReaderSize(conn, readBufSize), + conn: conn, + mu: mu, + readFinal: true, + writeBuf: make([]byte, writeBufSize+maxFrameHeaderSize), + writeFrameType: noFrame, + writePos: maxFrameHeaderSize, + } + c.SetPingHandler(nil) + c.SetPongHandler(nil) + return c +} + +// Close closes the underlying network connection without sending or waiting for a close frame. +func (c *Conn) Close() error { + return c.conn.Close() +} + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// Write methods + +func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { + <-c.mu + defer func() { c.mu <- true }() + + if c.closeSent { + return ErrCloseSent + } else if frameType == CloseMessage { + c.closeSent = true + } + + c.conn.SetWriteDeadline(deadline) + for _, buf := range bufs { + if len(buf) > 0 { + n, err := c.conn.Write(buf) + if n != len(buf) { + // Close on partial write. + c.conn.Close() + } + if err != nil { + return err + } + } + } + return nil +} + +// WriteControl writes a control message with the given deadline. The allowed +// message types are CloseMessage, PingMessage and PongMessage. +func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { + if !isControl(messageType) { + return errBadWriteOpCode + } + if len(data) > maxControlFramePayloadSize { + return errInvalidControlFrame + } + + b0 := byte(messageType) | finalBit + b1 := byte(len(data)) + if !c.isServer { + b1 |= maskBit + } + + buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize) + buf = append(buf, b0, b1) + + if c.isServer { + buf = append(buf, data...) + } else { + key := newMaskKey() + buf = append(buf, key[:]...) + buf = append(buf, data...) + maskBytes(key, 0, buf[6:]) + } + + d := time.Hour * 1000 + if !deadline.IsZero() { + d = deadline.Sub(time.Now()) + if d < 0 { + return errWriteTimeout + } + } + + timer := time.NewTimer(d) + select { + case <-c.mu: + timer.Stop() + case <-timer.C: + return errWriteTimeout + } + defer func() { c.mu <- true }() + + if c.closeSent { + return ErrCloseSent + } else if messageType == CloseMessage { + c.closeSent = true + } + + c.conn.SetWriteDeadline(deadline) + n, err := c.conn.Write(buf) + if n != 0 && n != len(buf) { + c.conn.Close() + } + return err +} + +// NextWriter returns a writer for the next message to send. The writer's +// Close method flushes the complete message to the network. +// +// There can be at most one open writer on a connection. NextWriter closes the +// previous writer if the application has not already done so. +// +// The NextWriter method and the writers returned from the method cannot be +// accessed by more than one goroutine at a time. +func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { + if c.writeErr != nil { + return nil, c.writeErr + } + + if c.writeFrameType != noFrame { + if err := c.flushFrame(true, nil); err != nil { + return nil, err + } + } + + if !isControl(messageType) && !isData(messageType) { + return nil, errBadWriteOpCode + } + + c.writeFrameType = messageType + return messageWriter{c, c.writeSeq}, nil +} + +func (c *Conn) flushFrame(final bool, extra []byte) error { + length := c.writePos - maxFrameHeaderSize + len(extra) + + // Check for invalid control frames. + if isControl(c.writeFrameType) && + (!final || length > maxControlFramePayloadSize) { + c.writeSeq += 1 + c.writeFrameType = noFrame + c.writePos = maxFrameHeaderSize + return errInvalidControlFrame + } + + b0 := byte(c.writeFrameType) + if final { + b0 |= finalBit + } + b1 := byte(0) + if !c.isServer { + b1 |= maskBit + } + + // Assume that the frame starts at beginning of c.writeBuf. + framePos := 0 + if c.isServer { + // Adjust up if mask not included in the header. + framePos = 4 + } + + switch { + case length >= 65536: + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | 127 + binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length)) + case length > 125: + framePos += 6 + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | 126 + binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length)) + default: + framePos += 8 + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | byte(length) + } + + if !c.isServer { + key := newMaskKey() + copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) + maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:c.writePos]) + if len(extra) > 0 { + c.writeErr = errors.New("websocket: internal error, extra used in client mode") + return c.writeErr + } + } + + // Write the buffers to the connection. + c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra) + + // Setup for next frame. + c.writePos = maxFrameHeaderSize + c.writeFrameType = continuationFrame + if final { + c.writeSeq += 1 + c.writeFrameType = noFrame + } + return c.writeErr +} + +type messageWriter struct { + c *Conn + seq int +} + +func (w messageWriter) err() error { + c := w.c + if c.writeSeq != w.seq { + return errWriteClosed + } + if c.writeErr != nil { + return c.writeErr + } + return nil +} + +func (w messageWriter) ncopy(max int) (int, error) { + n := len(w.c.writeBuf) - w.c.writePos + if n <= 0 { + if err := w.c.flushFrame(false, nil); err != nil { + return 0, err + } + n = len(w.c.writeBuf) - w.c.writePos + } + if n > max { + n = max + } + return n, nil +} + +func (w messageWriter) write(final bool, p []byte) (int, error) { + if err := w.err(); err != nil { + return 0, err + } + + if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { + // Don't buffer large messages. + err := w.c.flushFrame(final, p) + if err != nil { + return 0, err + } + return len(p), nil + } + + nn := len(p) + for len(p) > 0 { + n, err := w.ncopy(len(p)) + if err != nil { + return 0, err + } + copy(w.c.writeBuf[w.c.writePos:], p[:n]) + w.c.writePos += n + p = p[n:] + } + return nn, nil +} + +func (w messageWriter) Write(p []byte) (int, error) { + return w.write(false, p) +} + +func (w messageWriter) WriteString(p string) (int, error) { + if err := w.err(); err != nil { + return 0, err + } + + nn := len(p) + for len(p) > 0 { + n, err := w.ncopy(len(p)) + if err != nil { + return 0, err + } + copy(w.c.writeBuf[w.c.writePos:], p[:n]) + w.c.writePos += n + p = p[n:] + } + return nn, nil +} + +func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { + if err := w.err(); err != nil { + return 0, err + } + for { + if w.c.writePos == len(w.c.writeBuf) { + err = w.c.flushFrame(false, nil) + if err != nil { + break + } + } + var n int + n, err = r.Read(w.c.writeBuf[w.c.writePos:]) + w.c.writePos += n + nn += int64(n) + if err != nil { + if err == io.EOF { + err = nil + } + break + } + } + return nn, err +} + +func (w messageWriter) Close() error { + if err := w.err(); err != nil { + return err + } + return w.c.flushFrame(true, nil) +} + +// WriteMessage is a helper method for getting a writer using NextWriter, +// writing the message and closing the writer. +func (c *Conn) WriteMessage(messageType int, data []byte) error { + wr, err := c.NextWriter(messageType) + if err != nil { + return err + } + w := wr.(messageWriter) + if _, err := w.write(true, data); err != nil { + return err + } + if c.writeSeq == w.seq { + if err := c.flushFrame(true, nil); err != nil { + return err + } + } + return nil +} + +// SetWriteDeadline sets the deadline for future calls to NextWriter and the +// io.WriteCloser returned from NextWriter. If the deadline is reached, the +// call will fail with a timeout instead of blocking. A zero value for t means +// Write will not time out. Even if Write times out, it may return n > 0, +// indicating that some of the data was successfully written. +func (c *Conn) SetWriteDeadline(t time.Time) error { + c.writeDeadline = t + return nil +} + +// Read methods + +func (c *Conn) advanceFrame() (int, error) { + + // 1. Skip remainder of previous frame. + + if c.readRemaining > 0 { + if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil { + return noFrame, err + } + } + + // 2. Read and parse first two bytes of frame header. + + var b [8]byte + if err := c.read(b[:2]); err != nil { + return noFrame, err + } + + final := b[0]&finalBit != 0 + frameType := int(b[0] & 0xf) + reserved := int((b[0] >> 4) & 0x7) + mask := b[1]&maskBit != 0 + c.readRemaining = int64(b[1] & 0x7f) + + if reserved != 0 { + return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved)) + } + + switch frameType { + case CloseMessage, PingMessage, PongMessage: + if c.readRemaining > maxControlFramePayloadSize { + return noFrame, c.handleProtocolError("control frame length > 125") + } + if !final { + return noFrame, c.handleProtocolError("control frame not final") + } + case TextMessage, BinaryMessage: + if !c.readFinal { + return noFrame, c.handleProtocolError("message start before final message frame") + } + c.readFinal = final + case continuationFrame: + if c.readFinal { + return noFrame, c.handleProtocolError("continuation after final message frame") + } + c.readFinal = final + default: + return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType)) + } + + // 3. Read and parse frame length. + + switch c.readRemaining { + case 126: + if err := c.read(b[:2]); err != nil { + return noFrame, err + } + c.readRemaining = int64(binary.BigEndian.Uint16(b[:2])) + case 127: + if err := c.read(b[:8]); err != nil { + return noFrame, err + } + c.readRemaining = int64(binary.BigEndian.Uint64(b[:8])) + } + + // 4. Handle frame masking. + + if mask != c.isServer { + return noFrame, c.handleProtocolError("incorrect mask flag") + } + + if mask { + c.readMaskPos = 0 + if err := c.read(c.readMaskKey[:]); err != nil { + return noFrame, err + } + } + + // 5. For text and binary messages, enforce read limit and return. + + if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { + + c.readLength += c.readRemaining + if c.readLimit > 0 && c.readLength > c.readLimit { + c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) + return noFrame, ErrReadLimit + } + + return frameType, nil + } + + // 6. Read control frame payload. + + payload := make([]byte, c.readRemaining) + c.readRemaining = 0 + if err := c.read(payload); err != nil { + return noFrame, err + } + maskBytes(c.readMaskKey, 0, payload) + + // 7. Process control frame payload. + + switch frameType { + case PongMessage: + if err := c.handlePong(string(payload)); err != nil { + return noFrame, err + } + case PingMessage: + if err := c.handlePing(string(payload)); err != nil { + return noFrame, err + } + case CloseMessage: + c.WriteControl(CloseMessage, []byte{}, time.Now().Add(writeWait)) + if len(payload) < 2 { + return noFrame, io.EOF + } + closeCode := binary.BigEndian.Uint16(payload) + switch closeCode { + case CloseNormalClosure, CloseGoingAway: + return noFrame, io.EOF + default: + return noFrame, errors.New("websocket: close " + + strconv.Itoa(int(closeCode)) + " " + + string(payload[2:])) + } + } + + return frameType, nil +} + +func (c *Conn) handleProtocolError(message string) error { + c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait)) + return errors.New("websocket: " + message) +} + +func (c *Conn) read(buf []byte) error { + var err error + for len(buf) > 0 && err == nil { + var nn int + nn, err = c.br.Read(buf) + buf = buf[nn:] + } + if err == io.EOF { + if len(buf) == 0 { + err = nil + } else { + err = io.ErrUnexpectedEOF + } + } + return err +} + +// NextReader returns the next data message received from the peer. The +// returned messageType is either TextMessage or BinaryMessage. +// +// There can be at most one open reader on a connection. NextReader discards +// the previous message if the application has not already consumed it. +// +// The NextReader method and the readers returned from the method cannot be +// accessed by more than one goroutine at a time. +func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { + + c.readSeq += 1 + c.readLength = 0 + + for c.readErr == nil { + var frameType int + frameType, c.readErr = c.advanceFrame() + if frameType == TextMessage || frameType == BinaryMessage { + return frameType, messageReader{c, c.readSeq}, nil + } + } + return noFrame, nil, c.readErr +} + +type messageReader struct { + c *Conn + seq int +} + +func (r messageReader) Read(b []byte) (n int, err error) { + + if r.seq != r.c.readSeq { + return 0, io.EOF + } + + for r.c.readErr == nil { + + if r.c.readRemaining > 0 { + if int64(len(b)) > r.c.readRemaining { + b = b[:r.c.readRemaining] + } + r.c.readErr = r.c.read(b) + r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b) + r.c.readRemaining -= int64(len(b)) + return len(b), r.c.readErr + } + + if r.c.readFinal { + r.c.readSeq += 1 + return 0, io.EOF + } + + var frameType int + frameType, r.c.readErr = r.c.advanceFrame() + + if frameType == TextMessage || frameType == BinaryMessage { + r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") + } + } + return 0, r.c.readErr +} + +// ReadMessage is a helper method for getting a reader using NextReader and +// reading from that reader to a buffer. +func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { + var r io.Reader + messageType, r, err = c.NextReader() + if err != nil { + return messageType, nil, err + } + p, err = ioutil.ReadAll(r) + return messageType, p, err +} + +// SetReadDeadline sets the deadline for future calls to NextReader and the +// io.Reader returned from NextReader. If the deadline is reached, the call +// will fail with a timeout instead of blocking. A zero value for t means that +// the methods will not time out. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetReadLimit sets the maximum size for a message read from the peer. If a +// message exceeds the limit, the connection sends a close frame to the peer +// and returns ErrReadLimit to the application. +func (c *Conn) SetReadLimit(limit int64) { + c.readLimit = limit +} + +// SetPingHandler sets the handler for ping messages received from the peer. +// The default ping handler sends a pong to the peer. +func (c *Conn) SetPingHandler(h func(string) error) { + if h == nil { + h = func(message string) error { + c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) + return nil + } + } + c.handlePing = h +} + +// SetPongHandler sets then handler for pong messages received from the peer. +// The default pong handler does nothing. +func (c *Conn) SetPongHandler(h func(string) error) { + if h == nil { + h = func(string) error { return nil } + } + c.handlePong = h +} + +// SetPongHandler sets the handler for +// FormatCloseMessage formats closeCode and text as a WebSocket close message. +func FormatCloseMessage(closeCode int, text string) []byte { + buf := make([]byte, 2+len(text)) + binary.BigEndian.PutUint16(buf, uint16(closeCode)) + copy(buf[2:], text) + return buf +} diff --git a/conn_test.go b/conn_test.go new file mode 100644 index 0000000..3a06d34 --- /dev/null +++ b/conn_test.go @@ -0,0 +1,140 @@ +// Copyright 2013 Gary Burd. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net" + "testing" + "testing/iotest" + "time" +) + +type fakeNetConn struct { + io.Reader + io.Writer +} + +func (c fakeNetConn) Close() error { return nil } +func (c fakeNetConn) LocalAddr() net.Addr { return nil } +func (c fakeNetConn) RemoteAddr() net.Addr { return nil } +func (c fakeNetConn) SetDeadline(t time.Time) error { return nil } +func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil } +func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil } + +func TestFraming(t *testing.T) { + frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537} + var readChunkers = []struct { + name string + f func(io.Reader) io.Reader + }{ + {"half", iotest.HalfReader}, + {"one", iotest.OneByteReader}, + {"asis", func(r io.Reader) io.Reader { return r }}, + } + + writeBuf := make([]byte, 65537) + for i := range writeBuf { + writeBuf[i] = byte(i) + } + + for _, isServer := range []bool{true, false} { + for _, chunker := range readChunkers { + + var connBuf bytes.Buffer + wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024) + rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024) + + for _, n := range frameSizes { + for _, iocopy := range []bool{true, false} { + name := fmt.Sprintf("s:%v, r:%s, n:%d c:%v", isServer, chunker.name, n, iocopy) + + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Errorf("%s: wc.NextWriter() returned %v", name, err) + continue + } + var nn int + if iocopy { + var n64 int64 + n64, err = io.Copy(w, bytes.NewReader(writeBuf[:n])) + nn = int(n64) + } else { + nn, err = w.Write(writeBuf[:n]) + } + if err != nil || nn != n { + t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err) + continue + } + err = w.Close() + if err != nil { + t.Errorf("%s: w.Close() returned %v", name, err) + continue + } + + opCode, r, err := rc.NextReader() + if err != nil || opCode != TextMessage { + t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err) + continue + } + rbuf, err := ioutil.ReadAll(r) + if err != nil { + t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) + continue + } + + if len(rbuf) != n { + t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n) + continue + } + + for i, b := range rbuf { + if byte(i) != b { + t.Errorf("%s: bad byte at offset %d", name, i) + break + } + } + } + } + } + } +} + +func TestReadLimit(t *testing.T) { + + const readLimit = 512 + message := make([]byte, readLimit+1) + + var b1, b2 bytes.Buffer + wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2) + rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) + rc.SetReadLimit(readLimit) + + // Send message at the limit with interleaved pong. + w, _ := wc.NextWriter(BinaryMessage) + w.Write(message[:readLimit-1]) + wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) + w.Write(message[:1]) + w.Close() + + // Send message larger than the limit. + wc.WriteMessage(BinaryMessage, message[:readLimit+1]) + + op, _, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("1: NextReader() returned %d, %v", op, err) + } + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("2: NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(ioutil.Discard, r) + if err != ErrReadLimit { + t.Fatalf("io.Copy() returned %v", err) + } +} diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..fa9b1cf --- /dev/null +++ b/doc.go @@ -0,0 +1,98 @@ +// Copyright 2013 Gary Burd. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package websocket implements the WebSocket protocol defined in RFC 6455. +// +// Overview +// +// The Conn type represents a WebSocket connection. +// +// A server application calls the Upgrade function to get a pointer to a Conn: +// +// func handler(w http.ResponseWriter, r *http.Request) { +// conn, err := websocket.Upgrade(w, r.Header, nil, 1024, 1024) +// if _, ok := err.(websocket.HandshakeError); ok { +// http.Error(w, "Not a websocket handshake", 400) +// return +// } else if err != nil { +// log.Println(err) +// return +// } +// ... Use conn to send and receive messages. +// } +// +// WebSocket messages are represented by the io.Reader interface when receiving +// a message and by the io.WriteCloser interface when sending a message. An +// application receives a message by calling the Conn.NextReader method and +// reading the returned io.Reader to EOF. An application sends a message by +// calling the Conn.NextWriter method and writing the message to the returned +// io.WriteCloser. The application terminates the message by closing the +// io.WriteCloser. +// +// The following example shows how to use the connection NextReader and +// NextWriter method to echo messages: +// +// for { +// mt, r, err := conn.NextReader() +// if err != nil { +// return +// } +// w, err := conn.NextWriter(mt) +// if err != nil { +// return err +// } +// if _, err := io.Copy(w, r); err != nil { +// return err +// } +// if err := w.Close(); err != nil { +// return err +// } +// } +// +// The connection ReadMessage and WriteMessage methods are helpers for reading +// or writing an entire message in one method call. The following example shows +// how to echo messages using these connection helper methods: +// +// for { +// mt, p, err := conn.ReadMessage() +// if err != nil { +// return +// } +// if _, err := conn.WriteMessaage(mt, p); err != nil { +// return err +// } +// } +// +// Concurrency +// +// A Conn supports a single concurrent caller to the write methods (NextWriter, +// SetWriteDeadline, WriteMessage) and a single concurrent caller to the read +// methods (NextReader, SetReadDeadline, ReadMessage). The Close and +// WriteControl methods can be called concurrently with all other methods. +// +// Data Messages +// +// The WebSocket protocol distinguishes between text and binary data messages. +// Text messages are interpreted as UTF-8 encoded text. The interpretation of +// binary messages is left to the application. +// +// This package uses the same types and methods to work with both types of data +// messages. It is the application's reponsiblity to ensure that text messages +// are valid UTF-8 encoded text. +// +// Control Messages +// +// The WebSocket protocol defines three types of control messages: close, ping +// and pong. Call the connection WriteControl, WriteMessage or NextWriter +// methods to send a control message to the peer. +// +// Connections handle received ping and pong messages by invoking a callback +// function set with SetPingHandler and SetPongHandler methods. These callback +// functions can be invoked from the ReadMessage method, the NextReader method +// or from a call to the data message reader returned from NextReader. +// +// Connections handle received close messages by returning an error from the +// ReadMessage method, the NextReader method or from a call to the data message +// reader returned from NextReader. +package websocket diff --git a/examples/autobahn/README.md b/examples/autobahn/README.md new file mode 100644 index 0000000..68a7ea3 --- /dev/null +++ b/examples/autobahn/README.md @@ -0,0 +1,18 @@ +# Test + +Clients and servers for the [Autobahn WebSockets Test Suite](http://autobahn.ws/testsuite). + +To test different code paths in the package, the test server echoes messages two ways: + +- Read the entire message using io.ReadAll and write the message in one chunk. +- Copy the message in parts using io.Copy + +To test the server, run it + + go run server.go + +and start the client test driver + + wstest -m fuzzingclient -s fuzzingclient.json + +When the client completes, it writes a report to reports/servers/index.html. diff --git a/examples/autobahn/fuzzingclient.json b/examples/autobahn/fuzzingclient.json new file mode 100644 index 0000000..27d5a5b --- /dev/null +++ b/examples/autobahn/fuzzingclient.json @@ -0,0 +1,14 @@ + +{ + "options": {"failByDrop": false}, + "outdir": "./reports/clients", + "servers": [ + {"agent": "ReadAllWriteMessage", "url": "ws://localhost:9000/m", "options": {"version": 18}}, + {"agent": "ReadAllWrite", "url": "ws://localhost:9000/r", "options": {"version": 18}}, + {"agent": "CopyFull", "url": "ws://localhost:9000/f", "options": {"version": 18}}, + {"agent": "CopyWriterOnly", "url": "ws://localhost:9000/c", "options": {"version": 18}} + ], + "cases": ["*"], + "exclude-cases": [], + "exclude-agent-cases": {} +} diff --git a/examples/autobahn/server.go b/examples/autobahn/server.go new file mode 100644 index 0000000..bca8e83 --- /dev/null +++ b/examples/autobahn/server.go @@ -0,0 +1,250 @@ +// Copyright 2012 Gary Burd +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +// Command server is a test server for the Autobahn WebSockets Test Suite. +package main + +import ( + "errors" + "flag" + "github.com/gorilla/websocket" + "io" + "log" + "net/http" + "time" + "unicode/utf8" +) + +// echoCopy echoes messages from the client using io.Copy. +func echoCopy(w http.ResponseWriter, r *http.Request, writerOnly bool) { + conn, err := websocket.Upgrade(w, r, nil, 4096, 4096) + if err != nil { + log.Println("Upgrade:", err) + http.Error(w, "Bad request", 400) + return + } + defer conn.Close() + for { + mt, r, err := conn.NextReader() + if err != nil { + if err != io.EOF { + log.Println("NextReader:", err) + } + return + } + if mt == websocket.TextMessage { + r = &validator{r: r} + } + w, err := conn.NextWriter(mt) + if err != nil { + log.Println("NextWriter:", err) + return + } + if mt == websocket.TextMessage { + r = &validator{r: r} + } + if writerOnly { + _, err = io.Copy(struct{ io.Writer }{w}, r) + } else { + _, err = io.Copy(w, r) + } + if err != nil { + if err == errInvalidUTF8 { + conn.WriteControl(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseInvalidFramePayloadData, ""), + time.Time{}) + } + log.Println("Copy:", err) + return + } + err = w.Close() + if err != nil { + log.Println("Close:", err) + return + } + } +} + +func echoCopyWriterOnly(w http.ResponseWriter, r *http.Request) { + echoCopy(w, r, true) +} + +func echoCopyFull(w http.ResponseWriter, r *http.Request) { + echoCopy(w, r, false) +} + +// echoReadAll echoes messages from the client by reading the entire message +// with ioutil.ReadAll. +func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage bool) { + conn, err := websocket.Upgrade(w, r, nil, 4096, 4096) + if err != nil { + log.Println("Upgrade:", err) + http.Error(w, "Bad request", 400) + return + } + defer conn.Close() + for { + mt, b, err := conn.ReadMessage() + if err != nil { + if err != io.EOF { + log.Println("NextReader:", err) + } + return + } + if mt == websocket.TextMessage { + if !utf8.Valid(b) { + conn.WriteControl(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseInvalidFramePayloadData, ""), + time.Time{}) + log.Println("ReadAll: invalid utf8") + } + } + if writeMessage { + err = conn.WriteMessage(mt, b) + if err != nil { + log.Println("WriteMessage:", err) + } + } else { + w, err := conn.NextWriter(mt) + if err != nil { + log.Println("NextWriter:", err) + return + } + if _, err := w.Write(b); err != nil { + log.Println("Writer:", err) + return + } + if err := w.Close(); err != nil { + log.Println("Close:", err) + return + } + } + } +} + +func echoReadAllWriter(w http.ResponseWriter, r *http.Request) { + echoReadAll(w, r, false) +} + +func echoReadAllWriteMessage(w http.ResponseWriter, r *http.Request) { + echoReadAll(w, r, true) +} + +func serveHome(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.Error(w, "Not found.", 404) + return + } + if r.Method != "GET" { + http.Error(w, "Method not allowed", 405) + return + } + w.Header().Set("Content-Type", "text/html; charset=utf-8") + io.WriteString(w, "Echo Server +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal in the Software without restriction, including without limitation the +// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +// sell copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +// IN THE SOFTWARE. +var utf8d = [...]byte{ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 00..1f + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 20..3f + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 40..5f + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 60..7f + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, // 80..9f + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // a0..bf + 8, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // c0..df + 0xa, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x4, 0x3, 0x3, // e0..ef + 0xb, 0x6, 0x6, 0x6, 0x5, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, // f0..ff + 0x0, 0x1, 0x2, 0x3, 0x5, 0x8, 0x7, 0x1, 0x1, 0x1, 0x4, 0x6, 0x1, 0x1, 0x1, 0x1, // s0..s0 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, // s1..s2 + 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, // s3..s4 + 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, // s5..s6 + 1, 3, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // s7..s8 +} + +const ( + utf8Accept = 0 + utf8Reject = 1 +) + +func decode(state int, x rune, b byte) (int, rune) { + t := utf8d[b] + if state != utf8Accept { + x = rune(b&0x3f) | (x << 6) + } else { + x = rune((0xff >> t) & b) + } + state = int(utf8d[256+state*16+int(t)]) + return state, x +} diff --git a/examples/chat/README.md b/examples/chat/README.md new file mode 100644 index 0000000..08fc3e6 --- /dev/null +++ b/examples/chat/README.md @@ -0,0 +1,19 @@ +# Chat Example + +This application shows how to use use the +[websocket](https://github.com/gorilla/websocket) package and +[jQuery](http://jquery.com) to implement a simple web chat application. + +## Running the example + +The example requires a working Go development environment. The [Getting +Started](http://golang.org/doc/install) page describes how to install the +development environment. + +Once you have Go up and running, you can download, build and run the example +using the following commands. + + $ go get github.com/gorilla/websocket + $ cd `go list -f '{{.Dir}}' github.com/gorilla/websocket/examples/chat` + $ go run *.go + diff --git a/examples/chat/conn.go b/examples/chat/conn.go new file mode 100644 index 0000000..aaa617e --- /dev/null +++ b/examples/chat/conn.go @@ -0,0 +1,108 @@ +// Copyright 2013 Gary Burd. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "github.com/gorilla/websocket" + "log" + "net/http" + "time" +) + +const ( + // Time allowed to write a message to the peer. + writeWait = 10 * time.Second + + // Time allowed to read the next pong message from the peer. + pongWait = 60 * time.Second + + // Send pings to peer with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 + + // Maximum message size allowed from peer. + maxMessageSize = 512 +) + +// connection is an middleman between the websocket connection and the hub. +type connection struct { + // The websocket connection. + ws *websocket.Conn + + // Buffered channel of outbound messages. + send chan []byte +} + +// readPump pumps messages from the websocket connection to the hub. +func (c *connection) readPump() { + defer func() { + h.unregister <- c + c.ws.Close() + }() + c.ws.SetReadLimit(maxMessageSize) + c.ws.SetReadDeadline(time.Now().Add(pongWait)) + c.ws.SetPongHandler(func(string) error { c.ws.SetReadDeadline(time.Now().Add(pongWait)); return nil }) + for { + _, message, err := c.ws.ReadMessage() + if err != nil { + break + } + h.broadcast <- message + } +} + +// write writes a message with the given message type and payload. +func (c *connection) write(mt int, payload []byte) error { + c.ws.SetWriteDeadline(time.Now().Add(writeWait)) + return c.ws.WriteMessage(mt, payload) +} + +// writePump pumps messages from the hub to the websocket connection. +func (c *connection) writePump() { + ticker := time.NewTicker(pingPeriod) + defer func() { + ticker.Stop() + c.ws.Close() + }() + for { + select { + case message, ok := <-c.send: + if !ok { + c.write(websocket.CloseMessage, []byte{}) + return + } + if err := c.write(websocket.TextMessage, message); err != nil { + return + } + case <-ticker.C: + if err := c.write(websocket.PingMessage, []byte{}); err != nil { + return + } + } + } +} + +// serverWs handles webocket requests from the peer. +func serveWs(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + http.Error(w, "Method not allowed", 405) + return + } + if r.Header.Get("Origin") != "http://"+r.Host { + http.Error(w, "Origin not allowed", 403) + return + } + ws, err := websocket.Upgrade(w, r, nil, 1024, 1024) + if _, ok := err.(websocket.HandshakeError); ok { + http.Error(w, "Not a websocket handshake", 400) + return + } else if err != nil { + log.Println(err) + return + } + c := &connection{send: make(chan []byte, 256), ws: ws} + h.register <- c + go c.writePump() + c.readPump() +} diff --git a/examples/chat/home.html b/examples/chat/home.html new file mode 100644 index 0000000..2959922 --- /dev/null +++ b/examples/chat/home.html @@ -0,0 +1,92 @@ + + + +Chat Example + + + + + +
+
+ + +
+ + diff --git a/examples/chat/hub.go b/examples/chat/hub.go new file mode 100644 index 0000000..1bab0ed --- /dev/null +++ b/examples/chat/hub.go @@ -0,0 +1,49 @@ +// Copyright 2013 Gary Burd. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +// hub maintains the set of active connections and broadcasts messages to the +// connections. +type hub struct { + // Registered connections. + connections map[*connection]bool + + // Inbound messages from the connections. + broadcast chan []byte + + // Register requests from the connections. + register chan *connection + + // Unregister requests from connections. + unregister chan *connection +} + +var h = hub{ + broadcast: make(chan []byte), + register: make(chan *connection), + unregister: make(chan *connection), + connections: make(map[*connection]bool), +} + +func (h *hub) run() { + for { + select { + case c := <-h.register: + h.connections[c] = true + case c := <-h.unregister: + delete(h.connections, c) + close(c.send) + case m := <-h.broadcast: + for c := range h.connections { + select { + case c.send <- m: + default: + close(c.send) + delete(h.connections, c) + } + } + } + } +} diff --git a/examples/chat/main.go b/examples/chat/main.go new file mode 100644 index 0000000..7c2515b --- /dev/null +++ b/examples/chat/main.go @@ -0,0 +1,39 @@ +// Copyright 2013 Gary Burd. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "flag" + "log" + "net/http" + "text/template" +) + +var addr = flag.String("addr", ":8080", "http service address") +var homeTempl = template.Must(template.ParseFiles("home.html")) + +func serveHome(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.Error(w, "Not found", 404) + return + } + if r.Method != "GET" { + http.Error(w, "Method nod allowed", 405) + return + } + w.Header().Set("Content-Type", "text/html; charset=utf-8") + homeTempl.Execute(w, r.Host) +} + +func main() { + flag.Parse() + go h.run() + http.HandleFunc("/", serveHome) + http.HandleFunc("/ws", serveWs) + err := http.ListenAndServe(*addr, nil) + if err != nil { + log.Fatal("ListenAndServe: ", err) + } +} diff --git a/json.go b/json.go new file mode 100644 index 0000000..4b340d7 --- /dev/null +++ b/json.go @@ -0,0 +1,39 @@ +// Copyright 2013 Gary Burd. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "encoding/json" +) + +// WriteJSON writes the JSON encoding of v to the connection. +// +// See the documentation for encoding/json Marshal for details about the +// conversion of Go values to JSON. +func WriteJSON(c *Conn, v interface{}) error { + w, err := c.NextWriter(TextMessage) + if err != nil { + return err + } + err1 := json.NewEncoder(w).Encode(v) + err2 := w.Close() + if err1 != nil { + return err1 + } + return err2 +} + +// ReadJSON reads the next JSON-encoded message from the connection and stores +// it in the value pointed to by v. +// +// See the documentation for the encoding/json Marshal function for details +// about the conversion of JSON to a Go value. +func ReadJSON(c *Conn, v interface{}) error { + _, r, err := c.NextReader() + if err != nil { + return err + } + return json.NewDecoder(r).Decode(v) +} diff --git a/json_test.go b/json_test.go new file mode 100644 index 0000000..a373844 --- /dev/null +++ b/json_test.go @@ -0,0 +1,37 @@ +// Copyright 2013 Gary Burd. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bytes" + "reflect" + "testing" +) + +func TestJSON(t *testing.T) { + var buf bytes.Buffer + c := fakeNetConn{&buf, &buf} + wc := newConn(c, true, 1024, 1024) + rc := newConn(c, false, 1024, 1024) + + var actual, expect struct { + A int + B string + } + expect.A = 1 + expect.B = "hello" + + if err := WriteJSON(wc, &expect); err != nil { + t.Fatal("write", err) + } + + if err := ReadJSON(rc, &actual); err != nil { + t.Fatal("read", err) + } + + if !reflect.DeepEqual(&actual, &expect) { + t.Fatal("equal", actual, expect) + } +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..a14c586 --- /dev/null +++ b/server.go @@ -0,0 +1,106 @@ +// Copyright 2013 Gary Burd. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "errors" + "net" + "net/http" +) + +// HandshakeError describes an error with the handshake from the peer. +type HandshakeError struct { + Err string +} + +func (e HandshakeError) Error() string { return e.Err } + +// Upgrade upgrades the HTTP server connection to the WebSocket protocol. +// +// Upgrade returns a HandshakeError if the request is not a WebSocket +// handshake. Applications should handle errors of this type by replying to the +// client with an HTTP response. +// +// The application is responsible for checking the request origin before +// calling Upgrade. An example implementation of the same origin policy is: +// +// if req.Header.Get("Origin") != "http://"+req.Host { +// http.Error(w, "Origin not allowed", 403) +// return +// } +// +// Use the responseHeader to specify cookies (Set-Cookie) and the subprotocol +// (Sec-WebSocket-Protocol). +func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) { + + if values := r.Header["Sec-Websocket-Version"]; len(values) == 0 || values[0] != "13" { + return nil, HandshakeError{"websocket: version != 13"} + } + + if !tokenListContainsValue(r.Header, "Connection", "upgrade") { + return nil, HandshakeError{"websocket: connection header != upgrade"} + } + + if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { + return nil, HandshakeError{"websocket: upgrade != websocket"} + } + + var challengeKey string + values := r.Header["Sec-Websocket-Key"] + if len(values) == 0 || values[0] == "" { + return nil, HandshakeError{"websocket: key missing or blank"} + } + challengeKey = values[0] + + var ( + netConn net.Conn + br *bufio.Reader + err error + ) + + h, ok := w.(http.Hijacker) + if !ok { + return nil, errors.New("websocket: response does not implement http.Hijacker") + } + var rw *bufio.ReadWriter + netConn, rw, err = h.Hijack() + br = rw.Reader + + if br.Buffered() > 0 { + netConn.Close() + return nil, errors.New("websocket: client sent data before handshake is complete") + } + + c := newConn(netConn, true, readBufSize, writeBufSize) + + p := c.writeBuf[:0] + p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) + p = append(p, computeAcceptKey(challengeKey)...) + p = append(p, "\r\n"...) + for k, vs := range responseHeader { + for _, v := range vs { + p = append(p, k...) + p = append(p, ": "...) + for i := 0; i < len(v); i++ { + b := v[i] + if b <= 31 { + // prevent response splitting. + b = ' ' + } + p = append(p, b) + } + p = append(p, "\r\n"...) + } + } + p = append(p, "\r\n"...) + + if _, err = netConn.Write(p); err != nil { + netConn.Close() + return nil, err + } + + return c, nil +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..a9cfbf9 --- /dev/null +++ b/util.go @@ -0,0 +1,44 @@ +// Copyright 2013 Gary Burd. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "io" + "net/http" + "strings" +) + +// tokenListContainsValue returns true if the 1#token header with the given +// name contains token. +func tokenListContainsValue(header http.Header, name string, value string) bool { + for _, v := range header[name] { + for _, s := range strings.Split(v, ",") { + if strings.EqualFold(value, strings.TrimSpace(s)) { + return true + } + } + } + return false +} + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func computeAcceptKey(challengeKey string) string { + h := sha1.New() + h.Write([]byte(challengeKey)) + h.Write(keyGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func generateChallengeKey() (string, error) { + p := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, p); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(p), nil +}