Initial commit

This commit is contained in:
Gary Burd 2013-10-16 09:41:47 -07:00
commit 273ecadfca
20 changed files with 2066 additions and 0 deletions

22
.gitignore vendored Normal file
View File

@ -0,0 +1,22 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe

23
LICENSE Normal file
View File

@ -0,0 +1,23 @@
Copyright (c) 2013, Gorilla web toolkit
All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice, this
list of conditions and the following disclaimer in the documentation and/or
other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

26
README.md Normal file
View File

@ -0,0 +1,26 @@
# WebSocket
This project is a [Go](http://golang.org/) implementation of the
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
The project passes the server tests in the [Autobahn WebSockets Test
Suite](http://autobahn.ws/testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
## Documentation
* [Reference](http://godoc.org/github.com/gorilla/websocket)
* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat)
## Features
- Send and receive ping, pong and close control messages.
- Limit size of received messages.
- Stream messages.
- Specify IO buffer sizes.
- Application has full control over origin checks and sub-protocol negotiation.
## Installation
go get github.com/gorilla/websocket

69
client.go Normal file
View File

@ -0,0 +1,69 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"errors"
"net"
"net/http"
"net/url"
"strings"
)
// ErrBadHandshake is returned when the server response to opening handshake is
// invalid.
var ErrBadHandshake = errors.New("websocket: bad handshake")
// NewClient creates a new client connection using the given net connection.
// The URL u specifies the host and request URI. Use requestHeader to specify
// the origin (Origin), subprotocols (Set-WebSocket-Protocol) and cookies
// (Cookie). Use the response.Header to get the selected subprotocol
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
//
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
// non-nil *http.Response so that callers can handle redirects, authentication,
// etc.
func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) {
challengeKey, err := generateChallengeKey()
if err != nil {
return nil, nil, err
}
acceptKey := computeAcceptKey(challengeKey)
c = newConn(netConn, false, readBufSize, writeBufSize)
p := c.writeBuf[:0]
p = append(p, "GET "...)
p = append(p, u.RequestURI()...)
p = append(p, " HTTP/1.1\r\nHost: "...)
p = append(p, u.Host...)
p = append(p, "\r\nUpgrade: websocket\r\nConnection: upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: "...)
p = append(p, challengeKey...)
p = append(p, "\r\n"...)
for k, vs := range requestHeader {
for _, v := range vs {
p = append(p, k...)
p = append(p, ": "...)
p = append(p, v...)
p = append(p, "\r\n"...)
}
}
p = append(p, "\r\n"...)
if _, err := netConn.Write(p); err != nil {
return nil, nil, err
}
resp, err := http.ReadResponse(c.br, &http.Request{Method: "GET", URL: u})
if err != nil {
return nil, nil, err
}
if resp.StatusCode != 101 ||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != acceptKey {
return nil, resp, ErrBadHandshake
}
return c, resp, nil
}

114
client_server_test.go Normal file
View File

@ -0,0 +1,114 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket_test
import (
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/gorilla/websocket"
)
type wsHandler struct {
*testing.T
}
func (t wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
http.Error(w, "Method not allowed", 405)
t.Logf("bad method: %s", r.Method)
return
}
if r.Header.Get("Origin") != "http://"+r.Host {
http.Error(w, "Origin not allowed", 403)
t.Logf("bad origin: %s", r.Header.Get("Origin"))
return
}
ws, err := websocket.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}, 1024, 1024)
if _, ok := err.(websocket.HandshakeError); ok {
t.Logf("bad handshake: %v", err)
http.Error(w, "Not a websocket handshake", 400)
return
} else if err != nil {
t.Logf("upgrade error: %v", err)
return
}
defer ws.Close()
for {
op, r, err := ws.NextReader()
if err != nil {
if err != io.EOF {
t.Logf("NextReader: %v", err)
}
return
}
if op == websocket.PongMessage {
continue
}
w, err := ws.NextWriter(op)
if err != nil {
t.Logf("NextWriter: %v", err)
return
}
if _, err = io.Copy(w, r); err != nil {
t.Logf("Copy: %v", err)
return
}
if err := w.Close(); err != nil {
t.Logf("Close: %v", err)
return
}
}
}
func TestClientServer(t *testing.T) {
s := httptest.NewServer(wsHandler{t})
defer s.Close()
u, _ := url.Parse(s.URL)
c, err := net.Dial("tcp", u.Host)
if err != nil {
t.Fatalf("Dial: %v", err)
}
ws, resp, err := websocket.NewClient(c, u, http.Header{"Origin": {s.URL}}, 1024, 1024)
if err != nil {
t.Fatalf("NewClient: %v", err)
}
defer ws.Close()
var sessionID string
for _, c := range resp.Cookies() {
if c.Name == "sessionID" {
sessionID = c.Value
}
}
if sessionID != "1234" {
t.Error("Set-Cookie not received from the server.")
}
w, _ := ws.NextWriter(websocket.TextMessage)
io.WriteString(w, "HELLO")
w.Close()
ws.SetReadDeadline(time.Now().Add(1 * time.Second))
op, r, err := ws.NextReader()
if err != nil {
t.Fatalf("NextReader: %v", err)
}
if op != websocket.TextMessage {
t.Fatalf("op=%d, want %d", op, websocket.TextMessage)
}
b, err := ioutil.ReadAll(r)
if err != nil {
t.Fatalf("ReadAll: %v", err)
}
if string(b) != "HELLO" {
t.Fatalf("message=%s, want %s", b, "HELLO")
}
}

759
conn.go Normal file
View File

