mirror of https://github.com/siddontang/go.git
add websocket
This commit is contained in:
parent
a46d1de902
commit
8dc4fbcfad
|
@ -0,0 +1,60 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBadHandshake = errors.New("bad handshake")
|
||||
)
|
||||
|
||||
func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header) (c *Conn, response *http.Response, err error) {
|
||||
key, err := calcKey()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
acceptKey := calcAcceptKey(key)
|
||||
|
||||
c = NewConn(netConn, false)
|
||||
|
||||
buf := bytes.NewBufferString("GET ")
|
||||
buf.WriteString(u.RequestURI())
|
||||
buf.WriteString(" HTTP/1.1\r\nHost: ")
|
||||
buf.WriteString(u.Host)
|
||||
buf.WriteString("\r\nUpgrade: websocket\r\nConnection: upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: ")
|
||||
buf.WriteString(key)
|
||||
buf.WriteString("\r\n")
|
||||
|
||||
for k, vs := range requestHeader {
|
||||
for _, v := range vs {
|
||||
buf.WriteString(k)
|
||||
buf.WriteString(": ")
|
||||
buf.WriteString(v)
|
||||
buf.WriteString("\r\n")
|
||||
}
|
||||
}
|
||||
|
||||
buf.WriteString("\r\n")
|
||||
p := buf.Bytes()
|
||||
if _, err := netConn.Write(p); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(c.br, &http.Request{Method: "GET", URL: u})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 101 ||
|
||||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
|
||||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
|
||||
resp.Header.Get("Sec-Websocket-Accept") != acceptKey {
|
||||
return nil, resp, ErrBadHandshake
|
||||
}
|
||||
return c, resp, nil
|
||||
}
|
|
@ -0,0 +1,323 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"lib/log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
//refer RFC6455
|
||||
|
||||
const (
|
||||
TextMessage byte = 1
|
||||
BinaryMessage byte = 2
|
||||
CloseMessage byte = 8
|
||||
PingMessage byte = 9
|
||||
PongMessage byte = 10
|
||||
)
|
||||
|
||||
var (
|
||||
ErrControlTooLong = errors.New("control message too long")
|
||||
ErrRSVNotSupport = errors.New("reserved bit not support")
|
||||
ErrPayloadError = errors.New("payload length error")
|
||||
ErrControlFragmented = errors.New("control message can not be fragmented")
|
||||
ErrNotTCPConn = errors.New("not a tcp connection")
|
||||
ErrWriteError = errors.New("write error")
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
conn net.Conn
|
||||
|
||||
br *bufio.Reader
|
||||
|
||||
isServer bool
|
||||
}
|
||||
|
||||
func NewConn(conn net.Conn, isServer bool) *Conn {
|
||||
c := new(Conn)
|
||||
|
||||
c.conn = conn
|
||||
|
||||
c.br = bufio.NewReader(conn)
|
||||
|
||||
c.isServer = isServer
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Conn) ReadMessage() (messageType byte, message []byte, err error) {
|
||||
return c.Read()
|
||||
}
|
||||
|
||||
func (c *Conn) Read() (messageType byte, message []byte, err error) {
|
||||
buf := make([]byte, 8, 8)
|
||||
|
||||
message = []byte{}
|
||||
|
||||
messageType = 0
|
||||
|
||||
for {
|
||||
opcode, data, err := c.readFrame(buf)
|
||||
|
||||
if err != nil {
|
||||
return messageType, message, err
|
||||
}
|
||||
|
||||
message = append(message, data...)
|
||||
|
||||
if opcode&0x80 != 0 {
|
||||
//final
|
||||
if opcode&0x0F > 0 {
|
||||
//not continue frame
|
||||
messageType = opcode & 0x0F
|
||||
}
|
||||
return messageType, message, nil
|
||||
|
||||
} else {
|
||||
if opcode&0x0F > 0 {
|
||||
//first continue frame
|
||||
messageType = opcode & 0x0F
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) Write(message []byte, binary bool) error {
|
||||
if binary {
|
||||
return c.sendFrame(BinaryMessage, message)
|
||||
} else {
|
||||
return c.sendFrame(TextMessage, message)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) WriteMessage(messageType byte, message []byte) error {
|
||||
return c.sendFrame(messageType, message)
|
||||
}
|
||||
|
||||
//write utf-8 text message
|
||||
func (c *Conn) WriteString(message []byte) error {
|
||||
return c.Write(message, false)
|
||||
}
|
||||
|
||||
//write binary message
|
||||
func (c *Conn) WriteBinary(message []byte) error {
|
||||
return c.Write(message, true)
|
||||
}
|
||||
|
||||
func (c *Conn) Ping(message []byte) error {
|
||||
return c.sendFrame(PingMessage, message)
|
||||
}
|
||||
|
||||
func (c *Conn) Pong(message []byte) error {
|
||||
return c.sendFrame(PongMessage, message)
|
||||
}
|
||||
|
||||
//close socket, not send websocket close message
|
||||
func (c *Conn) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadBuffer(bytes int) error {
|
||||
if tcpConn, ok := c.conn.(*net.TCPConn); ok {
|
||||
return tcpConn.SetReadBuffer(bytes)
|
||||
} else {
|
||||
return ErrNotTCPConn
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteBuffer(bytes int) error {
|
||||
if tcpConn, ok := c.conn.(*net.TCPConn); ok {
|
||||
return tcpConn.SetWriteBuffer(bytes)
|
||||
} else {
|
||||
return ErrNotTCPConn
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) readPayloadLen(length byte, buf []byte) (payloadLen uint64, err error) {
|
||||
if length < 126 {
|
||||
payloadLen = uint64(length)
|
||||
} else if length == 126 {
|
||||
err = c.read(buf[:2])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
payloadLen = uint64(binary.BigEndian.Uint16(buf[:2]))
|
||||
} else if length == 127 {
|
||||
err = c.read(buf[:8])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
payloadLen = uint64(binary.BigEndian.Uint16(buf[:8]))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) readFrame(buf []byte) (opcode byte, messsage []byte, err error) {
|
||||
//minimum head may 2 byte
|
||||
|
||||
err = c.read(buf[:2])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
opcode = buf[0]
|
||||
|
||||
if opcode&0x70 > 0 {
|
||||
err = ErrRSVNotSupport
|
||||
return
|
||||
}
|
||||
|
||||
//isMasking := (0x80 & buf[1]) > 0
|
||||
isMasking := (0x80 & buf[1]) > 0
|
||||
|
||||
var payloadLen uint64
|
||||
payloadLen, err = c.readPayloadLen(buf[1]&0x7F, buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if opcode&0x08 > 0 && payloadLen > 125 {
|
||||
err = ErrControlTooLong
|
||||
return
|
||||
}
|
||||
|
||||
var masking []byte
|
||||
|
||||
if isMasking {
|
||||
err = c.read(buf[:4])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
masking = buf[:4]
|
||||
}
|
||||
|
||||
messsage = make([]byte, payloadLen)
|
||||
err = c.read(messsage)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if isMasking {
|
||||
//maskingKey := c.newMaskingKey()
|
||||
c.maskingData(messsage, masking)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) sendFrame(opcode byte, message []byte) error {
|
||||
//max frame header may 14 length
|
||||
buf := make([]byte, 0, len(message)+14)
|
||||
//here we don not support continue frame, all are final
|
||||
opcode |= 0x80
|
||||
|
||||
if opcode&0x08 > 0 && len(message) >= 126 {
|
||||
return ErrControlTooLong
|
||||
}
|
||||
|
||||
buf = append(buf, opcode)
|
||||
|
||||
//no mask, because chrome may not support
|
||||
var mask byte = 0x00
|
||||
|
||||
if !c.isServer {
|
||||
//for client, we will mask data
|
||||
mask = 0x80
|
||||
}
|
||||
|
||||
payloadLen := len(message)
|
||||
|
||||
if payloadLen < 126 {
|
||||
buf = append(buf, mask|byte(payloadLen))
|
||||
} else if payloadLen <= 0xFFFF {
|
||||
buf = append(buf, mask|byte(126), 0, 0)
|
||||
|
||||
binary.BigEndian.PutUint16(buf[len(buf)-2:], uint16(payloadLen))
|
||||
} else {
|
||||
buf = append(buf, mask|byte(127), 0, 0, 0, 0, 0, 0, 0, 0)
|
||||
|
||||
binary.BigEndian.PutUint64(buf[len(buf)-8:], uint64(payloadLen))
|
||||
}
|
||||
|
||||
if !c.isServer {
|
||||
maskingKey := c.newMaskingKey()
|
||||
buf = append(buf, maskingKey...)
|
||||
|
||||
pos := len(buf)
|
||||
buf = append(buf, message...)
|
||||
|
||||
c.maskingData(buf[pos:], maskingKey)
|
||||
|
||||
} else {
|
||||
buf = append(buf, message...)
|
||||
}
|
||||
|
||||
tmpBuf := buf
|
||||
for i := 0; i < 3; i++ {
|
||||
n, err := c.conn.Write(tmpBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == len(tmpBuf) {
|
||||
return nil
|
||||
} else {
|
||||
log.Warn("[conn write] buffer_size=%d return_size=%s", len(tmpBuf), n)
|
||||
tmpBuf = tmpBuf[n:]
|
||||
}
|
||||
}
|
||||
return ErrWriteError
|
||||
}
|
||||
|
||||
func (c *Conn) read(buf []byte) error {
|
||||
var err error
|
||||
for len(buf) > 0 && err == nil {
|
||||
var nn int
|
||||
nn, err = c.br.Read(buf)
|
||||
buf = buf[nn:]
|
||||
}
|
||||
if err == io.EOF {
|
||||
if len(buf) == 0 {
|
||||
err = nil
|
||||
} else {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) maskingData(data []byte, maskingKey []byte) {
|
||||
for i := range data {
|
||||
data[i] ^= maskingKey[i%4]
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) newMaskingKey() []byte {
|
||||
n := rand.Uint32()
|
||||
return []byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 32)}
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidMethod = errors.New("Only GET Supported")
|
||||
ErrInvalidVersion = errors.New("Sec-Websocket-Version: 13")
|
||||
ErrInvalidUpgrade = errors.New("Can \"Upgrade\" only to \"WebSocket\"")
|
||||
ErrInvalidConnection = errors.New("\"Connection\" must be \"Upgrade\"")
|
||||
ErrMissingKey = errors.New("Missing Key")
|
||||
ErrHijacker = errors.New("Not implement http.Hijacker")
|
||||
ErrNoEmptyConn = errors.New("Conn ReadBuf must be empty")
|
||||
)
|
||||
|
||||
func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
|
||||
if r.Method != "GET" {
|
||||
return nil, ErrInvalidMethod
|
||||
}
|
||||
|
||||
if r.Header.Get("Sec-Websocket-Version") != "13" {
|
||||
return nil, ErrInvalidVersion
|
||||
}
|
||||
|
||||
if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {
|
||||
return nil, ErrInvalidUpgrade
|
||||
}
|
||||
|
||||
if strings.ToLower(r.Header.Get("Connection")) != "upgrade" {
|
||||
return nil, ErrInvalidConnection
|
||||
}
|
||||
|
||||
var acceptKey string
|
||||
|
||||
if key := r.Header.Get("Sec-Websocket-key"); len(key) == 0 {
|
||||
return nil, ErrMissingKey
|
||||
} else {
|
||||
acceptKey = calcAcceptKey(key)
|
||||
}
|
||||
|
||||
var (
|
||||
netConn net.Conn
|
||||
br *bufio.Reader
|
||||
err error
|
||||
)
|
||||
|
||||
h, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, ErrHijacker
|
||||
}
|
||||
|
||||
var rw *bufio.ReadWriter
|
||||
netConn, rw, err = h.Hijack()
|
||||
br = rw.Reader
|
||||
|
||||
if br.Buffered() > 0 {
|
||||
netConn.Close()
|
||||
return nil, ErrNoEmptyConn
|
||||
}
|
||||
|
||||
c := NewConn(netConn, true)
|
||||
|
||||
buf := bytes.NewBufferString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ")
|
||||
|
||||
buf.WriteString(acceptKey)
|
||||
buf.WriteString("\r\n")
|
||||
|
||||
subProtol := selectSubProtocol(r)
|
||||
if len(subProtol) > 0 {
|
||||
buf.WriteString("Sec-Websocket-Protocol: ")
|
||||
buf.WriteString(subProtol)
|
||||
buf.WriteString("\r\n")
|
||||
}
|
||||
|
||||
for k, vs := range responseHeader {
|
||||
for _, v := range vs {
|
||||
buf.WriteString(k)
|
||||
buf.WriteString(": ")
|
||||
buf.WriteString(v)
|
||||
buf.WriteString("\r\n")
|
||||
}
|
||||
}
|
||||
buf.WriteString("\r\n")
|
||||
|
||||
if _, err = netConn.Write(buf.Bytes()); err != nil {
|
||||
netConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func selectSubProtocol(r *http.Request) string {
|
||||
h := r.Header.Get("Sec-Websocket-Protocol")
|
||||
if len(h) == 0 {
|
||||
return ""
|
||||
}
|
||||
return strings.Split(h, ",")[0]
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||
|
||||
func calcAcceptKey(key string) string {
|
||||
h := sha1.New()
|
||||
h.Write([]byte(key))
|
||||
h.Write(keyGUID)
|
||||
return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func calcKey() (string, error) {
|
||||
p := make([]byte, 16)
|
||||
if _, err := io.ReadFull(rand.Reader, p); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(p), nil
|
||||
}
|
Loading…
Reference in New Issue