format with goimports

This commit is contained in:
siddontang 2015-05-05 08:45:01 +08:00
parent 530a231625
commit b151716326
12 changed files with 793 additions and 792 deletions

View File

@ -1,18 +1,18 @@
// BSON library for Go // BSON library for Go
// //
// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net> // Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
// //
// All rights reserved. // All rights reserved.
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met: // modification, are permitted provided that the following conditions are met:
// //
// 1. Redistributions of source code must retain the above copyright notice, this // 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer. // list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice, // 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation // this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution. // and/or other materials provided with the distribution.
// //
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE

View File

@ -1,18 +1,18 @@
// BSON library for Go // BSON library for Go
// //
// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net> // Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
// //
// All rights reserved. // All rights reserved.
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met: // modification, are permitted provided that the following conditions are met:
// //
// 1. Redistributions of source code must retain the above copyright notice, this // 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer. // list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice, // 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation // this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution. // and/or other materials provided with the distribution.
// //
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
@ -182,7 +182,7 @@ func isZero(v reflect.Value) bool {
if v.Type() == typeTime { if v.Type() == typeTime {
return v.Interface().(time.Time).IsZero() return v.Interface().(time.Time).IsZero()
} }
for i := v.NumField()-1; i >= 0; i-- { for i := v.NumField() - 1; i >= 0; i-- {
if !isZero(v.Field(i)) { if !isZero(v.Field(i)) {
return false return false
} }
@ -207,7 +207,7 @@ func (e *encoder) addSlice(v reflect.Value) {
return return
} }
l := v.Len() l := v.Len()
et := v.Type().Elem() et := v.Type().Elem()
if et == typeDocElem { if et == typeDocElem {
for i := 0; i < l; i++ { for i := 0; i < l; i++ {
elem := v.Index(i).Interface().(DocElem) elem := v.Index(i).Interface().(DocElem)
@ -401,7 +401,7 @@ func (e *encoder) addElem(name string, v reflect.Value, minSize bool) {
case time.Time: case time.Time:
// MongoDB handles timestamps as milliseconds. // MongoDB handles timestamps as milliseconds.
e.addElemName('\x09', name) e.addElemName('\x09', name)
e.addInt64(s.Unix() * 1000 + int64(s.Nanosecond() / 1e6)) e.addInt64(s.Unix()*1000 + int64(s.Nanosecond()/1e6))
case url.URL: case url.URL:
e.addElemName('\x02', name) e.addElemName('\x02', name)

View File

@ -36,7 +36,7 @@ func (h *FileHandler) Close() error {
return h.fd.Close() return h.fd.Close()
} }
//RotatingFileHandler writes log a file, if file size exceeds maxBytes, //RotatingFileHandler writes log a file, if file size exceeds maxBytes,
//it will backup current file and open a new one. //it will backup current file and open a new one.
// //
//max backup file number is set by backupCount, it will delete oldest if backups too many. //max backup file number is set by backupCount, it will delete oldest if backups too many.
@ -112,7 +112,7 @@ func (h *RotatingFileHandler) doRollover() {
} }
} }
//TimeRotatingFileHandler writes log to a file, //TimeRotatingFileHandler writes log to a file,
//it will backup current and open a new one, with a period time you sepecified. //it will backup current and open a new one, with a period time you sepecified.
// //
//refer: http://docs.python.org/2/library/logging.handlers.html. //refer: http://docs.python.org/2/library/logging.handlers.html.

View File

@ -31,8 +31,7 @@ func (h *StreamHandler) Close() error {
return nil return nil
} }
//NullHandler does nothing, it discards anything.
//NullHandler does nothing, it discards anything.
type NullHandler struct { type NullHandler struct {
} }

View File

@ -7,8 +7,8 @@ import (
) )
//SocketHandler writes log to a connectionl. //SocketHandler writes log to a connectionl.
//Network protocol is simple: log length + log | log length + log. log length is uint32, bigendian. //Network protocol is simple: log length + log | log length + log. log length is uint32, bigendian.
//you must implement your own log server, maybe you can use logd instead simply. //you must implement your own log server, maybe you can use logd instead simply.
type SocketHandler struct { type SocketHandler struct {
c net.Conn c net.Conn
protocol string protocol string

View File

@ -1,60 +1,60 @@
package websocket package websocket
import ( import (
"bytes" "bytes"
"errors" "errors"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
) )
var ( var (
ErrBadHandshake = errors.New("bad handshake") ErrBadHandshake = errors.New("bad handshake")
) )
func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header) (c *Conn, response *http.Response, err error) { func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header) (c *Conn, response *http.Response, err error) {
key, err := calcKey() key, err := calcKey()
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
acceptKey := calcAcceptKey(key) acceptKey := calcAcceptKey(key)
c = NewConn(netConn, false) c = NewConn(netConn, false)
buf := bytes.NewBufferString("GET ") buf := bytes.NewBufferString("GET ")
buf.WriteString(u.RequestURI()) buf.WriteString(u.RequestURI())
buf.WriteString(" HTTP/1.1\r\nHost: ") buf.WriteString(" HTTP/1.1\r\nHost: ")
buf.WriteString(u.Host) buf.WriteString(u.Host)
buf.WriteString("\r\nUpgrade: websocket\r\nConnection: upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: ") buf.WriteString("\r\nUpgrade: websocket\r\nConnection: upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: ")
buf.WriteString(key) buf.WriteString(key)
buf.WriteString("\r\n") buf.WriteString("\r\n")
for k, vs := range requestHeader { for k, vs := range requestHeader {
for _, v := range vs { for _, v := range vs {
buf.WriteString(k) buf.WriteString(k)
buf.WriteString(": ") buf.WriteString(": ")
buf.WriteString(v) buf.WriteString(v)
buf.WriteString("\r\n") buf.WriteString("\r\n")
} }
} }
buf.WriteString("\r\n") buf.WriteString("\r\n")
p := buf.Bytes() p := buf.Bytes()
if _, err := netConn.Write(p); err != nil { if _, err := netConn.Write(p); err != nil {
return nil, nil, err return nil, nil, err
} }
resp, err := http.ReadResponse(c.br, &http.Request{Method: "GET", URL: u}) resp, err := http.ReadResponse(c.br, &http.Request{Method: "GET", URL: u})
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if resp.StatusCode != 101 || if resp.StatusCode != 101 ||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != acceptKey { resp.Header.Get("Sec-Websocket-Accept") != acceptKey {
return nil, resp, ErrBadHandshake return nil, resp, ErrBadHandshake
} }
return c, resp, nil return c, resp, nil
} }

View File

@ -1,99 +1,100 @@
package websocket package websocket
import ( import (
"github.com/gorilla/websocket" "net"
"net" "net/http"
"net/http" "net/url"
"net/url" "testing"
"testing" "time"
"time"
) "github.com/gorilla/websocket"
)
func TestWSClient(t *testing.T) {
http.HandleFunc("/test/client", func(w http.ResponseWriter, r *http.Request) { func TestWSClient(t *testing.T) {
conn, err := websocket.Upgrade(w, r, nil, 1024, 1024) http.HandleFunc("/test/client", func(w http.ResponseWriter, r *http.Request) {
if err != nil { conn, err := websocket.Upgrade(w, r, nil, 1024, 1024)
t.Fatal(err.Error()) if err != nil {
} t.Fatal(err.Error())
}
msgType, msg, err := conn.ReadMessage()
conn.WriteMessage(websocket.TextMessage, msg) msgType, msg, err := conn.ReadMessage()
conn.WriteMessage(websocket.TextMessage, msg)
if err != nil {
t.Fatal(err.Error()) if err != nil {
} t.Fatal(err.Error())
}
if msgType != websocket.TextMessage {
t.Fatal("invalid msg type", msgType) if msgType != websocket.TextMessage {
} t.Fatal("invalid msg type", msgType)
}
msgType, msg, err = conn.ReadMessage()
if err != nil { msgType, msg, err = conn.ReadMessage()
t.Fatal(err.Error()) if err != nil {
} t.Fatal(err.Error())
}
if msgType != websocket.PingMessage {
t.Fatal("invalid msg type", msgType) if msgType != websocket.PingMessage {
} t.Fatal("invalid msg type", msgType)
}
conn.WriteMessage(websocket.PongMessage, []byte{})
conn.WriteMessage(websocket.PongMessage, []byte{})
conn.WriteMessage(websocket.PingMessage, []byte{})
conn.WriteMessage(websocket.PingMessage, []byte{})
msgType, msg, err = conn.ReadMessage()
if err != nil { msgType, msg, err = conn.ReadMessage()
t.Fatal(err.Error()) if err != nil {
} t.Fatal(err.Error())
println(msgType) }
if msgType != websocket.PongMessage { println(msgType)
if msgType != websocket.PongMessage {
t.Fatal("invalid msg type", msgType)
} t.Fatal("invalid msg type", msgType)
}) }
})
go http.ListenAndServe(":65500", nil)
go http.ListenAndServe(":65500", nil)
time.Sleep(time.Second * 1)
time.Sleep(time.Second * 1)
conn, err := net.Dial("tcp", "127.0.0.1:65500")
conn, err := net.Dial("tcp", "127.0.0.1:65500")
if err != nil {
t.Fatal(err.Error()) if err != nil {
} t.Fatal(err.Error())
ws, _, err := NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/client"}, nil) }
ws, _, err := NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/client"}, nil)
if err != nil {
t.Fatal(err.Error()) if err != nil {
} t.Fatal(err.Error())
}
payload := make([]byte, 4*1024)
for i := 0; i < 4*1024; i++ { payload := make([]byte, 4*1024)
payload[i] = 'x' for i := 0; i < 4*1024; i++ {
} payload[i] = 'x'
}
ws.WriteString(payload)
ws.WriteString(payload)
msgType, msg, err := ws.Read()
if err != nil { msgType, msg, err := ws.Read()
t.Fatal(err.Error()) if err != nil {
} t.Fatal(err.Error())
if msgType != TextMessage { }
t.Fatal("invalid msg type", msgType) if msgType != TextMessage {
} t.Fatal("invalid msg type", msgType)
}
if string(msg) != string(payload) {
t.Fatal("invalid msg", string(msg)) if string(msg) != string(payload) {
t.Fatal("invalid msg", string(msg))
}
}
//test ping
ws.Ping([]byte{}) //test ping
msgType, msg, err = ws.ReadMessage() ws.Ping([]byte{})
if err != nil { msgType, msg, err = ws.ReadMessage()
t.Fatal(err.Error()) if err != nil {
} t.Fatal(err.Error())
if msgType != PongMessage { }
t.Fatal("invalid msg type", msgType) if msgType != PongMessage {
} t.Fatal("invalid msg type", msgType)
}
}
}

View File

@ -1,321 +1,321 @@
package websocket package websocket
import ( import (
"bufio" "bufio"
"encoding/binary" "encoding/binary"
"errors" "errors"
"io" "io"
"math/rand" "math/rand"
"net" "net"
"time" "time"
) )
//refer RFC6455 //refer RFC6455
const ( const (
TextMessage byte = 1 TextMessage byte = 1
BinaryMessage byte = 2 BinaryMessage byte = 2
CloseMessage byte = 8 CloseMessage byte = 8
PingMessage byte = 9 PingMessage byte = 9
PongMessage byte = 10 PongMessage byte = 10
) )
var ( var (
ErrControlTooLong = errors.New("control message too long") ErrControlTooLong = errors.New("control message too long")
ErrRSVNotSupport = errors.New("reserved bit not support") ErrRSVNotSupport = errors.New("reserved bit not support")
ErrPayloadError = errors.New("payload length error") ErrPayloadError = errors.New("payload length error")
ErrControlFragmented = errors.New("control message can not be fragmented") ErrControlFragmented = errors.New("control message can not be fragmented")
ErrNotTCPConn = errors.New("not a tcp connection") ErrNotTCPConn = errors.New("not a tcp connection")
ErrWriteError = errors.New("write error") ErrWriteError = errors.New("write error")
) )
type Conn struct { type Conn struct {
conn net.Conn conn net.Conn
br *bufio.Reader br *bufio.Reader
isServer bool isServer bool
} }
func NewConn(conn net.Conn, isServer bool) *Conn { func NewConn(conn net.Conn, isServer bool) *Conn {
c := new(Conn) c := new(Conn)
c.conn = conn c.conn = conn
c.br = bufio.NewReader(conn) c.br = bufio.NewReader(conn)
c.isServer = isServer c.isServer = isServer
return c return c
} }
func (c *Conn) ReadMessage() (messageType byte, message []byte, err error) { func (c *Conn) ReadMessage() (messageType byte, message []byte, err error) {
return c.Read() return c.Read()
} }
func (c *Conn) Read() (messageType byte, message []byte, err error) { func (c *Conn) Read() (messageType byte, message []byte, err error) {
buf := make([]byte, 8, 8) buf := make([]byte, 8, 8)
message = []byte{} message = []byte{}
messageType = 0 messageType = 0
for { for {
opcode, data, err := c.readFrame(buf) opcode, data, err := c.readFrame(buf)
if err != nil { if err != nil {
return messageType, message, err return messageType, message, err
} }
message = append(message, data...) message = append(message, data...)
if opcode&0x80 != 0 { if opcode&0x80 != 0 {
//final //final
if opcode&0x0F > 0 { if opcode&0x0F > 0 {
//not continue frame //not continue frame
messageType = opcode & 0x0F messageType = opcode & 0x0F
} }
return messageType, message, nil return messageType, message, nil
} else { } else {
if opcode&0x0F > 0 { if opcode&0x0F > 0 {
//first continue frame //first continue frame
messageType = opcode & 0x0F messageType = opcode & 0x0F
} }
} }
} }
return return
} }
func (c *Conn) Write(message []byte, binary bool) error { func (c *Conn) Write(message []byte, binary bool) error {
if binary { if binary {
return c.sendFrame(BinaryMessage, message) return c.sendFrame(BinaryMessage, message)
} else { } else {
return c.sendFrame(TextMessage, message) return c.sendFrame(TextMessage, message)
} }
} }
func (c *Conn) WriteMessage(messageType byte, message []byte) error { func (c *Conn) WriteMessage(messageType byte, message []byte) error {
return c.sendFrame(messageType, message) return c.sendFrame(messageType, message)
} }
//write utf-8 text message //write utf-8 text message
func (c *Conn) WriteString(message []byte) error { func (c *Conn) WriteString(message []byte) error {
return c.Write(message, false) return c.Write(message, false)
} }
//write binary message //write binary message
func (c *Conn) WriteBinary(message []byte) error { func (c *Conn) WriteBinary(message []byte) error {
return c.Write(message, true) return c.Write(message, true)
} }
func (c *Conn) Ping(message []byte) error { func (c *Conn) Ping(message []byte) error {
return c.sendFrame(PingMessage, message) return c.sendFrame(PingMessage, message)
} }
func (c *Conn) Pong(message []byte) error { func (c *Conn) Pong(message []byte) error {
return c.sendFrame(PongMessage, message) return c.sendFrame(PongMessage, message)
} }
//close socket, not send websocket close message //close socket, not send websocket close message
func (c *Conn) Close() error { func (c *Conn) Close() error {
return c.conn.Close() return c.conn.Close()
} }
func (c *Conn) LocalAddr() net.Addr { func (c *Conn) LocalAddr() net.Addr {
return c.conn.LocalAddr() return c.conn.LocalAddr()
} }
func (c *Conn) RemoteAddr() net.Addr { func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr() return c.conn.RemoteAddr()
} }
func (c *Conn) SetReadDeadline(t time.Time) error { func (c *Conn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t) return c.conn.SetReadDeadline(t)
} }
func (c *Conn) SetWriteDeadline(t time.Time) error { func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t) return c.conn.SetWriteDeadline(t)
} }
func (c *Conn) SetReadBuffer(bytes int) error { func (c *Conn) SetReadBuffer(bytes int) error {
if tcpConn, ok := c.conn.(*net.TCPConn); ok { if tcpConn, ok := c.conn.(*net.TCPConn); ok {
return tcpConn.SetReadBuffer(bytes) return tcpConn.SetReadBuffer(bytes)
} else { } else {
return ErrNotTCPConn return ErrNotTCPConn
} }
} }
func (c *Conn) SetWriteBuffer(bytes int) error { func (c *Conn) SetWriteBuffer(bytes int) error {
if tcpConn, ok := c.conn.(*net.TCPConn); ok { if tcpConn, ok := c.conn.(*net.TCPConn); ok {
return tcpConn.SetWriteBuffer(bytes) return tcpConn.SetWriteBuffer(bytes)
} else { } else {
return ErrNotTCPConn return ErrNotTCPConn
} }
} }
func (c *Conn) readPayloadLen(length byte, buf []byte) (payloadLen uint64, err error) { func (c *Conn) readPayloadLen(length byte, buf []byte) (payloadLen uint64, err error) {
if length < 126 { if length < 126 {
payloadLen = uint64(length) payloadLen = uint64(length)
} else if length == 126 { } else if length == 126 {
err = c.read(buf[:2]) err = c.read(buf[:2])
if err != nil { if err != nil {
return return
} }
payloadLen = uint64(binary.BigEndian.Uint16(buf[:2])) payloadLen = uint64(binary.BigEndian.Uint16(buf[:2]))
} else if length == 127 { } else if length == 127 {
err = c.read(buf[:8]) err = c.read(buf[:8])
if err != nil { if err != nil {
return return
} }
payloadLen = uint64(binary.BigEndian.Uint16(buf[:8])) payloadLen = uint64(binary.BigEndian.Uint16(buf[:8]))
} }
return return
} }
func (c *Conn) readFrame(buf []byte) (opcode byte, messsage []byte, err error) { func (c *Conn) readFrame(buf []byte) (opcode byte, messsage []byte, err error) {
//minimum head may 2 byte //minimum head may 2 byte
err = c.read(buf[:2]) err = c.read(buf[:2])
if err != nil { if err != nil {
return return
} }
opcode = buf[0] opcode = buf[0]
if opcode&0x70 > 0 { if opcode&0x70 > 0 {
err = ErrRSVNotSupport err = ErrRSVNotSupport
return return
} }
//isMasking := (0x80 & buf[1]) > 0 //isMasking := (0x80 & buf[1]) > 0
isMasking := (0x80 & buf[1]) > 0 isMasking := (0x80 & buf[1]) > 0
var payloadLen uint64 var payloadLen uint64
payloadLen, err = c.readPayloadLen(buf[1]&0x7F, buf) payloadLen, err = c.readPayloadLen(buf[1]&0x7F, buf)
if err != nil { if err != nil {
return return
} }
if opcode&0x08 > 0 && payloadLen > 125 { if opcode&0x08 > 0 && payloadLen > 125 {
err = ErrControlTooLong err = ErrControlTooLong
return return
} }
var masking []byte var masking []byte
if isMasking { if isMasking {
err = c.read(buf[:4]) err = c.read(buf[:4])
if err != nil { if err != nil {
return return
} }
masking = buf[:4] masking = buf[:4]
} }
messsage = make([]byte, payloadLen) messsage = make([]byte, payloadLen)
err = c.read(messsage) err = c.read(messsage)
if err != nil { if err != nil {
return return
} }
if isMasking { if isMasking {
//maskingKey := c.newMaskingKey() //maskingKey := c.newMaskingKey()
c.maskingData(messsage, masking) c.maskingData(messsage, masking)
} }
return return
} }
func (c *Conn) sendFrame(opcode byte, message []byte) error { func (c *Conn) sendFrame(opcode byte, message []byte) error {
//max frame header may 14 length //max frame header may 14 length
buf := make([]byte, 0, len(message)+14) buf := make([]byte, 0, len(message)+14)
//here we don not support continue frame, all are final //here we don not support continue frame, all are final
opcode |= 0x80 opcode |= 0x80
if opcode&0x08 > 0 && len(message) >= 126 { if opcode&0x08 > 0 && len(message) >= 126 {
return ErrControlTooLong return ErrControlTooLong
} }
buf = append(buf, opcode) buf = append(buf, opcode)
//no mask, because chrome may not support //no mask, because chrome may not support
var mask byte = 0x00 var mask byte = 0x00
if !c.isServer { if !c.isServer {
//for client, we will mask data //for client, we will mask data
mask = 0x80 mask = 0x80
} }
payloadLen := len(message) payloadLen := len(message)
if payloadLen < 126 { if payloadLen < 126 {
buf = append(buf, mask|byte(payloadLen)) buf = append(buf, mask|byte(payloadLen))
} else if payloadLen <= 0xFFFF { } else if payloadLen <= 0xFFFF {
buf = append(buf, mask|byte(126), 0, 0) buf = append(buf, mask|byte(126), 0, 0)
binary.BigEndian.PutUint16(buf[len(buf)-2:], uint16(payloadLen)) binary.BigEndian.PutUint16(buf[len(buf)-2:], uint16(payloadLen))
} else { } else {
buf = append(buf, mask|byte(127), 0, 0, 0, 0, 0, 0, 0, 0) buf = append(buf, mask|byte(127), 0, 0, 0, 0, 0, 0, 0, 0)
binary.BigEndian.PutUint64(buf[len(buf)-8:], uint64(payloadLen)) binary.BigEndian.PutUint64(buf[len(buf)-8:], uint64(payloadLen))
} }
if !c.isServer { if !c.isServer {
maskingKey := c.newMaskingKey() maskingKey := c.newMaskingKey()
buf = append(buf, maskingKey...) buf = append(buf, maskingKey...)
pos := len(buf) pos := len(buf)
buf = append(buf, message...) buf = append(buf, message...)
c.maskingData(buf[pos:], maskingKey) c.maskingData(buf[pos:], maskingKey)
} else { } else {
buf = append(buf, message...) buf = append(buf, message...)
} }
tmpBuf := buf tmpBuf := buf
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
n, err := c.conn.Write(tmpBuf) n, err := c.conn.Write(tmpBuf)
if err != nil { if err != nil {
return err return err
} }
if n == len(tmpBuf) { if n == len(tmpBuf) {
return nil return nil
} else { } else {
tmpBuf = tmpBuf[n:] tmpBuf = tmpBuf[n:]
} }
} }
return ErrWriteError return ErrWriteError
} }
func (c *Conn) read(buf []byte) error { func (c *Conn) read(buf []byte) error {
var err error var err error
for len(buf) > 0 && err == nil { for len(buf) > 0 && err == nil {
var nn int var nn int
nn, err = c.br.Read(buf) nn, err = c.br.Read(buf)
buf = buf[nn:] buf = buf[nn:]
} }
if err == io.EOF { if err == io.EOF {
if len(buf) == 0 { if len(buf) == 0 {
err = nil err = nil
} else { } else {
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
} }
} }
return err return err
} }
func (c *Conn) maskingData(data []byte, maskingKey []byte) { func (c *Conn) maskingData(data []byte, maskingKey []byte) {
for i := range data { for i := range data {
data[i] ^= maskingKey[i%4] data[i] ^= maskingKey[i%4]
} }
} }
func (c *Conn) newMaskingKey() []byte { func (c *Conn) newMaskingKey() []byte {
n := rand.Uint32() n := rand.Uint32()
return []byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 32)} return []byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 32)}
} }