@ -0,0 +1,759 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"encoding/binary"
"errors"
"io"
"io/ioutil"
"math/rand"
"net"
"strconv"
"time"
)
// Close codes defined in RFC 6455, section 11.7.
const (
CloseNormalClosure = 1000
CloseGoingAway = 1001
CloseProtocolError = 1002
CloseUnsupportedData = 1003
CloseNoStatusReceived = 1005
CloseAbnormalClosure = 1006
CloseInvalidFramePayloadData = 1007
ClosePolicyViolation = 1008
CloseMessageTooBig = 1009
CloseMandatoryExtension = 1010
CloseInternalServerErr = 1011
CloseTLSHandshake = 1015
)
// The message types are defined in RFC 6455, section 11.8.
const (
// TextMessage denotes a text message. The text message payload is
// interpreted as UTF-8 encoded text data.
TextMessage = 1
// BinaryMessage denotes a binary data message.
BinaryMessage = 2
// CloseMessage denotes a close control message. The optional message
// payload contains a numeric code and text. Use the FormatCloseMessage
// function to format a close message payload.
CloseMessage = 8
// PingMessage denotes a ping control message. The optional message payload
// is UTF-8 encoded text.
PingMessage = 9
// PongMessage denotes a ping control message. The optional message payload
// is UTF-8 encoded text.
PongMessage = 10
)
var (
continuationFrame = 0
noFrame = -1
)
var (
ErrCloseSent = errors.New("websocket: close sent")
ErrReadLimit = errors.New("websocket: read limit exceeded")
)
var (
errBadWriteOpCode = errors.New("websocket: bad write message type")
errWriteTimeout = errors.New("websocket: write timeout")
errWriteClosed = errors.New("websocket: write closed")
errInvalidControlFrame = errors.New("websocket: invalid control frame")
)
const (
maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
maxControlFramePayloadSize = 125
finalBit = 1 << 7
maskBit = 1 << 7
writeWait = time.Second
)
func isControl(frameType int) bool {
return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage
}
func isData(frameType int) bool {
return frameType == TextMessage || frameType == BinaryMessage
}
func maskBytes(key [4]byte, pos int, b []byte) int {
for i := range b {
b[i] ^= key[pos&3]
pos += 1
}
return pos & 3
}
func newMaskKey() [4]byte {
n := rand.Uint32()
return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 32)}
}
// Conn represents a WebSocket connection.
type Conn struct {
conn net.Conn
isServer bool
// Write fields
mu chan bool // used as mutex to protect write to conn and closeSent
closeSent bool // true if close message was sent
// Message writer fields.
writeErr error
writeBuf []byte // frame is constructed in this buffer.
writePos int // end of data in writeBuf.
writeFrameType int // type of the current frame.
writeSeq int // incremented to invalidate message writers.
writeDeadline time.Time
// Read fields
readErr error
br *bufio.Reader
readRemaining int64 // bytes remaining in current frame.
readFinal bool // true the current message has more frames.
readSeq int // incremented to invalidate message readers.
readLength int64 // Message size.
readLimit int64 // Maximum message size.
readMaskPos int
readMaskKey [4]byte
handlePong func(string) error
handlePing func(string) error
}
func newConn(conn net.Conn, isServer bool, readBufSize, writeBufSize int) *Conn {
mu := make(chan bool, 1)
mu <- true
c := &Conn{
isServer: isServer,
br: bufio.NewReaderSize(conn, readBufSize),
conn: conn,
mu: mu,
readFinal: true,
writeBuf: make([]byte, writeBufSize+maxFrameHeaderSize),
writeFrameType: noFrame,
writePos: maxFrameHeaderSize,
}
c.SetPingHandler(nil)
c.SetPongHandler(nil)
return c
}
// Close closes the underlying network connection without sending or waiting for a close frame.
func (c *Conn) Close() error {
return c.conn.Close()
}
// LocalAddr returns the local network address.
func (c *Conn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
// RemoteAddr returns the remote network address.
func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// Write methods
func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
<-c.mu
defer func() { c.mu <- true }()
if c.closeSent {
return ErrCloseSent
} else if frameType == CloseMessage {
c.closeSent = true
}
c.conn.SetWriteDeadline(deadline)
for _, buf := range bufs {
if len(buf) > 0 {
n, err := c.conn.Write(buf)
if n != len(buf) {
// Close on partial write.
c.conn.Close()
}
if err != nil {
return err
}
}
}
return nil
}
// WriteControl writes a control message with the given deadline. The allowed
// message types are CloseMessage, PingMessage and PongMessage.
func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
if !isControl(messageType) {
return errBadWriteOpCode
}
if len(data) > maxControlFramePayloadSize {
return errInvalidControlFrame
}
b0 := byte(messageType) | finalBit
b1 := byte(len(data))
if !c.isServer {
b1 |= maskBit
}
buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize)
buf = append(buf, b0, b1)
if c.isServer {
buf = append(buf, data...)
} else {
key := newMaskKey()
buf = append(buf, key[:]...)
buf = append(buf, data...)
maskBytes(key, 0, buf[6:])
}
d := time.Hour * 1000
if !deadline.IsZero() {
d = deadline.Sub(time.Now())
if d < 0 {
return errWriteTimeout
}
}
timer := time.NewTimer(d)
select {
case <-c.mu:
timer.Stop()
case <-timer.C:
return errWriteTimeout
}
defer func() { c.mu <- true }()
if c.closeSent {
return ErrCloseSent
} else if messageType == CloseMessage {
c.closeSent = true
}
c.conn.SetWriteDeadline(deadline)
n, err := c.conn.Write(buf)
if n != 0 && n != len(buf) {
c.conn.Close()
}
return err
}
// NextWriter returns a writer for the next message to send. The writer's
// Close method flushes the complete message to the network.
//
// There can be at most one open writer on a connection. NextWriter closes the
// previous writer if the application has not already done so.
//
// The NextWriter method and the writers returned from the method cannot be
// accessed by more than one goroutine at a time.
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
if c.writeErr != nil {
return nil, c.writeErr
}
if c.writeFrameType != noFrame {
if err := c.flushFrame(true, nil); err != nil {
return nil, err
}
}
if !isControl(messageType) && !isData(messageType) {
return nil, errBadWriteOpCode
}
c.writeFrameType = messageType
return messageWriter{c, c.writeSeq}, nil
}
func (c *Conn) flushFrame(final bool, extra []byte) error {
length := c.writePos - maxFrameHeaderSize + len(extra)
// Check for invalid control frames.
if isControl(c.writeFrameType) &&
(!final || length > maxControlFramePayloadSize) {
c.writeSeq += 1
c.writeFrameType = noFrame
c.writePos = maxFrameHeaderSize
return errInvalidControlFrame
}
b0 := byte(c.writeFrameType)
if final {
b0 |= finalBit
}
b1 := byte(0)
if !c.isServer {
b1 |= maskBit
}
// Assume that the frame starts at beginning of c.writeBuf.
framePos := 0
if c.isServer {
// Adjust up if mask not included in the header.
framePos = 4
}
switch {
case length >= 65536:
c.writeBuf[framePos] = b0
c.writeBuf[framePos+1] = b1 | 127
binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length))
case length > 125:
framePos += 6
c.writeBuf[framePos] = b0
c.writeBuf[framePos+1] = b1 | 126
binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length))
default:
framePos += 8
c.writeBuf[framePos] = b0
c.writeBuf[framePos+1] = b1 | byte(length)
}
if !c.isServer {
key := newMaskKey()
copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:c.writePos])
if len(extra) > 0 {
c.writeErr = errors.New("websocket: internal error, extra used in client mode")
return c.writeErr
}
}
// Write the buffers to the connection.
c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra)
// Setup for next frame.
c.writePos = maxFrameHeaderSize
c.writeFrameType = continuationFrame
if final {
c.writeSeq += 1
c.writeFrameType = noFrame
}
return c.writeErr
}
type messageWriter struct {
c *Conn
seq int
}
func (w messageWriter) err() error {
c := w.c
if c.writeSeq != w.seq {
return errWriteClosed
}
if c.writeErr != nil {
return c.writeErr
}
return nil
}
func (w messageWriter) ncopy(max int) (int, error) {
n := len(w.c.writeBuf) - w.c.writePos
if n <= 0 {
if err := w.c.flushFrame(false, nil); err != nil {
return 0, err
}
n = len(w.c.writeBuf) - w.c.writePos
}
if n > max {
n = max
}
return n, nil
}
func (w messageWriter) write(final bool, p []byte) (int, error) {
if err := w.err(); err != nil {
return 0, err
}
if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
// Don't buffer large messages.
err := w.c.flushFrame(final, p)
if err != nil {
return 0, err
}
return len(p), nil
}
nn := len(p)
for len(p) > 0 {
n, err := w.ncopy(len(p))
if err != nil {
return 0, err
}
copy(w.c.writeBuf[w.c.writePos:], p[:n])
w.c.writePos += n
p = p[n:]
}
return nn, nil
}
func (w messageWriter) Write(p []byte) (int, error) {
return w.write(false, p)
}
func (w messageWriter) WriteString(p string) (int, error) {
if err := w.err(); err != nil {
return 0, err
}
nn := len(p)
for len(p) > 0 {
n, err := w.ncopy(len(p))
if err != nil {
return 0, err
}
copy(w.c.writeBuf[w.c.writePos:], p[:n])
w.c.writePos += n
p = p[n:]
}
return nn, nil
}
func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
if err := w.err(); err != nil {
return 0, err
}
for {
if w.c.writePos == len(w.c.writeBuf) {
err = w.c.flushFrame(false, nil)
if err != nil {
break
}
}
var n int
n, err = r.Read(w.c.writeBuf[w.c.writePos:])
w.c.writePos += n
nn += int64(n)
if err != nil {
if err == io.EOF {
err = nil
}
break
}
}
return nn, err
}
func (w messageWriter) Close() error {
if err := w.err(); err != nil {
return err
}
return w.c.flushFrame(true, nil)
}
// WriteMessage is a helper method for getting a writer using NextWriter,
// writing the message and closing the writer.
func (c *Conn) WriteMessage(messageType int, data []byte) error {
wr, err := c.NextWriter(messageType)
if err != nil {
return err
}
w := wr.(messageWriter)
if _, err := w.write(true, data); err != nil {
return err
}
if c.writeSeq == w.seq {
if err := c.flushFrame(true, nil); err != nil {
return err
}
}
return nil
}
// SetWriteDeadline sets the deadline for future calls to NextWriter and the
// io.WriteCloser returned from NextWriter. If the deadline is reached, the
// call will fail with a timeout instead of blocking. A zero value for t means
// Write will not time out. Even if Write times out, it may return n > 0,
// indicating that some of the data was successfully written.
func (c *Conn) SetWriteDeadline(t time.Time) error {
c.writeDeadline = t
return nil
}
// Read methods
func (c *Conn) advanceFrame() (int, error) {
// 1. Skip remainder of previous frame.
if c.readRemaining > 0 {
if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil {
return noFrame, err
}
}
// 2. Read and parse first two bytes of frame header.
var b [8]byte
if err := c.read(b[:2]); err != nil {
return noFrame, err
}
final := b[0]&finalBit != 0
frameType := int(b[0] & 0xf)
reserved := int((b[0] >> 4) & 0x7)
mask := b[1]&maskBit != 0
c.readRemaining = int64(b[1] & 0x7f)
if reserved != 0 {
return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
}
switch frameType {
case CloseMessage, PingMessage, PongMessage:
if c.readRemaining > maxControlFramePayloadSize {
return noFrame, c.handleProtocolError("control frame length > 125")
}
if !final {
return noFrame, c.handleProtocolError("control frame not final")
}
case TextMessage, BinaryMessage:
if !c.readFinal {
return noFrame, c.handleProtocolError("message start before final message frame")
}
c.readFinal = final
case continuationFrame:
if c.readFinal {
return noFrame, c.handleProtocolError("continuation after final message frame")
}
c.readFinal = final
default:
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
}
// 3. Read and parse frame length.
switch c.readRemaining {
case 126:
if err := c.read(b[:2]); err != nil {
return noFrame, err
}
c.readRemaining = int64(binary.BigEndian.Uint16(b[:2]))
case 127:
if err := c.read(b[:8]); err != nil {
return noFrame, err
}
c.readRemaining = int64(binary.BigEndian.Uint64(b[:8]))
}
// 4. Handle frame masking.
if mask != c.isServer {
return noFrame, c.handleProtocolError("incorrect mask flag")
}
if mask {
c.readMaskPos = 0
if err := c.read(c.readMaskKey[:]); err != nil {
return noFrame, err
}
}
// 5. For text and binary messages, enforce read limit and return.
if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
c.readLength += c.readRemaining
if c.readLimit > 0 && c.readLength > c.readLimit {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
return noFrame, ErrReadLimit
}
return frameType, nil
}
// 6. Read control frame payload.
payload := make([]byte, c.readRemaining)
c.readRemaining = 0
if err := c.read(payload); err != nil {
return noFrame, err
}
maskBytes(c.readMaskKey, 0, payload)
// 7. Process control frame payload.
switch frameType {
case PongMessage:
if err := c.handlePong(string(payload)); err != nil {
return noFrame, err
}
case PingMessage:
if err := c.handlePing(string(payload)); err != nil {
return noFrame, err
}
case CloseMessage:
c.WriteControl(CloseMessage, []byte{}, time.Now().Add(writeWait))
if len(payload) < 2 {
return noFrame, io.EOF
}
closeCode := binary.BigEndian.Uint16(payload)
switch closeCode {
case CloseNormalClosure, CloseGoingAway:
return noFrame, io.EOF
default:
return noFrame, errors.New("websocket: close " +
strconv.Itoa(int(closeCode)) + " " +
string(payload[2:]))
}
}
return frameType, nil
}
func (c *Conn) handleProtocolError(message string) error {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait))
return errors.New("websocket: " + message)
}
func (c *Conn) read(buf []byte) error {
var err error
for len(buf) > 0 && err == nil {
var nn int
nn, err = c.br.Read(buf)
buf = buf[nn:]
}
if err == io.EOF {
if len(buf) == 0 {
err = nil
} else {
err = io.ErrUnexpectedEOF
}
}
return err
}
// NextReader returns the next data message received from the peer. The
// returned messageType is either TextMessage or BinaryMessage.
//
// There can be at most one open reader on a connection. NextReader discards
// the previous message if the application has not already consumed it.
//
// The NextReader method and the readers returned from the method cannot be
// accessed by more than one goroutine at a time.
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
c.readSeq += 1
c.readLength = 0
for c.readErr == nil {
var frameType int
frameType, c.readErr = c.advanceFrame()
if frameType == TextMessage || frameType == BinaryMessage {
return frameType, messageReader{c, c.readSeq}, nil
}
}
return noFrame, nil, c.readErr
}
type messageReader struct {
c *Conn
seq int
}
func (r messageReader) Read(b []byte) (n int, err error) {
if r.seq != r.c.readSeq {
return 0, io.EOF
}
for r.c.readErr == nil {
if r.c.readRemaining > 0 {
if int64(len(b)) > r.c.readRemaining {
b = b[:r.c.readRemaining]
}
r.c.readErr = r.c.read(b)
r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b)
r.c.readRemaining -= int64(len(b))
return len(b), r.c.readErr
}
if r.c.readFinal {
r.c.readSeq += 1
return 0, io.EOF
}
var frameType int
frameType, r.c.readErr = r.c.advanceFrame()
if frameType == TextMessage || frameType == BinaryMessage {
r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
}
}
return 0, r.c.readErr
}
// ReadMessage is a helper method for getting a reader using NextReader and
// reading from that reader to a buffer.
func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
var r io.Reader
messageType, r, err = c.NextReader()
if err != nil {
return messageType, nil, err
}
p, err = ioutil.ReadAll(r)
return messageType, p, err
}
// SetReadDeadline sets the deadline for future calls to NextReader and the
// io.Reader returned from NextReader. If the deadline is reached, the call
// will fail with a timeout instead of blocking. A zero value for t means that
// the methods will not time out.
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
// SetReadLimit sets the maximum size for a message read from the peer. If a
// message exceeds the limit, the connection sends a close frame to the peer
// and returns ErrReadLimit to the application.
func (c *Conn) SetReadLimit(limit int64) {
c.readLimit = limit
}
// SetPingHandler sets the handler for ping messages received from the peer.
// The default ping handler sends a pong to the peer.
func (c *Conn) SetPingHandler(h func(string) error) {
if h == nil {
h = func(message string) error {
c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait))
return nil
}
}
c.handlePing = h
}
// SetPongHandler sets then handler for pong messages received from the peer.
// The default pong handler does nothing.
func (c *Conn) SetPongHandler(h func(string) error) {
if h == nil {
h = func(string) error { return nil }
}
c.handlePong = h
}
// SetPongHandler sets the handler for
// FormatCloseMessage formats closeCode and text as a WebSocket close message.
func FormatCloseMessage(closeCode int, text string) []byte {
buf := make([]byte, 2+len(text))
binary.BigEndian.PutUint16(buf, uint16(closeCode))
copy(buf[2:], text)
return buf
}

