diff --git a/websocket/client.go b/websocket/client.go new file mode 100644 index 0000000..6860fa8 --- /dev/null +++ b/websocket/client.go @@ -0,0 +1,60 @@ +package websocket + +import ( + "bytes" + "errors" + "net" + "net/http" + "net/url" + "strings" +) + +var ( + ErrBadHandshake = errors.New("bad handshake") +) + +func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header) (c *Conn, response *http.Response, err error) { + key, err := calcKey() + if err != nil { + return nil, nil, err + } + acceptKey := calcAcceptKey(key) + + c = NewConn(netConn, false) + + buf := bytes.NewBufferString("GET ") + buf.WriteString(u.RequestURI()) + buf.WriteString(" HTTP/1.1\r\nHost: ") + buf.WriteString(u.Host) + buf.WriteString("\r\nUpgrade: websocket\r\nConnection: upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: ") + buf.WriteString(key) + buf.WriteString("\r\n") + + for k, vs := range requestHeader { + for _, v := range vs { + buf.WriteString(k) + buf.WriteString(": ") + buf.WriteString(v) + buf.WriteString("\r\n") + } + } + + buf.WriteString("\r\n") + p := buf.Bytes() + if _, err := netConn.Write(p); err != nil { + return nil, nil, err + } + + resp, err := http.ReadResponse(c.br, &http.Request{Method: "GET", URL: u}) + if err != nil { + return nil, nil, err + } + + if resp.StatusCode != 101 || + !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || + !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + resp.Header.Get("Sec-Websocket-Accept") != acceptKey { + return nil, resp, ErrBadHandshake + } + return c, resp, nil +} diff --git a/websocket/conn.go b/websocket/conn.go new file mode 100644 index 0000000..7519e23 --- /dev/null +++ b/websocket/conn.go @@ -0,0 +1,323 @@ +package websocket + +import ( + "bufio" + "encoding/binary" + "errors" + "io" + "lib/log" + "math/rand" + "net" + "time" +) + +//refer RFC6455 + +const ( + TextMessage byte = 1 + BinaryMessage byte = 2 + CloseMessage byte = 8 + PingMessage byte = 9 + PongMessage byte = 10 +) + +var ( + ErrControlTooLong = errors.New("control message too long") + ErrRSVNotSupport = errors.New("reserved bit not support") + ErrPayloadError = errors.New("payload length error") + ErrControlFragmented = errors.New("control message can not be fragmented") + ErrNotTCPConn = errors.New("not a tcp connection") + ErrWriteError = errors.New("write error") +) + +type Conn struct { + conn net.Conn + + br *bufio.Reader + + isServer bool +} + +func NewConn(conn net.Conn, isServer bool) *Conn { + c := new(Conn) + + c.conn = conn + + c.br = bufio.NewReader(conn) + + c.isServer = isServer + + return c +} + +func (c *Conn) ReadMessage() (messageType byte, message []byte, err error) { + return c.Read() +} + +func (c *Conn) Read() (messageType byte, message []byte, err error) { + buf := make([]byte, 8, 8) + + message = []byte{} + + messageType = 0 + + for { + opcode, data, err := c.readFrame(buf) + + if err != nil { + return messageType, message, err + } + + message = append(message, data...) + + if opcode&0x80 != 0 { + //final + if opcode&0x0F > 0 { + //not continue frame + messageType = opcode & 0x0F + } + return messageType, message, nil + + } else { + if opcode&0x0F > 0 { + //first continue frame + messageType = opcode & 0x0F + } + } + } + + return +} + +func (c *Conn) Write(message []byte, binary bool) error { + if binary { + return c.sendFrame(BinaryMessage, message) + } else { + return c.sendFrame(TextMessage, message) + } +} + +func (c *Conn) WriteMessage(messageType byte, message []byte) error { + return c.sendFrame(messageType, message) +} + +//write utf-8 text message +func (c *Conn) WriteString(message []byte) error { + return c.Write(message, false) +} + +//write binary message +func (c *Conn) WriteBinary(message []byte) error { + return c.Write(message, true) +} + +func (c *Conn) Ping(message []byte) error { + return c.sendFrame(PingMessage, message) +} + +func (c *Conn) Pong(message []byte) error { + return c.sendFrame(PongMessage, message) +} + +//close socket, not send websocket close message +func (c *Conn) Close() error { + return c.conn.Close() +} + +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +func (c *Conn) SetReadBuffer(bytes int) error { + if tcpConn, ok := c.conn.(*net.TCPConn); ok { + return tcpConn.SetReadBuffer(bytes) + } else { + return ErrNotTCPConn + } +} + +func (c *Conn) SetWriteBuffer(bytes int) error { + if tcpConn, ok := c.conn.(*net.TCPConn); ok { + return tcpConn.SetWriteBuffer(bytes) + } else { + return ErrNotTCPConn + } +} + +func (c *Conn) readPayloadLen(length byte, buf []byte) (payloadLen uint64, err error) { + if length < 126 { + payloadLen = uint64(length) + } else if length == 126 { + err = c.read(buf[:2]) + if err != nil { + return + } + payloadLen = uint64(binary.BigEndian.Uint16(buf[:2])) + } else if length == 127 { + err = c.read(buf[:8]) + if err != nil { + return + } + payloadLen = uint64(binary.BigEndian.Uint16(buf[:8])) + } + + return +} + +func (c *Conn) readFrame(buf []byte) (opcode byte, messsage []byte, err error) { + //minimum head may 2 byte + + err = c.read(buf[:2]) + if err != nil { + return + } + + opcode = buf[0] + + if opcode&0x70 > 0 { + err = ErrRSVNotSupport + return + } + + //isMasking := (0x80 & buf[1]) > 0 + isMasking := (0x80 & buf[1]) > 0 + + var payloadLen uint64 + payloadLen, err = c.readPayloadLen(buf[1]&0x7F, buf) + if err != nil { + return + } + + if opcode&0x08 > 0 && payloadLen > 125 { + err = ErrControlTooLong + return + } + + var masking []byte + + if isMasking { + err = c.read(buf[:4]) + if err != nil { + return + } + + masking = buf[:4] + } + + messsage = make([]byte, payloadLen) + err = c.read(messsage) + + if err != nil { + return + } + + if isMasking { + //maskingKey := c.newMaskingKey() + c.maskingData(messsage, masking) + } + + return +} + +func (c *Conn) sendFrame(opcode byte, message []byte) error { + //max frame header may 14 length + buf := make([]byte, 0, len(message)+14) + //here we don not support continue frame, all are final + opcode |= 0x80 + + if opcode&0x08 > 0 && len(message) >= 126 { + return ErrControlTooLong + } + + buf = append(buf, opcode) + + //no mask, because chrome may not support + var mask byte = 0x00 + + if !c.isServer { + //for client, we will mask data + mask = 0x80 + } + + payloadLen := len(message) + + if payloadLen < 126 { + buf = append(buf, mask|byte(payloadLen)) + } else if payloadLen <= 0xFFFF { + buf = append(buf, mask|byte(126), 0, 0) + + binary.BigEndian.PutUint16(buf[len(buf)-2:], uint16(payloadLen)) + } else { + buf = append(buf, mask|byte(127), 0, 0, 0, 0, 0, 0, 0, 0) + + binary.BigEndian.PutUint64(buf[len(buf)-8:], uint64(payloadLen)) + } + + if !c.isServer { + maskingKey := c.newMaskingKey() + buf = append(buf, maskingKey...) + + pos := len(buf) + buf = append(buf, message...) + + c.maskingData(buf[pos:], maskingKey) + + } else { + buf = append(buf, message...) + } + + tmpBuf := buf + for i := 0; i < 3; i++ { + n, err := c.conn.Write(tmpBuf) + if err != nil { + return err + } + if n == len(tmpBuf) { + return nil + } else { + log.Warn("[conn write] buffer_size=%d return_size=%s", len(tmpBuf), n) + tmpBuf = tmpBuf[n:] + } + } + return ErrWriteError +} + +func (c *Conn) read(buf []byte) error { + var err error + for len(buf) > 0 && err == nil { + var nn int + nn, err = c.br.Read(buf) + buf = buf[nn:] + } + if err == io.EOF { + if len(buf) == 0 { + err = nil + } else { + err = io.ErrUnexpectedEOF + } + } + return err +} + +func (c *Conn) maskingData(data []byte, maskingKey []byte) { + for i := range data { + data[i] ^= maskingKey[i%4] + } +} + +func (c *Conn) newMaskingKey() []byte { + n := rand.Uint32() + return []byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 32)} +} diff --git a/websocket/server.go b/websocket/server.go new file mode 100644 index 0000000..ab09539 --- /dev/null +++ b/websocket/server.go @@ -0,0 +1,105 @@ +package websocket + +import ( + "bufio" + "bytes" + "errors" + "net" + "net/http" + "strings" +) + +var ( + ErrInvalidMethod = errors.New("Only GET Supported") + ErrInvalidVersion = errors.New("Sec-Websocket-Version: 13") + ErrInvalidUpgrade = errors.New("Can \"Upgrade\" only to \"WebSocket\"") + ErrInvalidConnection = errors.New("\"Connection\" must be \"Upgrade\"") + ErrMissingKey = errors.New("Missing Key") + ErrHijacker = errors.New("Not implement http.Hijacker") + ErrNoEmptyConn = errors.New("Conn ReadBuf must be empty") +) + +func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { + if r.Method != "GET" { + return nil, ErrInvalidMethod + } + + if r.Header.Get("Sec-Websocket-Version") != "13" { + return nil, ErrInvalidVersion + } + + if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" { + return nil, ErrInvalidUpgrade + } + + if strings.ToLower(r.Header.Get("Connection")) != "upgrade" { + return nil, ErrInvalidConnection + } + + var acceptKey string + + if key := r.Header.Get("Sec-Websocket-key"); len(key) == 0 { + return nil, ErrMissingKey + } else { + acceptKey = calcAcceptKey(key) + } + + var ( + netConn net.Conn + br *bufio.Reader + err error + ) + + h, ok := w.(http.Hijacker) + if !ok { + return nil, ErrHijacker + } + + var rw *bufio.ReadWriter + netConn, rw, err = h.Hijack() + br = rw.Reader + + if br.Buffered() > 0 { + netConn.Close() + return nil, ErrNoEmptyConn + } + + c := NewConn(netConn, true) + + buf := bytes.NewBufferString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ") + + buf.WriteString(acceptKey) + buf.WriteString("\r\n") + + subProtol := selectSubProtocol(r) + if len(subProtol) > 0 { + buf.WriteString("Sec-Websocket-Protocol: ") + buf.WriteString(subProtol) + buf.WriteString("\r\n") + } + + for k, vs := range responseHeader { + for _, v := range vs { + buf.WriteString(k) + buf.WriteString(": ") + buf.WriteString(v) + buf.WriteString("\r\n") + } + } + buf.WriteString("\r\n") + + if _, err = netConn.Write(buf.Bytes()); err != nil { + netConn.Close() + return nil, err + } + + return c, nil +} + +func selectSubProtocol(r *http.Request) string { + h := r.Header.Get("Sec-Websocket-Protocol") + if len(h) == 0 { + return "" + } + return strings.Split(h, ",")[0] +} diff --git a/websocket/util.go b/websocket/util.go new file mode 100644 index 0000000..5e156cb --- /dev/null +++ b/websocket/util.go @@ -0,0 +1,26 @@ +package websocket + +import ( + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "errors" + "io" +) + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func calcAcceptKey(key string) string { + h := sha1.New() + h.Write([]byte(key)) + h.Write(keyGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func calcKey() (string, error) { + p := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, p); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(p), nil +}