View File

@ -1,51 +1,51 @@
package websocket package websocket
import ( import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"testing" "testing"
"time" "time"
) )
func TestWSPing(t *testing.T) { func TestWSPing(t *testing.T) {
http.HandleFunc("/test/ping", func(w http.ResponseWriter, r *http.Request) { http.HandleFunc("/test/ping", func(w http.ResponseWriter, r *http.Request) {
conn, err := Upgrade(w, r, nil) conn, err := Upgrade(w, r, nil)
if err != nil { if err != nil {
t.Fatal(err.Error()) t.Fatal(err.Error())
} }
//conn := NewConn(c, true) //conn := NewConn(c, true)
conn.Read() conn.Read()
conn.Pong([]byte{}) conn.Pong([]byte{})
conn.Ping([]byte{}) conn.Ping([]byte{})
msgType, _, _ := conn.Read() msgType, _, _ := conn.Read()
println(msgType) println(msgType)
}) })
go http.ListenAndServe(":65500", nil) go http.ListenAndServe(":65500", nil)
time.Sleep(time.Second * 1) time.Sleep(time.Second * 1)
conn, err := net.Dial("tcp", "127.0.0.1:65500") conn, err := net.Dial("tcp", "127.0.0.1:65500")
if err != nil { if err != nil {
t.Fatal(err.Error()) t.Fatal(err.Error())
} }
ws, _, err := NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/ping"}, nil) ws, _, err := NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/ping"}, nil)
if err != nil { if err != nil {
t.Fatal(err.Error()) t.Fatal(err.Error())
} }
ws.Ping([]byte{}) ws.Ping([]byte{})
msgType, _, _ := ws.Read() msgType, _, _ := ws.Read()
if msgType != PongMessage { if msgType != PongMessage {
t.Fatal("invalid msg type", msgType) t.Fatal("invalid msg type", msgType)
} }
msgType, _, _ = ws.Read() msgType, _, _ = ws.Read()
if msgType != PingMessage { if msgType != PingMessage {
t.Fatal("invalid msg type", msgType) t.Fatal("invalid msg type", msgType)
} }
ws.Pong([]byte{}) ws.Pong([]byte{})
time.Sleep(time.Second * 1) time.Sleep(time.Second * 1)
} }