140
conn_test.go Normal file
View File

@ -0,0 +1,140 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net"
"testing"
"testing/iotest"
"time"
)
type fakeNetConn struct {
io.Reader
io.Writer
}
func (c fakeNetConn) Close() error { return nil }
func (c fakeNetConn) LocalAddr() net.Addr { return nil }
func (c fakeNetConn) RemoteAddr() net.Addr { return nil }
func (c fakeNetConn) SetDeadline(t time.Time) error { return nil }
func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil }
func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
func TestFraming(t *testing.T) {
frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
var readChunkers = []struct {
name string
f func(io.Reader) io.Reader
}{
{"half", iotest.HalfReader},
{"one", iotest.OneByteReader},
{"asis", func(r io.Reader) io.Reader { return r }},
}
writeBuf := make([]byte, 65537)
for i := range writeBuf {
writeBuf[i] = byte(i)
}
for _, isServer := range []bool{true, false} {
for _, chunker := range readChunkers {
var connBuf bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
for _, n := range frameSizes {
for _, iocopy := range []bool{true, false} {
name := fmt.Sprintf("s:%v, r:%s, n:%d c:%v", isServer, chunker.name, n, iocopy)
w, err := wc.NextWriter(TextMessage)
if err != nil {
t.Errorf("%s: wc.NextWriter() returned %v", name, err)
continue
}
var nn int
if iocopy {
var n64 int64
n64, err = io.Copy(w, bytes.NewReader(writeBuf[:n]))
nn = int(n64)
} else {
nn, err = w.Write(writeBuf[:n])
}
if err != nil || nn != n {
t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
continue
}
err = w.Close()
if err != nil {
t.Errorf("%s: w.Close() returned %v", name, err)
continue
}
opCode, r, err := rc.NextReader()
if err != nil || opCode != TextMessage {
t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
continue
}
rbuf, err := ioutil.ReadAll(r)
if err != nil {
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
continue
}
if len(rbuf) != n {
t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
continue
}
for i, b := range rbuf {
if byte(i) != b {
t.Errorf("%s: bad byte at offset %d", name, i)
break
}
}
}
}
}
}
}
func TestReadLimit(t *testing.T) {
const readLimit = 512
message := make([]byte, readLimit+1)
var b1, b2 bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2)
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
rc.SetReadLimit(readLimit)
// Send message at the limit with interleaved pong.
w, _ := wc.NextWriter(BinaryMessage)
w.Write(message[:readLimit-1])
wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
w.Write(message[:1])
w.Close()
// Send message larger than the limit.
wc.WriteMessage(BinaryMessage, message[:readLimit+1])
op, _, err := rc.NextReader()
if op != BinaryMessage || err != nil {
t.Fatalf("1: NextReader() returned %d, %v", op, err)
}
op, r, err := rc.NextReader()
if op != BinaryMessage || err != nil {
t.Fatalf("2: NextReader() returned %d, %v", op, err)
}
_, err = io.Copy(ioutil.Discard, r)
if err != ErrReadLimit {
t.Fatalf("io.Copy() returned %v", err)
}
}

