diff --git a/bson/decode.go b/bson/decode.go index f1c8b4f..fc991b7 100644 --- a/bson/decode.go +++ b/bson/decode.go @@ -1,18 +1,18 @@ // BSON library for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE diff --git a/bson/encode.go b/bson/encode.go index 6544748..036a136 100644 --- a/bson/encode.go +++ b/bson/encode.go @@ -1,18 +1,18 @@ // BSON library for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -182,7 +182,7 @@ func isZero(v reflect.Value) bool { if v.Type() == typeTime { return v.Interface().(time.Time).IsZero() } - for i := v.NumField()-1; i >= 0; i-- { + for i := v.NumField() - 1; i >= 0; i-- { if !isZero(v.Field(i)) { return false } @@ -207,7 +207,7 @@ func (e *encoder) addSlice(v reflect.Value) { return } l := v.Len() - et := v.Type().Elem() + et := v.Type().Elem() if et == typeDocElem { for i := 0; i < l; i++ { elem := v.Index(i).Interface().(DocElem) @@ -401,7 +401,7 @@ func (e *encoder) addElem(name string, v reflect.Value, minSize bool) { case time.Time: // MongoDB handles timestamps as milliseconds. e.addElemName('\x09', name) - e.addInt64(s.Unix() * 1000 + int64(s.Nanosecond() / 1e6)) + e.addInt64(s.Unix()*1000 + int64(s.Nanosecond()/1e6)) case url.URL: e.addElemName('\x02', name) diff --git a/log/filehandler.go b/log/filehandler.go index 783b652..308896a 100644 --- a/log/filehandler.go +++ b/log/filehandler.go @@ -36,7 +36,7 @@ func (h *FileHandler) Close() error { return h.fd.Close() } -//RotatingFileHandler writes log a file, if file size exceeds maxBytes, +//RotatingFileHandler writes log a file, if file size exceeds maxBytes, //it will backup current file and open a new one. // //max backup file number is set by backupCount, it will delete oldest if backups too many. @@ -112,7 +112,7 @@ func (h *RotatingFileHandler) doRollover() { } } -//TimeRotatingFileHandler writes log to a file, +//TimeRotatingFileHandler writes log to a file, //it will backup current and open a new one, with a period time you sepecified. // //refer: http://docs.python.org/2/library/logging.handlers.html. diff --git a/log/handler.go b/log/handler.go index 352e30c..4dc086f 100644 --- a/log/handler.go +++ b/log/handler.go @@ -31,8 +31,7 @@ func (h *StreamHandler) Close() error { return nil } - -//NullHandler does nothing, it discards anything. +//NullHandler does nothing, it discards anything. type NullHandler struct { } diff --git a/log/sockethandler.go b/log/sockethandler.go index ad81ccd..3e7494d 100644 --- a/log/sockethandler.go +++ b/log/sockethandler.go @@ -7,8 +7,8 @@ import ( ) //SocketHandler writes log to a connectionl. -//Network protocol is simple: log length + log | log length + log. log length is uint32, bigendian. -//you must implement your own log server, maybe you can use logd instead simply. +//Network protocol is simple: log length + log | log length + log. log length is uint32, bigendian. +//you must implement your own log server, maybe you can use logd instead simply. type SocketHandler struct { c net.Conn protocol string diff --git a/websocket/client.go b/websocket/client.go index baa92b5..6860fa8 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -1,60 +1,60 @@ -package websocket - -import ( - "bytes" - "errors" - "net" - "net/http" - "net/url" - "strings" -) - -var ( - ErrBadHandshake = errors.New("bad handshake") -) - -func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header) (c *Conn, response *http.Response, err error) { - key, err := calcKey() - if err != nil { - return nil, nil, err - } - acceptKey := calcAcceptKey(key) - - c = NewConn(netConn, false) - - buf := bytes.NewBufferString("GET ") - buf.WriteString(u.RequestURI()) - buf.WriteString(" HTTP/1.1\r\nHost: ") - buf.WriteString(u.Host) - buf.WriteString("\r\nUpgrade: websocket\r\nConnection: upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: ") - buf.WriteString(key) - buf.WriteString("\r\n") - - for k, vs := range requestHeader { - for _, v := range vs { - buf.WriteString(k) - buf.WriteString(": ") - buf.WriteString(v) - buf.WriteString("\r\n") - } - } - - buf.WriteString("\r\n") - p := buf.Bytes() - if _, err := netConn.Write(p); err != nil { - return nil, nil, err - } - - resp, err := http.ReadResponse(c.br, &http.Request{Method: "GET", URL: u}) - if err != nil { - return nil, nil, err - } - - if resp.StatusCode != 101 || - !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || - !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || - resp.Header.Get("Sec-Websocket-Accept") != acceptKey { - return nil, resp, ErrBadHandshake - } - return c, resp, nil -} +package websocket + +import ( + "bytes" + "errors" + "net" + "net/http" + "net/url" + "strings" +) + +var ( + ErrBadHandshake = errors.New("bad handshake") +) + +func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header) (c *Conn, response *http.Response, err error) { + key, err := calcKey() + if err != nil { + return nil, nil, err + } + acceptKey := calcAcceptKey(key) + + c = NewConn(netConn, false) + + buf := bytes.NewBufferString("GET ") + buf.WriteString(u.RequestURI()) + buf.WriteString(" HTTP/1.1\r\nHost: ") + buf.WriteString(u.Host) + buf.WriteString("\r\nUpgrade: websocket\r\nConnection: upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: ") + buf.WriteString(key) + buf.WriteString("\r\n") + + for k, vs := range requestHeader { + for _, v := range vs { + buf.WriteString(k) + buf.WriteString(": ") + buf.WriteString(v) + buf.WriteString("\r\n") + } + } + + buf.WriteString("\r\n") + p := buf.Bytes() + if _, err := netConn.Write(p); err != nil { + return nil, nil, err + } + + resp, err := http.ReadResponse(c.br, &http.Request{Method: "GET", URL: u}) + if err != nil { + return nil, nil, err + } + + if resp.StatusCode != 101 || + !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || + !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + resp.Header.Get("Sec-Websocket-Accept") != acceptKey { + return nil, resp, ErrBadHandshake + } + return c, resp, nil +} diff --git a/websocket/client_test.go b/websocket/client_test.go index 9b58959..12d25b9 100644 --- a/websocket/client_test.go +++ b/websocket/client_test.go @@ -1,99 +1,100 @@ -package websocket - -import ( - "github.com/gorilla/websocket" - "net" - "net/http" - "net/url" - "testing" - "time" -) - -func TestWSClient(t *testing.T) { - http.HandleFunc("/test/client", func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Upgrade(w, r, nil, 1024, 1024) - if err != nil { - t.Fatal(err.Error()) - } - - msgType, msg, err := conn.ReadMessage() - conn.WriteMessage(websocket.TextMessage, msg) - - if err != nil { - t.Fatal(err.Error()) - } - - if msgType != websocket.TextMessage { - t.Fatal("invalid msg type", msgType) - } - - msgType, msg, err = conn.ReadMessage() - if err != nil { - t.Fatal(err.Error()) - } - - if msgType != websocket.PingMessage { - t.Fatal("invalid msg type", msgType) - } - - conn.WriteMessage(websocket.PongMessage, []byte{}) - - conn.WriteMessage(websocket.PingMessage, []byte{}) - - msgType, msg, err = conn.ReadMessage() - if err != nil { - t.Fatal(err.Error()) - } - println(msgType) - if msgType != websocket.PongMessage { - - t.Fatal("invalid msg type", msgType) - } - }) - - go http.ListenAndServe(":65500", nil) - - time.Sleep(time.Second * 1) - - conn, err := net.Dial("tcp", "127.0.0.1:65500") - - if err != nil { - t.Fatal(err.Error()) - } - ws, _, err := NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/client"}, nil) - - if err != nil { - t.Fatal(err.Error()) - } - - payload := make([]byte, 4*1024) - for i := 0; i < 4*1024; i++ { - payload[i] = 'x' - } - - ws.WriteString(payload) - - msgType, msg, err := ws.Read() - if err != nil { - t.Fatal(err.Error()) - } - if msgType != TextMessage { - t.Fatal("invalid msg type", msgType) - } - - if string(msg) != string(payload) { - t.Fatal("invalid msg", string(msg)) - - } - - //test ping - ws.Ping([]byte{}) - msgType, msg, err = ws.ReadMessage() - if err != nil { - t.Fatal(err.Error()) - } - if msgType != PongMessage { - t.Fatal("invalid msg type", msgType) - } - -} +package websocket + +import ( + "net" + "net/http" + "net/url" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +func TestWSClient(t *testing.T) { + http.HandleFunc("/test/client", func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Upgrade(w, r, nil, 1024, 1024) + if err != nil { + t.Fatal(err.Error()) + } + + msgType, msg, err := conn.ReadMessage() + conn.WriteMessage(websocket.TextMessage, msg) + + if err != nil { + t.Fatal(err.Error()) + } + + if msgType != websocket.TextMessage { + t.Fatal("invalid msg type", msgType) + } + + msgType, msg, err = conn.ReadMessage() + if err != nil { + t.Fatal(err.Error()) + } + + if msgType != websocket.PingMessage { + t.Fatal("invalid msg type", msgType) + } + + conn.WriteMessage(websocket.PongMessage, []byte{}) + + conn.WriteMessage(websocket.PingMessage, []byte{}) + + msgType, msg, err = conn.ReadMessage() + if err != nil { + t.Fatal(err.Error()) + } + println(msgType) + if msgType != websocket.PongMessage { + + t.Fatal("invalid msg type", msgType) + } + }) + + go http.ListenAndServe(":65500", nil) + + time.Sleep(time.Second * 1) + + conn, err := net.Dial("tcp", "127.0.0.1:65500") + + if err != nil { + t.Fatal(err.Error()) + } + ws, _, err := NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/client"}, nil) + + if err != nil { + t.Fatal(err.Error()) + } + + payload := make([]byte, 4*1024) + for i := 0; i < 4*1024; i++ { + payload[i] = 'x' + } + + ws.WriteString(payload) + + msgType, msg, err := ws.Read() + if err != nil { + t.Fatal(err.Error()) + } + if msgType != TextMessage { + t.Fatal("invalid msg type", msgType) + } + + if string(msg) != string(payload) { + t.Fatal("invalid msg", string(msg)) + + } + + //test ping + ws.Ping([]byte{}) + msgType, msg, err = ws.ReadMessage() + if err != nil { + t.Fatal(err.Error()) + } + if msgType != PongMessage { + t.Fatal("invalid msg type", msgType) + } + +} diff --git a/websocket/conn.go b/websocket/conn.go index 05d63eb..264ce61 100644 --- a/websocket/conn.go +++ b/websocket/conn.go @@ -1,321 +1,321 @@ -package websocket - -import ( - "bufio" - "encoding/binary" - "errors" - "io" - "math/rand" - "net" - "time" -) - -//refer RFC6455 - -const ( - TextMessage byte = 1 - BinaryMessage byte = 2 - CloseMessage byte = 8 - PingMessage byte = 9 - PongMessage byte = 10 -) - -var ( - ErrControlTooLong = errors.New("control message too long") - ErrRSVNotSupport = errors.New("reserved bit not support") - ErrPayloadError = errors.New("payload length error") - ErrControlFragmented = errors.New("control message can not be fragmented") - ErrNotTCPConn = errors.New("not a tcp connection") - ErrWriteError = errors.New("write error") -) - -type Conn struct { - conn net.Conn - - br *bufio.Reader - - isServer bool -} - -func NewConn(conn net.Conn, isServer bool) *Conn { - c := new(Conn) - - c.conn = conn - - c.br = bufio.NewReader(conn) - - c.isServer = isServer - - return c -} - -func (c *Conn) ReadMessage() (messageType byte, message []byte, err error) { - return c.Read() -} - -func (c *Conn) Read() (messageType byte, message []byte, err error) { - buf := make([]byte, 8, 8) - - message = []byte{} - - messageType = 0 - - for { - opcode, data, err := c.readFrame(buf) - - if err != nil { - return messageType, message, err - } - - message = append(message, data...) - - if opcode&0x80 != 0 { - //final - if opcode&0x0F > 0 { - //not continue frame - messageType = opcode & 0x0F - } - return messageType, message, nil - - } else { - if opcode&0x0F > 0 { - //first continue frame - messageType = opcode & 0x0F - } - } - } - - return -} - -func (c *Conn) Write(message []byte, binary bool) error { - if binary { - return c.sendFrame(BinaryMessage, message) - } else { - return c.sendFrame(TextMessage, message) - } -} - -func (c *Conn) WriteMessage(messageType byte, message []byte) error { - return c.sendFrame(messageType, message) -} - -//write utf-8 text message -func (c *Conn) WriteString(message []byte) error { - return c.Write(message, false) -} - -//write binary message -func (c *Conn) WriteBinary(message []byte) error { - return c.Write(message, true) -} - -func (c *Conn) Ping(message []byte) error { - return c.sendFrame(PingMessage, message) -} - -func (c *Conn) Pong(message []byte) error { - return c.sendFrame(PongMessage, message) -} - -//close socket, not send websocket close message -func (c *Conn) Close() error { - return c.conn.Close() -} - -func (c *Conn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *Conn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -func (c *Conn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -func (c *Conn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) -} - -func (c *Conn) SetReadBuffer(bytes int) error { - if tcpConn, ok := c.conn.(*net.TCPConn); ok { - return tcpConn.SetReadBuffer(bytes) - } else { - return ErrNotTCPConn - } -} - -func (c *Conn) SetWriteBuffer(bytes int) error { - if tcpConn, ok := c.conn.(*net.TCPConn); ok { - return tcpConn.SetWriteBuffer(bytes) - } else { - return ErrNotTCPConn - } -} - -func (c *Conn) readPayloadLen(length byte, buf []byte) (payloadLen uint64, err error) { - if length < 126 { - payloadLen = uint64(length) - } else if length == 126 { - err = c.read(buf[:2]) - if err != nil { - return - } - payloadLen = uint64(binary.BigEndian.Uint16(buf[:2])) - } else if length == 127 { - err = c.read(buf[:8]) - if err != nil { - return - } - payloadLen = uint64(binary.BigEndian.Uint16(buf[:8])) - } - - return -} - -func (c *Conn) readFrame(buf []byte) (opcode byte, messsage []byte, err error) { - //minimum head may 2 byte - - err = c.read(buf[:2]) - if err != nil { - return - } - - opcode = buf[0] - - if opcode&0x70 > 0 { - err = ErrRSVNotSupport - return - } - - //isMasking := (0x80 & buf[1]) > 0 - isMasking := (0x80 & buf[1]) > 0 - - var payloadLen uint64 - payloadLen, err = c.readPayloadLen(buf[1]&0x7F, buf) - if err != nil { - return - } - - if opcode&0x08 > 0 && payloadLen > 125 { - err = ErrControlTooLong - return - } - - var masking []byte - - if isMasking { - err = c.read(buf[:4]) - if err != nil { - return - } - - masking = buf[:4] - } - - messsage = make([]byte, payloadLen) - err = c.read(messsage) - - if err != nil { - return - } - - if isMasking { - //maskingKey := c.newMaskingKey() - c.maskingData(messsage, masking) - } - - return -} - -func (c *Conn) sendFrame(opcode byte, message []byte) error { - //max frame header may 14 length - buf := make([]byte, 0, len(message)+14) - //here we don not support continue frame, all are final - opcode |= 0x80 - - if opcode&0x08 > 0 && len(message) >= 126 { - return ErrControlTooLong - } - - buf = append(buf, opcode) - - //no mask, because chrome may not support - var mask byte = 0x00 - - if !c.isServer { - //for client, we will mask data - mask = 0x80 - } - - payloadLen := len(message) - - if payloadLen < 126 { - buf = append(buf, mask|byte(payloadLen)) - } else if payloadLen <= 0xFFFF { - buf = append(buf, mask|byte(126), 0, 0) - - binary.BigEndian.PutUint16(buf[len(buf)-2:], uint16(payloadLen)) - } else { - buf = append(buf, mask|byte(127), 0, 0, 0, 0, 0, 0, 0, 0) - - binary.BigEndian.PutUint64(buf[len(buf)-8:], uint64(payloadLen)) - } - - if !c.isServer { - maskingKey := c.newMaskingKey() - buf = append(buf, maskingKey...) - - pos := len(buf) - buf = append(buf, message...) - - c.maskingData(buf[pos:], maskingKey) - - } else { - buf = append(buf, message...) - } - - tmpBuf := buf - for i := 0; i < 3; i++ { - n, err := c.conn.Write(tmpBuf) - if err != nil { - return err - } - if n == len(tmpBuf) { - return nil - } else { - tmpBuf = tmpBuf[n:] - } - } - return ErrWriteError -} - -func (c *Conn) read(buf []byte) error { - var err error - for len(buf) > 0 && err == nil { - var nn int - nn, err = c.br.Read(buf) - buf = buf[nn:] - } - if err == io.EOF { - if len(buf) == 0 { - err = nil - } else { - err = io.ErrUnexpectedEOF - } - } - return err -} - -func (c *Conn) maskingData(data []byte, maskingKey []byte) { - for i := range data { - data[i] ^= maskingKey[i%4] - } -} - -func (c *Conn) newMaskingKey() []byte { - n := rand.Uint32() - return []byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 32)} -} +package websocket + +import ( + "bufio" + "encoding/binary" + "errors" + "io" + "math/rand" + "net" + "time" +) + +//refer RFC6455 + +const ( + TextMessage byte = 1 + BinaryMessage byte = 2 + CloseMessage byte = 8 + PingMessage byte = 9 + PongMessage byte = 10 +) + +var ( + ErrControlTooLong = errors.New("control message too long") + ErrRSVNotSupport = errors.New("reserved bit not support") + ErrPayloadError = errors.New("payload length error") + ErrControlFragmented = errors.New("control message can not be fragmented") + ErrNotTCPConn = errors.New("not a tcp connection") + ErrWriteError = errors.New("write error") +) + +type Conn struct { + conn net.Conn + + br *bufio.Reader + + isServer bool +} + +func NewConn(conn net.Conn, isServer bool) *Conn { + c := new(Conn) + + c.conn = conn + + c.br = bufio.NewReader(conn) + + c.isServer = isServer + + return c +} + +func (c *Conn) ReadMessage() (messageType byte, message []byte, err error) { + return c.Read() +} + +func (c *Conn) Read() (messageType byte, message []byte, err error) { + buf := make([]byte, 8, 8) + + message = []byte{} + + messageType = 0 + + for { + opcode, data, err := c.readFrame(buf) + + if err != nil { + return messageType, message, err + } + + message = append(message, data...) + + if opcode&0x80 != 0 { + //final + if opcode&0x0F > 0 { + //not continue frame + messageType = opcode & 0x0F + } + return messageType, message, nil + + } else { + if opcode&0x0F > 0 { + //first continue frame + messageType = opcode & 0x0F + } + } + } + + return +} + +func (c *Conn) Write(message []byte, binary bool) error { + if binary { + return c.sendFrame(BinaryMessage, message) + } else { + return c.sendFrame(TextMessage, message) + } +} + +func (c *Conn) WriteMessage(messageType byte, message []byte) error { + return c.sendFrame(messageType, message) +} + +//write utf-8 text message +func (c *Conn) WriteString(message []byte) error { + return c.Write(message, false) +} + +//write binary message +func (c *Conn) WriteBinary(message []byte) error { + return c.Write(message, true) +} + +func (c *Conn) Ping(message []byte) error { + return c.sendFrame(PingMessage, message) +} + +func (c *Conn) Pong(message []byte) error { + return c.sendFrame(PongMessage, message) +} + +//close socket, not send websocket close message +func (c *Conn) Close() error { + return c.conn.Close() +} + +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +func (c *Conn) SetReadBuffer(bytes int) error { + if tcpConn, ok := c.conn.(*net.TCPConn); ok { + return tcpConn.SetReadBuffer(bytes) + } else { + return ErrNotTCPConn + } +} + +func (c *Conn) SetWriteBuffer(bytes int) error { + if tcpConn, ok := c.conn.(*net.TCPConn); ok { + return tcpConn.SetWriteBuffer(bytes) + } else { + return ErrNotTCPConn + } +} + +func (c *Conn) readPayloadLen(length byte, buf []byte) (payloadLen uint64, err error) { + if length < 126 { + payloadLen = uint64(length) + } else if length == 126 { + err = c.read(buf[:2]) + if err != nil { + return + } + payloadLen = uint64(binary.BigEndian.Uint16(buf[:2])) + } else if length == 127 { + err = c.read(buf[:8]) + if err != nil { + return + } + payloadLen = uint64(binary.BigEndian.Uint16(buf[:8])) + } + + return +} + +func (c *Conn) readFrame(buf []byte) (opcode byte, messsage []byte, err error) { + //minimum head may 2 byte + + err = c.read(buf[:2]) + if err != nil { + return + } + + opcode = buf[0] + + if opcode&0x70 > 0 { + err = ErrRSVNotSupport + return + } + + //isMasking := (0x80 & buf[1]) > 0 + isMasking := (0x80 & buf[1]) > 0 + + var payloadLen uint64 + payloadLen, err = c.readPayloadLen(buf[1]&0x7F, buf) + if err != nil { + return + } + + if opcode&0x08 > 0 && payloadLen > 125 { + err = ErrControlTooLong + return + } + + var masking []byte + + if isMasking { + err = c.read(buf[:4]) + if err != nil { + return + } + + masking = buf[:4] + } + + messsage = make([]byte, payloadLen) + err = c.read(messsage) + + if err != nil { + return + } + + if isMasking { + //maskingKey := c.newMaskingKey() + c.maskingData(messsage, masking) + } + + return +} + +func (c *Conn) sendFrame(opcode byte, message []byte) error { + //max frame header may 14 length + buf := make([]byte, 0, len(message)+14) + //here we don not support continue frame, all are final + opcode |= 0x80 + + if opcode&0x08 > 0 && len(message) >= 126 { + return ErrControlTooLong + } + + buf = append(buf, opcode) + + //no mask, because chrome may not support + var mask byte = 0x00 + + if !c.isServer { + //for client, we will mask data + mask = 0x80 + } + + payloadLen := len(message) + + if payloadLen < 126 { + buf = append(buf, mask|byte(payloadLen)) + } else if payloadLen <= 0xFFFF { + buf = append(buf, mask|byte(126), 0, 0) + + binary.BigEndian.PutUint16(buf[len(buf)-2:], uint16(payloadLen)) + } else { + buf = append(buf, mask|byte(127), 0, 0, 0, 0, 0, 0, 0, 0) + + binary.BigEndian.PutUint64(buf[len(buf)-8:], uint64(payloadLen)) + } + + if !c.isServer { + maskingKey := c.newMaskingKey() + buf = append(buf, maskingKey...) + + pos := len(buf) + buf = append(buf, message...) + + c.maskingData(buf[pos:], maskingKey) + + } else { + buf = append(buf, message...) + } + + tmpBuf := buf + for i := 0; i < 3; i++ { + n, err := c.conn.Write(tmpBuf) + if err != nil { + return err + } + if n == len(tmpBuf) { + return nil + } else { + tmpBuf = tmpBuf[n:] + } + } + return ErrWriteError +} + +func (c *Conn) read(buf []byte) error { + var err error + for len(buf) > 0 && err == nil { + var nn int + nn, err = c.br.Read(buf) + buf = buf[nn:] + } + if err == io.EOF { + if len(buf) == 0 { + err = nil + } else { + err = io.ErrUnexpectedEOF + } + } + return err +} + +func (c *Conn) maskingData(data []byte, maskingKey []byte) { + for i := range data { + data[i] ^= maskingKey[i%4] + } +} + +func (c *Conn) newMaskingKey() []byte { + n := rand.Uint32() + return []byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 32)} +} diff --git a/websocket/pingpong_test.go b/websocket/pingpong_test.go index e188b82..7476e24 100644 --- a/websocket/pingpong_test.go +++ b/websocket/pingpong_test.go @@ -1,51 +1,51 @@ -package websocket - -import ( - "net" - "net/http" - "net/url" - "testing" - "time" -) - -func TestWSPing(t *testing.T) { - http.HandleFunc("/test/ping", func(w http.ResponseWriter, r *http.Request) { - conn, err := Upgrade(w, r, nil) - if err != nil { - t.Fatal(err.Error()) - } - //conn := NewConn(c, true) - conn.Read() - conn.Pong([]byte{}) - conn.Ping([]byte{}) - msgType, _, _ := conn.Read() - println(msgType) - }) - - go http.ListenAndServe(":65500", nil) - time.Sleep(time.Second * 1) - - conn, err := net.Dial("tcp", "127.0.0.1:65500") - - if err != nil { - t.Fatal(err.Error()) - } - ws, _, err := NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/ping"}, nil) - - if err != nil { - t.Fatal(err.Error()) - } - ws.Ping([]byte{}) - - msgType, _, _ := ws.Read() - if msgType != PongMessage { - t.Fatal("invalid msg type", msgType) - } - - msgType, _, _ = ws.Read() - if msgType != PingMessage { - t.Fatal("invalid msg type", msgType) - } - ws.Pong([]byte{}) - time.Sleep(time.Second * 1) -} +package websocket + +import ( + "net" + "net/http" + "net/url" + "testing" + "time" +) + +func TestWSPing(t *testing.T) { + http.HandleFunc("/test/ping", func(w http.ResponseWriter, r *http.Request) { + conn, err := Upgrade(w, r, nil) + if err != nil { + t.Fatal(err.Error()) + } + //conn := NewConn(c, true) + conn.Read() + conn.Pong([]byte{}) + conn.Ping([]byte{}) + msgType, _, _ := conn.Read() + println(msgType) + }) + + go http.ListenAndServe(":65500", nil) + time.Sleep(time.Second * 1) + + conn, err := net.Dial("tcp", "127.0.0.1:65500") + + if err != nil { + t.Fatal(err.Error()) + } + ws, _, err := NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/ping"}, nil) + + if err != nil { + t.Fatal(err.Error()) + } + ws.Ping([]byte{}) + + msgType, _, _ := ws.Read() + if msgType != PongMessage { + t.Fatal("invalid msg type", msgType) + } + + msgType, _, _ = ws.Read() + if msgType != PingMessage { + t.Fatal("invalid msg type", msgType) + } + ws.Pong([]byte{}) + time.Sleep(time.Second * 1) +} diff --git a/websocket/server.go b/websocket/server.go index 360233e..ab09539 100644 --- a/websocket/server.go +++ b/websocket/server.go @@ -1,105 +1,105 @@ -package websocket - -import ( - "bufio" - "bytes" - "errors" - "net" - "net/http" - "strings" -) - -var ( - ErrInvalidMethod = errors.New("Only GET Supported") - ErrInvalidVersion = errors.New("Sec-Websocket-Version: 13") - ErrInvalidUpgrade = errors.New("Can \"Upgrade\" only to \"WebSocket\"") - ErrInvalidConnection = errors.New("\"Connection\" must be \"Upgrade\"") - ErrMissingKey = errors.New("Missing Key") - ErrHijacker = errors.New("Not implement http.Hijacker") - ErrNoEmptyConn = errors.New("Conn ReadBuf must be empty") -) - -func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { - if r.Method != "GET" { - return nil, ErrInvalidMethod - } - - if r.Header.Get("Sec-Websocket-Version") != "13" { - return nil, ErrInvalidVersion - } - - if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" { - return nil, ErrInvalidUpgrade - } - - if strings.ToLower(r.Header.Get("Connection")) != "upgrade" { - return nil, ErrInvalidConnection - } - - var acceptKey string - - if key := r.Header.Get("Sec-Websocket-key"); len(key) == 0 { - return nil, ErrMissingKey - } else { - acceptKey = calcAcceptKey(key) - } - - var ( - netConn net.Conn - br *bufio.Reader - err error - ) - - h, ok := w.(http.Hijacker) - if !ok { - return nil, ErrHijacker - } - - var rw *bufio.ReadWriter - netConn, rw, err = h.Hijack() - br = rw.Reader - - if br.Buffered() > 0 { - netConn.Close() - return nil, ErrNoEmptyConn - } - - c := NewConn(netConn, true) - - buf := bytes.NewBufferString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ") - - buf.WriteString(acceptKey) - buf.WriteString("\r\n") - - subProtol := selectSubProtocol(r) - if len(subProtol) > 0 { - buf.WriteString("Sec-Websocket-Protocol: ") - buf.WriteString(subProtol) - buf.WriteString("\r\n") - } - - for k, vs := range responseHeader { - for _, v := range vs { - buf.WriteString(k) - buf.WriteString(": ") - buf.WriteString(v) - buf.WriteString("\r\n") - } - } - buf.WriteString("\r\n") - - if _, err = netConn.Write(buf.Bytes()); err != nil { - netConn.Close() - return nil, err - } - - return c, nil -} - -func selectSubProtocol(r *http.Request) string { - h := r.Header.Get("Sec-Websocket-Protocol") - if len(h) == 0 { - return "" - } - return strings.Split(h, ",")[0] -} +package websocket + +import ( + "bufio" + "bytes" + "errors" + "net" + "net/http" + "strings" +) + +var ( + ErrInvalidMethod = errors.New("Only GET Supported") + ErrInvalidVersion = errors.New("Sec-Websocket-Version: 13") + ErrInvalidUpgrade = errors.New("Can \"Upgrade\" only to \"WebSocket\"") + ErrInvalidConnection = errors.New("\"Connection\" must be \"Upgrade\"") + ErrMissingKey = errors.New("Missing Key") + ErrHijacker = errors.New("Not implement http.Hijacker") + ErrNoEmptyConn = errors.New("Conn ReadBuf must be empty") +) + +func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { + if r.Method != "GET" { + return nil, ErrInvalidMethod + } + + if r.Header.Get("Sec-Websocket-Version") != "13" { + return nil, ErrInvalidVersion + } + + if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" { + return nil, ErrInvalidUpgrade + } + + if strings.ToLower(r.Header.Get("Connection")) != "upgrade" { + return nil, ErrInvalidConnection + } + + var acceptKey string + + if key := r.Header.Get("Sec-Websocket-key"); len(key) == 0 { + return nil, ErrMissingKey + } else { + acceptKey = calcAcceptKey(key) + } + + var ( + netConn net.Conn + br *bufio.Reader + err error + ) + + h, ok := w.(http.Hijacker) + if !ok { + return nil, ErrHijacker + } + + var rw *bufio.ReadWriter + netConn, rw, err = h.Hijack() + br = rw.Reader + + if br.Buffered() > 0 { + netConn.Close() + return nil, ErrNoEmptyConn + } + + c := NewConn(netConn, true) + + buf := bytes.NewBufferString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ") + + buf.WriteString(acceptKey) + buf.WriteString("\r\n") + + subProtol := selectSubProtocol(r) + if len(subProtol) > 0 { + buf.WriteString("Sec-Websocket-Protocol: ") + buf.WriteString(subProtol) + buf.WriteString("\r\n") + } + + for k, vs := range responseHeader { + for _, v := range vs { + buf.WriteString(k) + buf.WriteString(": ") + buf.WriteString(v) + buf.WriteString("\r\n") + } + } + buf.WriteString("\r\n") + + if _, err = netConn.Write(buf.Bytes()); err != nil { + netConn.Close() + return nil, err + } + + return c, nil +} + +func selectSubProtocol(r *http.Request) string { + h := r.Header.Get("Sec-Websocket-Protocol") + if len(h) == 0 { + return "" + } + return strings.Split(h, ",")[0] +} diff --git a/websocket/server_test.go b/websocket/server_test.go index 73eba96..3066f79 100644 --- a/websocket/server_test.go +++ b/websocket/server_test.go @@ -1,97 +1,98 @@ -package websocket - -import ( - "github.com/gorilla/websocket" - "net" - "net/http" - "net/url" - "testing" - "time" -) - -func TestWSServer(t *testing.T) { - http.HandleFunc("/test/server", func(w http.ResponseWriter, r *http.Request) { - conn, err := Upgrade(w, r, nil) - - if err != nil { - t.Fatal(err.Error()) - } - //err = conn.SetReadBuffer(1024 * 1024 * 4) - //if err != nil { - // println(err.Error()) - //} - //err = conn.SetWriteBuffer(1024 * 1024 * 4) - - //if err != nil { - // println(err.Error()) - //} - - msgType, msg, err := conn.Read() - conn.Write(msg, false) - - if err != nil { - t.Fatal(err.Error()) - } - - if msgType != TextMessage { - t.Fatal("wrong msg type", msgType) - } - - msgType, msg, err = conn.ReadMessage() - if err != nil { - t.Fatal(err.Error()) - } - - if msgType != PingMessage { - t.Fatal("wrong msg type", msgType) - } - - err = conn.Pong([]byte("abc")) - - if err != nil { - t.Fatal(err.Error()) - } - - }) - - go http.ListenAndServe(":65500", nil) - time.Sleep(time.Second * 1) - - conn, err := net.Dial("tcp", "127.0.0.1:65500") - - if err != nil { - t.Fatal(err.Error()) - } - ws, _, err := websocket.NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/server"}, nil, 1024, 1024) - - ws.SetPongHandler(func(string) error { - println("pong") - return nil - }) - - if err != nil { - t.Fatal(err.Error()) - } - - payload := make([]byte, 4*1024*1024) - for i := 0; i < 4*1024*1024; i++ { - payload[i] = 'x' - } - - ws.WriteMessage(websocket.TextMessage, payload) - - msgType, msg, err := ws.ReadMessage() - if err != nil { - t.Fatal(err.Error()) - } - if msgType != websocket.TextMessage { - t.Fatal("invalid msg type", msgType) - } - - if string(msg) != string(payload) { - t.Fatal("invalid msg", string(msg)) - - } - - time.Sleep(time.Second * 1) -} +package websocket + +import ( + "net" + "net/http" + "net/url" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +func TestWSServer(t *testing.T) { + http.HandleFunc("/test/server", func(w http.ResponseWriter, r *http.Request) { + conn, err := Upgrade(w, r, nil) + + if err != nil { + t.Fatal(err.Error()) + } + //err = conn.SetReadBuffer(1024 * 1024 * 4) + //if err != nil { + // println(err.Error()) + //} + //err = conn.SetWriteBuffer(1024 * 1024 * 4) + + //if err != nil { + // println(err.Error()) + //} + + msgType, msg, err := conn.Read() + conn.Write(msg, false) + + if err != nil { + t.Fatal(err.Error()) + } + + if msgType != TextMessage { + t.Fatal("wrong msg type", msgType) + } + + msgType, msg, err = conn.ReadMessage() + if err != nil { + t.Fatal(err.Error()) + } + + if msgType != PingMessage { + t.Fatal("wrong msg type", msgType) + } + + err = conn.Pong([]byte("abc")) + + if err != nil { + t.Fatal(err.Error()) + } + + }) + + go http.ListenAndServe(":65500", nil) + time.Sleep(time.Second * 1) + + conn, err := net.Dial("tcp", "127.0.0.1:65500") + + if err != nil { + t.Fatal(err.Error()) + } + ws, _, err := websocket.NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/server"}, nil, 1024, 1024) + + ws.SetPongHandler(func(string) error { + println("pong") + return nil + }) + + if err != nil { + t.Fatal(err.Error()) + } + + payload := make([]byte, 4*1024*1024) + for i := 0; i < 4*1024*1024; i++ { + payload[i] = 'x' + } + + ws.WriteMessage(websocket.TextMessage, payload) + + msgType, msg, err := ws.ReadMessage() + if err != nil { + t.Fatal(err.Error()) + } + if msgType != websocket.TextMessage { + t.Fatal("invalid msg type", msgType) + } + + if string(msg) != string(payload) { + t.Fatal("invalid msg", string(msg)) + + } + + time.Sleep(time.Second * 1) +} diff --git a/websocket/util.go b/websocket/util.go index 9ad5ec9..cd0eddd 100644 --- a/websocket/util.go +++ b/websocket/util.go @@ -1,36 +1,36 @@ -package websocket - -import ( - "crypto/rand" - "crypto/sha1" - "encoding/base64" - "errors" - "io" -) - -var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") - -func calcAcceptKey(key string) string { - h := sha1.New() - h.Write([]byte(key)) - h.Write(keyGUID) - return base64.StdEncoding.EncodeToString(h.Sum(nil)) -} - -func calcKey() (string, error) { - p := make([]byte, 16) - if _, err := io.ReadFull(rand.Reader, p); err != nil { - return "", err - } - return base64.StdEncoding.EncodeToString(p), nil -} - -func HandleCloseFrame(buf []byte) (int16, string, error) { - - if len(buf) < 2 { - return 0, "", errors.New("close frame msg's length less than 2") - } - code := int16(buf[0])<<8 + int16(buf[1]) - reason := string(buf[2:]) - return code, reason, nil -} +package websocket + +import ( + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "errors" + "io" +) + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func calcAcceptKey(key string) string { + h := sha1.New() + h.Write([]byte(key)) + h.Write(keyGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func calcKey() (string, error) { + p := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, p); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(p), nil +} + +func HandleCloseFrame(buf []byte) (int16, string, error) { + + if len(buf) < 2 { + return 0, "", errors.New("close frame msg's length less than 2") + } + code := int16(buf[0])<<8 + int16(buf[1]) + reason := string(buf[2:]) + return code, reason, nil +}