View File

@ -1,105 +1,105 @@
package websocket package websocket
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"errors" "errors"
"net" "net"
"net/http" "net/http"
"strings" "strings"
) )
var ( var (
ErrInvalidMethod = errors.New("Only GET Supported") ErrInvalidMethod = errors.New("Only GET Supported")
ErrInvalidVersion = errors.New("Sec-Websocket-Version: 13") ErrInvalidVersion = errors.New("Sec-Websocket-Version: 13")
ErrInvalidUpgrade = errors.New("Can \"Upgrade\" only to \"WebSocket\"") ErrInvalidUpgrade = errors.New("Can \"Upgrade\" only to \"WebSocket\"")
ErrInvalidConnection = errors.New("\"Connection\" must be \"Upgrade\"") ErrInvalidConnection = errors.New("\"Connection\" must be \"Upgrade\"")
ErrMissingKey = errors.New("Missing Key") ErrMissingKey = errors.New("Missing Key")
ErrHijacker = errors.New("Not implement http.Hijacker") ErrHijacker = errors.New("Not implement http.Hijacker")
ErrNoEmptyConn = errors.New("Conn ReadBuf must be empty") ErrNoEmptyConn = errors.New("Conn ReadBuf must be empty")
) )
func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
if r.Method != "GET" { if r.Method != "GET" {
return nil, ErrInvalidMethod return nil, ErrInvalidMethod
} }
if r.Header.Get("Sec-Websocket-Version") != "13" { if r.Header.Get("Sec-Websocket-Version") != "13" {
return nil, ErrInvalidVersion return nil, ErrInvalidVersion
} }
if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" { if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {
return nil, ErrInvalidUpgrade return nil, ErrInvalidUpgrade
} }
if strings.ToLower(r.Header.Get("Connection")) != "upgrade" { if strings.ToLower(r.Header.Get("Connection")) != "upgrade" {
return nil, ErrInvalidConnection return nil, ErrInvalidConnection
} }
var acceptKey string var acceptKey string
if key := r.Header.Get("Sec-Websocket-key"); len(key) == 0 { if key := r.Header.Get("Sec-Websocket-key"); len(key) == 0 {
return nil, ErrMissingKey return nil, ErrMissingKey
} else { } else {
acceptKey = calcAcceptKey(key) acceptKey = calcAcceptKey(key)
} }
var ( var (
netConn net.Conn netConn net.Conn
br *bufio.Reader br *bufio.Reader
err error err error
) )
h, ok := w.(http.Hijacker) h, ok := w.(http.Hijacker)
if !ok { if !ok {
return nil, ErrHijacker return nil, ErrHijacker
} }
var rw *bufio.ReadWriter var rw *bufio.ReadWriter
netConn, rw, err = h.Hijack() netConn, rw, err = h.Hijack()
br = rw.Reader br = rw.Reader
if br.Buffered() > 0 { if br.Buffered() > 0 {
netConn.Close() netConn.Close()
return nil, ErrNoEmptyConn return nil, ErrNoEmptyConn
} }
c := NewConn(netConn, true) c := NewConn(netConn, true)
buf := bytes.NewBufferString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ") buf := bytes.NewBufferString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ")
buf.WriteString(acceptKey) buf.WriteString(acceptKey)
buf.WriteString("\r\n") buf.WriteString("\r\n")
subProtol := selectSubProtocol(r) subProtol := selectSubProtocol(r)
if len(subProtol) > 0 { if len(subProtol) > 0 {
buf.WriteString("Sec-Websocket-Protocol: ") buf.WriteString("Sec-Websocket-Protocol: ")
buf.WriteString(subProtol) buf.WriteString(subProtol)
buf.WriteString("\r\n") buf.WriteString("\r\n")
} }
for k, vs := range responseHeader { for k, vs := range responseHeader {
for _, v := range vs { for _, v := range vs {
buf.WriteString(k) buf.WriteString(k)
buf.WriteString(": ") buf.WriteString(": ")
buf.WriteString(v) buf.WriteString(v)
buf.WriteString("\r\n") buf.WriteString("\r\n")
} }
} }
buf.WriteString("\r\n") buf.WriteString("\r\n")
if _, err = netConn.Write(buf.Bytes()); err != nil { if _, err = netConn.Write(buf.Bytes()); err != nil {
netConn.Close() netConn.Close()
return nil, err return nil, err
} }
return c, nil return c, nil
} }
func selectSubProtocol(r *http.Request) string { func selectSubProtocol(r *http.Request) string {
h := r.Header.Get("Sec-Websocket-Protocol") h := r.Header.Get("Sec-Websocket-Protocol")
if len(h) == 0 { if len(h) == 0 {
return "" return ""
} }
return strings.Split(h, ",")[0] return strings.Split(h, ",")[0]
} }