98
doc.go Normal file
View File

@ -0,0 +1,98 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package websocket implements the WebSocket protocol defined in RFC 6455.
//
// Overview
//
// The Conn type represents a WebSocket connection.
//
// A server application calls the Upgrade function to get a pointer to a Conn:
//
// func handler(w http.ResponseWriter, r *http.Request) {
// conn, err := websocket.Upgrade(w, r.Header, nil, 1024, 1024)
// if _, ok := err.(websocket.HandshakeError); ok {
// http.Error(w, "Not a websocket handshake", 400)
// return
// } else if err != nil {
// log.Println(err)
// return
// }
// ... Use conn to send and receive messages.
// }
//
// WebSocket messages are represented by the io.Reader interface when receiving
// a message and by the io.WriteCloser interface when sending a message. An
// application receives a message by calling the Conn.NextReader method and
// reading the returned io.Reader to EOF. An application sends a message by
// calling the Conn.NextWriter method and writing the message to the returned
// io.WriteCloser. The application terminates the message by closing the
// io.WriteCloser.
//
// The following example shows how to use the connection NextReader and
// NextWriter method to echo messages:
//
// for {
// mt, r, err := conn.NextReader()
// if err != nil {
// return
// }
// w, err := conn.NextWriter(mt)
// if err != nil {
// return err
// }
// if _, err := io.Copy(w, r); err != nil {
// return err
// }
// if err := w.Close(); err != nil {
// return err
// }
// }
//
// The connection ReadMessage and WriteMessage methods are helpers for reading
// or writing an entire message in one method call. The following example shows
// how to echo messages using these connection helper methods:
//
// for {
// mt, p, err := conn.ReadMessage()
// if err != nil {
// return
// }
// if _, err := conn.WriteMessaage(mt, p); err != nil {
// return err
// }
// }
//
// Concurrency
//
// A Conn supports a single concurrent caller to the write methods (NextWriter,
// SetWriteDeadline, WriteMessage) and a single concurrent caller to the read
// methods (NextReader, SetReadDeadline, ReadMessage). The Close and
// WriteControl methods can be called concurrently with all other methods.
//
// Data Messages
//
// The WebSocket protocol distinguishes between text and binary data messages.
// Text messages are interpreted as UTF-8 encoded text. The interpretation of
// binary messages is left to the application.
//
// This package uses the same types and methods to work with both types of data
// messages. It is the application's reponsiblity to ensure that text messages
// are valid UTF-8 encoded text.
//
// Control Messages
//
// The WebSocket protocol defines three types of control messages: close, ping
// and pong. Call the connection WriteControl, WriteMessage or NextWriter
// methods to send a control message to the peer.
//
// Connections handle received ping and pong messages by invoking a callback
// function set with SetPingHandler and SetPongHandler methods. These callback
// functions can be invoked from the ReadMessage method, the NextReader method
// or from a call to the data message reader returned from NextReader.
//
// Connections handle received close messages by returning an error from the
// ReadMessage method, the NextReader method or from a call to the data message
// reader returned from NextReader.
package websocket

View File

@ -0,0 +1,18 @@
# Test
Clients and servers for the [Autobahn WebSockets Test Suite](http://autobahn.ws/testsuite).
To test different code paths in the package, the test server echoes messages two ways:
- Read the entire message using io.ReadAll and write the message in one chunk.
- Copy the message in parts using io.Copy
To test the server, run it
go run server.go
and start the client test driver
wstest -m fuzzingclient -s fuzzingclient.json
When the client completes, it writes a report to reports/servers/index.html.

View File

@ -0,0 +1,14 @@
{
"options": {"failByDrop": false},
"outdir": "./reports/clients",
"servers": [
{"agent": "ReadAllWriteMessage", "url": "ws://localhost:9000/m", "options": {"version": 18}},
{"agent": "ReadAllWrite", "url": "ws://localhost:9000/r", "options": {"version": 18}},
{"agent": "CopyFull", "url": "ws://localhost:9000/f", "options": {"version": 18}},
{"agent": "CopyWriterOnly", "url": "ws://localhost:9000/c", "options": {"version": 18}}
],
"cases": ["*"],
"exclude-cases": [],
"exclude-agent-cases": {}
}

250
examples/autobahn/server.go Normal file
View File

@ -0,0 +1,250 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
// Command server is a test server for the Autobahn WebSockets Test Suite.
package main
import (
"errors"
"flag"
"github.com/gorilla/websocket"
"io"
"log"
"net/http"
"time"
"unicode/utf8"
)
// echoCopy echoes messages from the client using io.Copy.
func echoCopy(w http.ResponseWriter, r *http.Request, writerOnly bool) {
conn, err := websocket.Upgrade(w, r, nil, 4096, 4096)
if err != nil {
log.Println("Upgrade:", err)
http.Error(w, "Bad request", 400)
return
}
defer conn.Close()
for {
mt, r, err := conn.NextReader()
if err != nil {
if err != io.EOF {
log.Println("NextReader:", err)
}
return
}
if mt == websocket.TextMessage {
r = &validator{r: r}
}
w, err := conn.NextWriter(mt)
if err != nil {
log.Println("NextWriter:", err)
return
}
if mt == websocket.TextMessage {
r = &validator{r: r}
}
if writerOnly {
_, err = io.Copy(struct{ io.Writer }{w}, r)
} else {
_, err = io.Copy(w, r)
}
if err != nil {
if err == errInvalidUTF8 {
conn.WriteControl(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseInvalidFramePayloadData, ""),
time.Time{})
}
log.Println("Copy:", err)
return
}
err = w.Close()
if err != nil {
log.Println("Close:", err)
return
}
}
}
func echoCopyWriterOnly(w http.ResponseWriter, r *http.Request) {
echoCopy(w, r, true)
}
func echoCopyFull(w http.ResponseWriter, r *http.Request) {
echoCopy(w, r, false)
}
// echoReadAll echoes messages from the client by reading the entire message
// with ioutil.ReadAll.
func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage bool) {
conn, err := websocket.Upgrade(w, r, nil, 4096, 4096)
if err != nil {
log.Println("Upgrade:", err)
http.Error(w, "Bad request", 400)
return
}
defer conn.Close()
for {
mt, b, err := conn.ReadMessage()
if err != nil {
if err != io.EOF {
log.Println("NextReader:", err)
}
return
}
if mt == websocket.TextMessage {
if !utf8.Valid(b) {
conn.WriteControl(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseInvalidFramePayloadData, ""),
time.Time{})
log.Println("ReadAll: invalid utf8")
}
}
if writeMessage {
err = conn.WriteMessage(mt, b)
if err != nil {
log.Println("WriteMessage:", err)
}
} else {
w, err := conn.NextWriter(mt)
if err != nil {
log.Println("NextWriter:", err)
return
}
if _, err := w.Write(b); err != nil {
log.Println("Writer:", err)
return
}
if err := w.Close(); err != nil {
log.Println("Close:", err)
return
}
}
}
}
func echoReadAllWriter(w http.ResponseWriter, r *http.Request) {
echoReadAll(w, r, false)
}
func echoReadAllWriteMessage(w http.ResponseWriter, r *http.Request) {
echoReadAll(w, r, true)
}
func serveHome(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.Error(w, "Not found.", 404)
return
}
if r.Method != "GET" {
http.Error(w, "Method not allowed", 405)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
io.WriteString(w, "<html><body>Echo Server</body></html")
}
var addr = flag.String("addr", ":9000", "http service address")
func main() {
flag.Parse()
http.HandleFunc("/", serveHome)
http.HandleFunc("/c", echoCopyWriterOnly)
http.HandleFunc("/f", echoCopyFull)
http.HandleFunc("/r", echoReadAllWriter)
http.HandleFunc("/m", echoReadAllWriteMessage)
err := http.ListenAndServe(*addr, nil)
if err != nil {
log.Fatal("ListenAndServe: ", err)
}
}
type validator struct {
state int
x rune
r io.Reader
}
var errInvalidUTF8 = errors.New("invalid utf8")
func (r *validator) Read(p []byte) (int, error) {
n, err := r.r.Read(p)
state := r.state
x := r.x
for _, b := range p[:n] {
state, x = decode(state, x, b)
if state == utf8Reject {
break
}
}
r.state = state
r.x = x
if state == utf8Reject || (err == io.EOF && state != utf8Accept) {
return n, errInvalidUTF8
}
return n, err
}
// UTF-8 decoder from http://bjoern.hoehrmann.de/utf-8/decoder/dfa/
//
// Copyright (c) 2008-2009 Bjoern Hoehrmann <bjoern@hoehrmann.de>
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to
// deal in the Software without restriction, including without limitation the
// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
// sell copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
// IN THE SOFTWARE.
var utf8d = [...]byte{
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 00..1f
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 20..3f
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 40..5f
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 60..7f
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, // 80..9f
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // a0..bf
8, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // c0..df
0xa, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x4, 0x3, 0x3, // e0..ef
0xb, 0x6, 0x6, 0x6, 0x5, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, // f0..ff
0x0, 0x1, 0x2, 0x3, 0x5, 0x8, 0x7, 0x1, 0x1, 0x1, 0x4, 0x6, 0x1, 0x1, 0x1, 0x1, // s0..s0
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, // s1..s2
1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, // s3..s4
1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, // s5..s6
1, 3, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // s7..s8
}
const (
utf8Accept = 0
utf8Reject = 1
)
func decode(state int, x rune, b byte) (int, rune) {
t := utf8d[b]
if state != utf8Accept {
x = rune(b&0x3f) | (x << 6)
} else {
x = rune((0xff >> t) & b)
}
state = int(utf8d[256+state*16+int(t)])
return state, x
}