View File

@ -1,97 +1,98 @@
package websocket package websocket
import ( import (
"github.com/gorilla/websocket" "net"
"net" "net/http"
"net/http" "net/url"
"net/url" "testing"
"testing" "time"
"time"
) "github.com/gorilla/websocket"
)
func TestWSServer(t *testing.T) {
http.HandleFunc("/test/server", func(w http.ResponseWriter, r *http.Request) { func TestWSServer(t *testing.T) {
conn, err := Upgrade(w, r, nil) http.HandleFunc("/test/server", func(w http.ResponseWriter, r *http.Request) {
conn, err := Upgrade(w, r, nil)
if err != nil {
t.Fatal(err.Error()) if err != nil {
} t.Fatal(err.Error())
//err = conn.SetReadBuffer(1024 * 1024 * 4) }
//if err != nil { //err = conn.SetReadBuffer(1024 * 1024 * 4)
// println(err.Error()) //if err != nil {
//} // println(err.Error())
//err = conn.SetWriteBuffer(1024 * 1024 * 4) //}
//err = conn.SetWriteBuffer(1024 * 1024 * 4)
//if err != nil {
// println(err.Error()) //if err != nil {
//} // println(err.Error())
//}
msgType, msg, err := conn.Read()
conn.Write(msg, false) msgType, msg, err := conn.Read()
conn.Write(msg, false)
if err != nil {
t.Fatal(err.Error()) if err != nil {
} t.Fatal(err.Error())
}
if msgType != TextMessage {
t.Fatal("wrong msg type", msgType) if msgType != TextMessage {
} t.Fatal("wrong msg type", msgType)
}
msgType, msg, err = conn.ReadMessage()
if err != nil { msgType, msg, err = conn.ReadMessage()
t.Fatal(err.Error()) if err != nil {
} t.Fatal(err.Error())
}
if msgType != PingMessage {
t.Fatal("wrong msg type", msgType) if msgType != PingMessage {
} t.Fatal("wrong msg type", msgType)
}
err = conn.Pong([]byte("abc"))
err = conn.Pong([]byte("abc"))
if err != nil {
t.Fatal(err.Error()) if err != nil {
} t.Fatal(err.Error())
}
})
})
go http.ListenAndServe(":65500", nil)
time.Sleep(time.Second * 1) go http.ListenAndServe(":65500", nil)
time.Sleep(time.Second * 1)
conn, err := net.Dial("tcp", "127.0.0.1:65500")
conn, err := net.Dial("tcp", "127.0.0.1:65500")
if err != nil {
t.Fatal(err.Error()) if err != nil {
} t.Fatal(err.Error())
ws, _, err := websocket.NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/server"}, nil, 1024, 1024) }
ws, _, err := websocket.NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/server"}, nil, 1024, 1024)
ws.SetPongHandler(func(string) error {
println("pong") ws.SetPongHandler(func(string) error {
return nil println("pong")
}) return nil
})
if err != nil {
t.Fatal(err.Error()) if err != nil {
} t.Fatal(err.Error())
}
payload := make([]byte, 4*1024*1024)
for i := 0; i < 4*1024*1024; i++ { payload := make([]byte, 4*1024*1024)
payload[i] = 'x' for i := 0; i < 4*1024*1024; i++ {
} payload[i] = 'x'
}
ws.WriteMessage(websocket.TextMessage, payload)
ws.WriteMessage(websocket.TextMessage, payload)
msgType, msg, err := ws.ReadMessage()
if err != nil { msgType, msg, err := ws.ReadMessage()
t.Fatal(err.Error()) if err != nil {
} t.Fatal(err.Error())
if msgType != websocket.TextMessage { }
t.Fatal("invalid msg type", msgType) if msgType != websocket.TextMessage {
} t.Fatal("invalid msg type", msgType)
}
if string(msg) != string(payload) {
t.Fatal("invalid msg", string(msg)) if string(msg) != string(payload) {
t.Fatal("invalid msg", string(msg))
}
}
time.Sleep(time.Second * 1)
} time.Sleep(time.Second * 1)
}

View File

@ -1,36 +1,36 @@
package websocket package websocket
import ( import (
"crypto/rand" "crypto/rand"
"crypto/sha1" "crypto/sha1"
"encoding/base64" "encoding/base64"
"errors" "errors"
"io" "io"
) )
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func calcAcceptKey(key string) string { func calcAcceptKey(key string) string {
h := sha1.New() h := sha1.New()
h.Write([]byte(key)) h.Write([]byte(key))
h.Write(keyGUID) h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil)) return base64.StdEncoding.EncodeToString(h.Sum(nil))
} }
func calcKey() (string, error) { func calcKey() (string, error) {
p := make([]byte, 16) p := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, p); err != nil { if _, err := io.ReadFull(rand.Reader, p); err != nil {
return "", err return "", err
} }
return base64.StdEncoding.EncodeToString(p), nil return base64.StdEncoding.EncodeToString(p), nil
} }
func HandleCloseFrame(buf []byte) (int16, string, error) { func HandleCloseFrame(buf []byte) (int16, string, error) {
if len(buf) < 2 { if len(buf) < 2 {
return 0, "", errors.New("close frame msg's length less than 2") return 0, "", errors.New("close frame msg's length less than 2")
} }
code := int16(buf[0])<<8 + int16(buf[1]) code := int16(buf[0])<<8 + int16(buf[1])
reason := string(buf[2:]) reason := string(buf[2:])
return code, reason, nil return code, reason, nil
} }