19
examples/chat/README.md Normal file
View File

@ -0,0 +1,19 @@
# Chat Example
This application shows how to use use the
[websocket](https://github.com/gorilla/websocket) package and
[jQuery](http://jquery.com) to implement a simple web chat application.
## Running the example
The example requires a working Go development environment. The [Getting
Started](http://golang.org/doc/install) page describes how to install the
development environment.
Once you have Go up and running, you can download, build and run the example
using the following commands.
$ go get github.com/gorilla/websocket
$ cd `go list -f '{{.Dir}}' github.com/gorilla/websocket/examples/chat`
$ go run *.go

108
examples/chat/conn.go Normal file
View File

@ -0,0 +1,108 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"github.com/gorilla/websocket"
"log"
"net/http"
"time"
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
// Maximum message size allowed from peer.
maxMessageSize = 512
)
// connection is an middleman between the websocket connection and the hub.
type connection struct {
// The websocket connection.
ws *websocket.Conn
// Buffered channel of outbound messages.
send chan []byte
}
// readPump pumps messages from the websocket connection to the hub.
func (c *connection) readPump() {
defer func() {
h.unregister <- c
c.ws.Close()
}()
c.ws.SetReadLimit(maxMessageSize)
c.ws.SetReadDeadline(time.Now().Add(pongWait))
c.ws.SetPongHandler(func(string) error { c.ws.SetReadDeadline(time.Now().Add(pongWait)); return nil })
for {
_, message, err := c.ws.ReadMessage()
if err != nil {
break
}
h.broadcast <- message
}
}
// write writes a message with the given message type and payload.
func (c *connection) write(mt int, payload []byte) error {
c.ws.SetWriteDeadline(time.Now().Add(writeWait))
return c.ws.WriteMessage(mt, payload)
}
// writePump pumps messages from the hub to the websocket connection.
func (c *connection) writePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
c.ws.Close()
}()
for {
select {
case message, ok := <-c.send:
if !ok {
c.write(websocket.CloseMessage, []byte{})
return
}
if err := c.write(websocket.TextMessage, message); err != nil {
return
}
case <-ticker.C:
if err := c.write(websocket.PingMessage, []byte{}); err != nil {
return
}
}
}
}
// serverWs handles webocket requests from the peer.
func serveWs(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
http.Error(w, "Method not allowed", 405)
return
}
if r.Header.Get("Origin") != "http://"+r.Host {
http.Error(w, "Origin not allowed", 403)
return
}
ws, err := websocket.Upgrade(w, r, nil, 1024, 1024)
if _, ok := err.(websocket.HandshakeError); ok {
http.Error(w, "Not a websocket handshake", 400)
return
} else if err != nil {
log.Println(err)
return
}
c := &connection{send: make(chan []byte, 256), ws: ws}
h.register <- c
go c.writePump()
c.readPump()
}

92
examples/chat/home.html Normal file
View File

@ -0,0 +1,92 @@
<!DOCTYPE html>
<html lang="en">
<head>
<title>Chat Example</title>
<script src="//ajax.googleapis.com/ajax/libs/jquery/2.0.3/jquery.min.js"></script>
<script type="text/javascript">
$(function() {
var conn;
var msg = $("#msg");
var log = $("#log");
function appendLog(msg) {
var d = log[0]
var doScroll = d.scrollTop == d.scrollHeight - d.clientHeight;
msg.appendTo(log)
if (doScroll) {
d.scrollTop = d.scrollHeight - d.clientHeight;
}
}
$("#form").submit(function() {
if (!conn) {
return false;
}
if (!msg.val()) {
return false;
}
conn.send(msg.val());
msg.val("");
return false
});
if (window["WebSocket"]) {
conn = new WebSocket("ws://{{$}}/ws");
conn.onclose = function(evt) {
appendLog($("<div><b>Connection closed.</b></div>"))
}
conn.onmessage = function(evt) {
appendLog($("<div/>").text(evt.data))
}
} else {
appendLog($("<div><b>Your browser does not support WebSockets.</b></div>"))
}
});
</script>
<style type="text/css">
html {
overflow: hidden;
}
body {
overflow: hidden;
padding: 0;
margin: 0;
width: 100%;
height: 100%;
background: gray;
}
#log {
background: white;
margin: 0;
padding: 0.5em 0.5em 0.5em 0.5em;
position: absolute;
top: 0.5em;
left: 0.5em;
right: 0.5em;
bottom: 3em;
overflow: auto;
}
#form {
padding: 0 0.5em 0 0.5em;
margin: 0;
position: absolute;
bottom: 1em;
left: 0px;
width: 100%;
overflow: hidden;
}
</style>
</head>
<body>
<div id="log"></div>
<form id="form">
<input type="submit" value="Send" />
<input type="text" id="msg" size="64"/>
</form>
</body>
</html>

49
examples/chat/hub.go Normal file
View File

@ -0,0 +1,49 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
// hub maintains the set of active connections and broadcasts messages to the
// connections.
type hub struct {
// Registered connections.
connections map[*connection]bool
// Inbound messages from the connections.
broadcast chan []byte
// Register requests from the connections.
register chan *connection
// Unregister requests from connections.
unregister chan *connection
}
var h = hub{
broadcast: make(chan []byte),
register: make(chan *connection),
unregister: make(chan *connection),
connections: make(map[*connection]bool),
}
func (h *hub) run() {
for {
select {
case c := <-h.register:
h.connections[c] = true
case c := <-h.unregister:
delete(h.connections, c)
close(c.send)
case m := <-h.broadcast:
for c := range h.connections {
select {
case c.send <- m:
default:
close(c.send)
delete(h.connections, c)
}
}
}
}
}

39
examples/chat/main.go Normal file
View File

@ -0,0 +1,39 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"flag"
"log"
"net/http"
"text/template"
)
var addr = flag.String("addr", ":8080", "http service address")
var homeTempl = template.Must(template.ParseFiles("home.html"))
func serveHome(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.Error(w, "Not found", 404)
return
}
if r.Method != "GET" {
http.Error(w, "Method nod allowed", 405)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
homeTempl.Execute(w, r.Host)
}
func main() {
flag.Parse()
go h.run()
http.HandleFunc("/", serveHome)
http.HandleFunc("/ws", serveWs)
err := http.ListenAndServe(*addr, nil)
if err != nil {
log.Fatal("ListenAndServe: ", err)
}
}

39
json.go Normal file
View File

@ -0,0 +1,39 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"encoding/json"
)
// WriteJSON writes the JSON encoding of v to the connection.
//
// See the documentation for encoding/json Marshal for details about the
// conversion of Go values to JSON.
func WriteJSON(c *Conn, v interface{}) error {
w, err := c.NextWriter(TextMessage)
if err != nil {
return err
}
err1 := json.NewEncoder(w).Encode(v)
err2 := w.Close()
if err1 != nil {
return err1
}
return err2
}
// ReadJSON reads the next JSON-encoded message from the connection and stores
// it in the value pointed to by v.
//
// See the documentation for the encoding/json Marshal function for details
// about the conversion of JSON to a Go value.
func ReadJSON(c *Conn, v interface{}) error {
_, r, err := c.NextReader()
if err != nil {
return err
}
return json.NewDecoder(r).Decode(v)
}

37
json_test.go Normal file
View File

@ -0,0 +1,37 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bytes"
"reflect"
"testing"
)
func TestJSON(t *testing.T) {
var buf bytes.Buffer
c := fakeNetConn{&buf, &buf}
wc := newConn(c, true, 1024, 1024)
rc := newConn(c, false, 1024, 1024)
var actual, expect struct {
A int
B string
}
expect.A = 1
expect.B = "hello"
if err := WriteJSON(wc, &expect); err != nil {
t.Fatal("write", err)
}
if err := ReadJSON(rc, &actual); err != nil {
t.Fatal("read", err)
}
if !reflect.DeepEqual(&actual, &expect) {
t.Fatal("equal", actual, expect)
}
}

106
server.go Normal file
View File

@ -0,0 +1,106 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"errors"
"net"
"net/http"
)
// HandshakeError describes an error with the handshake from the peer.
type HandshakeError struct {
Err string
}
func (e HandshakeError) Error() string { return e.Err }
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// Upgrade returns a HandshakeError if the request is not a WebSocket
// handshake. Applications should handle errors of this type by replying to the
// client with an HTTP response.
//
// The application is responsible for checking the request origin before
// calling Upgrade. An example implementation of the same origin policy is:
//
// if req.Header.Get("Origin") != "http://"+req.Host {
// http.Error(w, "Origin not allowed", 403)
// return
// }
//
// Use the responseHeader to specify cookies (Set-Cookie) and the subprotocol
// (Sec-WebSocket-Protocol).
func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) {
if values := r.Header["Sec-Websocket-Version"]; len(values) == 0 || values[0] != "13" {
return nil, HandshakeError{"websocket: version != 13"}
}
if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
return nil, HandshakeError{"websocket: connection header != upgrade"}
}
if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
return nil, HandshakeError{"websocket: upgrade != websocket"}
}
var challengeKey string
values := r.Header["Sec-Websocket-Key"]
if len(values) == 0 || values[0] == "" {
return nil, HandshakeError{"websocket: key missing or blank"}
}
challengeKey = values[0]
var (
netConn net.Conn
br *bufio.Reader
err error
)
h, ok := w.(http.Hijacker)
if !ok {
return nil, errors.New("websocket: response does not implement http.Hijacker")
}
var rw *bufio.ReadWriter
netConn, rw, err = h.Hijack()
br = rw.Reader
if br.Buffered() > 0 {
netConn.Close()
return nil, errors.New("websocket: client sent data before handshake is complete")
}
c := newConn(netConn, true, readBufSize, writeBufSize)
p := c.writeBuf[:0]
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
p = append(p, computeAcceptKey(challengeKey)...)
p = append(p, "\r\n"...)
for k, vs := range responseHeader {
for _, v := range vs {
p = append(p, k...)
p = append(p, ": "...)
for i := 0; i < len(v); i++ {
b := v[i]
if b <= 31 {
// prevent response splitting.
b = ' '
}
p = append(p, b)
}
p = append(p, "\r\n"...)
}
}
p = append(p, "\r\n"...)
if _, err = netConn.Write(p); err != nil {
netConn.Close()
return nil, err
}
return c, nil
}

44
util.go Normal file
View File

@ -0,0 +1,44 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"io"
"net/http"
"strings"
)
// tokenListContainsValue returns true if the 1#token header with the given
// name contains token.
func tokenListContainsValue(header http.Header, name string, value string) bool {
for _, v := range header[name] {
for _, s := range strings.Split(v, ",") {
if strings.EqualFold(value, strings.TrimSpace(s)) {
return true
}
}
}
return false
}
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func computeAcceptKey(challengeKey string) string {
h := sha1.New()
h.Write([]byte(challengeKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
func generateChallengeKey() (string, error) {
p := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, p); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(p), nil
}