format with goimports
This commit is contained in:
parent
530a231625
commit
b151716326
|
@ -1,18 +1,18 @@
|
||||||
// BSON library for Go
|
// BSON library for Go
|
||||||
//
|
//
|
||||||
// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
|
// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
|
||||||
//
|
//
|
||||||
// All rights reserved.
|
// All rights reserved.
|
||||||
//
|
//
|
||||||
// Redistribution and use in source and binary forms, with or without
|
// Redistribution and use in source and binary forms, with or without
|
||||||
// modification, are permitted provided that the following conditions are met:
|
// modification, are permitted provided that the following conditions are met:
|
||||||
//
|
//
|
||||||
// 1. Redistributions of source code must retain the above copyright notice, this
|
// 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
// list of conditions and the following disclaimer.
|
// list of conditions and the following disclaimer.
|
||||||
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
// this list of conditions and the following disclaimer in the documentation
|
// this list of conditions and the following disclaimer in the documentation
|
||||||
// and/or other materials provided with the distribution.
|
// and/or other materials provided with the distribution.
|
||||||
//
|
//
|
||||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
|
|
@ -1,18 +1,18 @@
|
||||||
// BSON library for Go
|
// BSON library for Go
|
||||||
//
|
//
|
||||||
// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
|
// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
|
||||||
//
|
//
|
||||||
// All rights reserved.
|
// All rights reserved.
|
||||||
//
|
//
|
||||||
// Redistribution and use in source and binary forms, with or without
|
// Redistribution and use in source and binary forms, with or without
|
||||||
// modification, are permitted provided that the following conditions are met:
|
// modification, are permitted provided that the following conditions are met:
|
||||||
//
|
//
|
||||||
// 1. Redistributions of source code must retain the above copyright notice, this
|
// 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
// list of conditions and the following disclaimer.
|
// list of conditions and the following disclaimer.
|
||||||
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
// this list of conditions and the following disclaimer in the documentation
|
// this list of conditions and the following disclaimer in the documentation
|
||||||
// and/or other materials provided with the distribution.
|
// and/or other materials provided with the distribution.
|
||||||
//
|
//
|
||||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
@ -182,7 +182,7 @@ func isZero(v reflect.Value) bool {
|
||||||
if v.Type() == typeTime {
|
if v.Type() == typeTime {
|
||||||
return v.Interface().(time.Time).IsZero()
|
return v.Interface().(time.Time).IsZero()
|
||||||
}
|
}
|
||||||
for i := v.NumField()-1; i >= 0; i-- {
|
for i := v.NumField() - 1; i >= 0; i-- {
|
||||||
if !isZero(v.Field(i)) {
|
if !isZero(v.Field(i)) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -207,7 +207,7 @@ func (e *encoder) addSlice(v reflect.Value) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
l := v.Len()
|
l := v.Len()
|
||||||
et := v.Type().Elem()
|
et := v.Type().Elem()
|
||||||
if et == typeDocElem {
|
if et == typeDocElem {
|
||||||
for i := 0; i < l; i++ {
|
for i := 0; i < l; i++ {
|
||||||
elem := v.Index(i).Interface().(DocElem)
|
elem := v.Index(i).Interface().(DocElem)
|
||||||
|
@ -401,7 +401,7 @@ func (e *encoder) addElem(name string, v reflect.Value, minSize bool) {
|
||||||
case time.Time:
|
case time.Time:
|
||||||
// MongoDB handles timestamps as milliseconds.
|
// MongoDB handles timestamps as milliseconds.
|
||||||
e.addElemName('\x09', name)
|
e.addElemName('\x09', name)
|
||||||
e.addInt64(s.Unix() * 1000 + int64(s.Nanosecond() / 1e6))
|
e.addInt64(s.Unix()*1000 + int64(s.Nanosecond()/1e6))
|
||||||
|
|
||||||
case url.URL:
|
case url.URL:
|
||||||
e.addElemName('\x02', name)
|
e.addElemName('\x02', name)
|
||||||
|
|
|
@ -36,7 +36,7 @@ func (h *FileHandler) Close() error {
|
||||||
return h.fd.Close()
|
return h.fd.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
//RotatingFileHandler writes log a file, if file size exceeds maxBytes,
|
//RotatingFileHandler writes log a file, if file size exceeds maxBytes,
|
||||||
//it will backup current file and open a new one.
|
//it will backup current file and open a new one.
|
||||||
//
|
//
|
||||||
//max backup file number is set by backupCount, it will delete oldest if backups too many.
|
//max backup file number is set by backupCount, it will delete oldest if backups too many.
|
||||||
|
@ -112,7 +112,7 @@ func (h *RotatingFileHandler) doRollover() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//TimeRotatingFileHandler writes log to a file,
|
//TimeRotatingFileHandler writes log to a file,
|
||||||
//it will backup current and open a new one, with a period time you sepecified.
|
//it will backup current and open a new one, with a period time you sepecified.
|
||||||
//
|
//
|
||||||
//refer: http://docs.python.org/2/library/logging.handlers.html.
|
//refer: http://docs.python.org/2/library/logging.handlers.html.
|
||||||
|
|
|
@ -31,8 +31,7 @@ func (h *StreamHandler) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//NullHandler does nothing, it discards anything.
|
||||||
//NullHandler does nothing, it discards anything.
|
|
||||||
type NullHandler struct {
|
type NullHandler struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,8 +7,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
//SocketHandler writes log to a connectionl.
|
//SocketHandler writes log to a connectionl.
|
||||||
//Network protocol is simple: log length + log | log length + log. log length is uint32, bigendian.
|
//Network protocol is simple: log length + log | log length + log. log length is uint32, bigendian.
|
||||||
//you must implement your own log server, maybe you can use logd instead simply.
|
//you must implement your own log server, maybe you can use logd instead simply.
|
||||||
type SocketHandler struct {
|
type SocketHandler struct {
|
||||||
c net.Conn
|
c net.Conn
|
||||||
protocol string
|
protocol string
|
||||||
|
|
|
@ -1,60 +1,60 @@
|
||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrBadHandshake = errors.New("bad handshake")
|
ErrBadHandshake = errors.New("bad handshake")
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header) (c *Conn, response *http.Response, err error) {
|
func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header) (c *Conn, response *http.Response, err error) {
|
||||||
key, err := calcKey()
|
key, err := calcKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
acceptKey := calcAcceptKey(key)
|
acceptKey := calcAcceptKey(key)
|
||||||
|
|
||||||
c = NewConn(netConn, false)
|
c = NewConn(netConn, false)
|
||||||
|
|
||||||
buf := bytes.NewBufferString("GET ")
|
buf := bytes.NewBufferString("GET ")
|
||||||
buf.WriteString(u.RequestURI())
|
buf.WriteString(u.RequestURI())
|
||||||
buf.WriteString(" HTTP/1.1\r\nHost: ")
|
buf.WriteString(" HTTP/1.1\r\nHost: ")
|
||||||
buf.WriteString(u.Host)
|
buf.WriteString(u.Host)
|
||||||
buf.WriteString("\r\nUpgrade: websocket\r\nConnection: upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: ")
|
buf.WriteString("\r\nUpgrade: websocket\r\nConnection: upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: ")
|
||||||
buf.WriteString(key)
|
buf.WriteString(key)
|
||||||
buf.WriteString("\r\n")
|
buf.WriteString("\r\n")
|
||||||
|
|
||||||
for k, vs := range requestHeader {
|
for k, vs := range requestHeader {
|
||||||
for _, v := range vs {
|
for _, v := range vs {
|
||||||
buf.WriteString(k)
|
buf.WriteString(k)
|
||||||
buf.WriteString(": ")
|
buf.WriteString(": ")
|
||||||
buf.WriteString(v)
|
buf.WriteString(v)
|
||||||
buf.WriteString("\r\n")
|
buf.WriteString("\r\n")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
buf.WriteString("\r\n")
|
buf.WriteString("\r\n")
|
||||||
p := buf.Bytes()
|
p := buf.Bytes()
|
||||||
if _, err := netConn.Write(p); err != nil {
|
if _, err := netConn.Write(p); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := http.ReadResponse(c.br, &http.Request{Method: "GET", URL: u})
|
resp, err := http.ReadResponse(c.br, &http.Request{Method: "GET", URL: u})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != 101 ||
|
if resp.StatusCode != 101 ||
|
||||||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
|
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
|
||||||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
|
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
|
||||||
resp.Header.Get("Sec-Websocket-Accept") != acceptKey {
|
resp.Header.Get("Sec-Websocket-Accept") != acceptKey {
|
||||||
return nil, resp, ErrBadHandshake
|
return nil, resp, ErrBadHandshake
|
||||||
}
|
}
|
||||||
return c, resp, nil
|
return c, resp, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,99 +1,100 @@
|
||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gorilla/websocket"
|
"net"
|
||||||
"net"
|
"net/http"
|
||||||
"net/http"
|
"net/url"
|
||||||
"net/url"
|
"testing"
|
||||||
"testing"
|
"time"
|
||||||
"time"
|
|
||||||
)
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
func TestWSClient(t *testing.T) {
|
|
||||||
http.HandleFunc("/test/client", func(w http.ResponseWriter, r *http.Request) {
|
func TestWSClient(t *testing.T) {
|
||||||
conn, err := websocket.Upgrade(w, r, nil, 1024, 1024)
|
http.HandleFunc("/test/client", func(w http.ResponseWriter, r *http.Request) {
|
||||||
if err != nil {
|
conn, err := websocket.Upgrade(w, r, nil, 1024, 1024)
|
||||||
t.Fatal(err.Error())
|
if err != nil {
|
||||||
}
|
t.Fatal(err.Error())
|
||||||
|
}
|
||||||
msgType, msg, err := conn.ReadMessage()
|
|
||||||
conn.WriteMessage(websocket.TextMessage, msg)
|
msgType, msg, err := conn.ReadMessage()
|
||||||
|
conn.WriteMessage(websocket.TextMessage, msg)
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err.Error())
|
if err != nil {
|
||||||
}
|
t.Fatal(err.Error())
|
||||||
|
}
|
||||||
if msgType != websocket.TextMessage {
|
|
||||||
t.Fatal("invalid msg type", msgType)
|
if msgType != websocket.TextMessage {
|
||||||
}
|
t.Fatal("invalid msg type", msgType)
|
||||||
|
}
|
||||||
msgType, msg, err = conn.ReadMessage()
|
|
||||||
if err != nil {
|
msgType, msg, err = conn.ReadMessage()
|
||||||
t.Fatal(err.Error())
|
if err != nil {
|
||||||
}
|
t.Fatal(err.Error())
|
||||||
|
}
|
||||||
if msgType != websocket.PingMessage {
|
|
||||||
t.Fatal("invalid msg type", msgType)
|
if msgType != websocket.PingMessage {
|
||||||
}
|
t.Fatal("invalid msg type", msgType)
|
||||||
|
}
|
||||||
conn.WriteMessage(websocket.PongMessage, []byte{})
|
|
||||||
|
conn.WriteMessage(websocket.PongMessage, []byte{})
|
||||||
conn.WriteMessage(websocket.PingMessage, []byte{})
|
|
||||||
|
conn.WriteMessage(websocket.PingMessage, []byte{})
|
||||||
msgType, msg, err = conn.ReadMessage()
|
|
||||||
if err != nil {
|
msgType, msg, err = conn.ReadMessage()
|
||||||
t.Fatal(err.Error())
|
if err != nil {
|
||||||
}
|
t.Fatal(err.Error())
|
||||||
println(msgType)
|
}
|
||||||
if msgType != websocket.PongMessage {
|
println(msgType)
|
||||||
|
if msgType != websocket.PongMessage {
|
||||||
t.Fatal("invalid msg type", msgType)
|
|
||||||
}
|
t.Fatal("invalid msg type", msgType)
|
||||||
})
|
}
|
||||||
|
})
|
||||||
go http.ListenAndServe(":65500", nil)
|
|
||||||
|
go http.ListenAndServe(":65500", nil)
|
||||||
time.Sleep(time.Second * 1)
|
|
||||||
|
time.Sleep(time.Second * 1)
|
||||||
conn, err := net.Dial("tcp", "127.0.0.1:65500")
|
|
||||||
|
conn, err := net.Dial("tcp", "127.0.0.1:65500")
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err.Error())
|
if err != nil {
|
||||||
}
|
t.Fatal(err.Error())
|
||||||
ws, _, err := NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/client"}, nil)
|
}
|
||||||
|
ws, _, err := NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/client"}, nil)
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err.Error())
|
if err != nil {
|
||||||
}
|
t.Fatal(err.Error())
|
||||||
|
}
|
||||||
payload := make([]byte, 4*1024)
|
|
||||||
for i := 0; i < 4*1024; i++ {
|
payload := make([]byte, 4*1024)
|
||||||
payload[i] = 'x'
|
for i := 0; i < 4*1024; i++ {
|
||||||
}
|
payload[i] = 'x'
|
||||||
|
}
|
||||||
ws.WriteString(payload)
|
|
||||||
|
ws.WriteString(payload)
|
||||||
msgType, msg, err := ws.Read()
|
|
||||||
if err != nil {
|
msgType, msg, err := ws.Read()
|
||||||
t.Fatal(err.Error())
|
if err != nil {
|
||||||
}
|
t.Fatal(err.Error())
|
||||||
if msgType != TextMessage {
|
}
|
||||||
t.Fatal("invalid msg type", msgType)
|
if msgType != TextMessage {
|
||||||
}
|
t.Fatal("invalid msg type", msgType)
|
||||||
|
}
|
||||||
if string(msg) != string(payload) {
|
|
||||||
t.Fatal("invalid msg", string(msg))
|
if string(msg) != string(payload) {
|
||||||
|
t.Fatal("invalid msg", string(msg))
|
||||||
}
|
|
||||||
|
}
|
||||||
//test ping
|
|
||||||
ws.Ping([]byte{})
|
//test ping
|
||||||
msgType, msg, err = ws.ReadMessage()
|
ws.Ping([]byte{})
|
||||||
if err != nil {
|
msgType, msg, err = ws.ReadMessage()
|
||||||
t.Fatal(err.Error())
|
if err != nil {
|
||||||
}
|
t.Fatal(err.Error())
|
||||||
if msgType != PongMessage {
|
}
|
||||||
t.Fatal("invalid msg type", msgType)
|
if msgType != PongMessage {
|
||||||
}
|
t.Fatal("invalid msg type", msgType)
|
||||||
|
}
|
||||||
}
|
|
||||||
|
}
|
||||||
|
|
|
@ -1,321 +1,321 @@
|
||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
//refer RFC6455
|
//refer RFC6455
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TextMessage byte = 1
|
TextMessage byte = 1
|
||||||
BinaryMessage byte = 2
|
BinaryMessage byte = 2
|
||||||
CloseMessage byte = 8
|
CloseMessage byte = 8
|
||||||
PingMessage byte = 9
|
PingMessage byte = 9
|
||||||
PongMessage byte = 10
|
PongMessage byte = 10
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrControlTooLong = errors.New("control message too long")
|
ErrControlTooLong = errors.New("control message too long")
|
||||||
ErrRSVNotSupport = errors.New("reserved bit not support")
|
ErrRSVNotSupport = errors.New("reserved bit not support")
|
||||||
ErrPayloadError = errors.New("payload length error")
|
ErrPayloadError = errors.New("payload length error")
|
||||||
ErrControlFragmented = errors.New("control message can not be fragmented")
|
ErrControlFragmented = errors.New("control message can not be fragmented")
|
||||||
ErrNotTCPConn = errors.New("not a tcp connection")
|
ErrNotTCPConn = errors.New("not a tcp connection")
|
||||||
ErrWriteError = errors.New("write error")
|
ErrWriteError = errors.New("write error")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
|
|
||||||
br *bufio.Reader
|
br *bufio.Reader
|
||||||
|
|
||||||
isServer bool
|
isServer bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConn(conn net.Conn, isServer bool) *Conn {
|
func NewConn(conn net.Conn, isServer bool) *Conn {
|
||||||
c := new(Conn)
|
c := new(Conn)
|
||||||
|
|
||||||
c.conn = conn
|
c.conn = conn
|
||||||
|
|
||||||
c.br = bufio.NewReader(conn)
|
c.br = bufio.NewReader(conn)
|
||||||
|
|
||||||
c.isServer = isServer
|
c.isServer = isServer
|
||||||
|
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) ReadMessage() (messageType byte, message []byte, err error) {
|
func (c *Conn) ReadMessage() (messageType byte, message []byte, err error) {
|
||||||
return c.Read()
|
return c.Read()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Read() (messageType byte, message []byte, err error) {
|
func (c *Conn) Read() (messageType byte, message []byte, err error) {
|
||||||
buf := make([]byte, 8, 8)
|
buf := make([]byte, 8, 8)
|
||||||
|
|
||||||
message = []byte{}
|
message = []byte{}
|
||||||
|
|
||||||
messageType = 0
|
messageType = 0
|
||||||
|
|
||||||
for {
|
for {
|
||||||
opcode, data, err := c.readFrame(buf)
|
opcode, data, err := c.readFrame(buf)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return messageType, message, err
|
return messageType, message, err
|
||||||
}
|
}
|
||||||
|
|
||||||
message = append(message, data...)
|
message = append(message, data...)
|
||||||
|
|
||||||
if opcode&0x80 != 0 {
|
if opcode&0x80 != 0 {
|
||||||
//final
|
//final
|
||||||
if opcode&0x0F > 0 {
|
if opcode&0x0F > 0 {
|
||||||
//not continue frame
|
//not continue frame
|
||||||
messageType = opcode & 0x0F
|
messageType = opcode & 0x0F
|
||||||
}
|
}
|
||||||
return messageType, message, nil
|
return messageType, message, nil
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if opcode&0x0F > 0 {
|
if opcode&0x0F > 0 {
|
||||||
//first continue frame
|
//first continue frame
|
||||||
messageType = opcode & 0x0F
|
messageType = opcode & 0x0F
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Write(message []byte, binary bool) error {
|
func (c *Conn) Write(message []byte, binary bool) error {
|
||||||
if binary {
|
if binary {
|
||||||
return c.sendFrame(BinaryMessage, message)
|
return c.sendFrame(BinaryMessage, message)
|
||||||
} else {
|
} else {
|
||||||
return c.sendFrame(TextMessage, message)
|
return c.sendFrame(TextMessage, message)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) WriteMessage(messageType byte, message []byte) error {
|
func (c *Conn) WriteMessage(messageType byte, message []byte) error {
|
||||||
return c.sendFrame(messageType, message)
|
return c.sendFrame(messageType, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
//write utf-8 text message
|
//write utf-8 text message
|
||||||
func (c *Conn) WriteString(message []byte) error {
|
func (c *Conn) WriteString(message []byte) error {
|
||||||
return c.Write(message, false)
|
return c.Write(message, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
//write binary message
|
//write binary message
|
||||||
func (c *Conn) WriteBinary(message []byte) error {
|
func (c *Conn) WriteBinary(message []byte) error {
|
||||||
return c.Write(message, true)
|
return c.Write(message, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Ping(message []byte) error {
|
func (c *Conn) Ping(message []byte) error {
|
||||||
return c.sendFrame(PingMessage, message)
|
return c.sendFrame(PingMessage, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Pong(message []byte) error {
|
func (c *Conn) Pong(message []byte) error {
|
||||||
return c.sendFrame(PongMessage, message)
|
return c.sendFrame(PongMessage, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
//close socket, not send websocket close message
|
//close socket, not send websocket close message
|
||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
return c.conn.Close()
|
return c.conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) LocalAddr() net.Addr {
|
func (c *Conn) LocalAddr() net.Addr {
|
||||||
return c.conn.LocalAddr()
|
return c.conn.LocalAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) RemoteAddr() net.Addr {
|
func (c *Conn) RemoteAddr() net.Addr {
|
||||||
return c.conn.RemoteAddr()
|
return c.conn.RemoteAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||||
return c.conn.SetReadDeadline(t)
|
return c.conn.SetReadDeadline(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||||
return c.conn.SetWriteDeadline(t)
|
return c.conn.SetWriteDeadline(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) SetReadBuffer(bytes int) error {
|
func (c *Conn) SetReadBuffer(bytes int) error {
|
||||||
if tcpConn, ok := c.conn.(*net.TCPConn); ok {
|
if tcpConn, ok := c.conn.(*net.TCPConn); ok {
|
||||||
return tcpConn.SetReadBuffer(bytes)
|
return tcpConn.SetReadBuffer(bytes)
|
||||||
} else {
|
} else {
|
||||||
return ErrNotTCPConn
|
return ErrNotTCPConn
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) SetWriteBuffer(bytes int) error {
|
func (c *Conn) SetWriteBuffer(bytes int) error {
|
||||||
if tcpConn, ok := c.conn.(*net.TCPConn); ok {
|
if tcpConn, ok := c.conn.(*net.TCPConn); ok {
|
||||||
return tcpConn.SetWriteBuffer(bytes)
|
return tcpConn.SetWriteBuffer(bytes)
|
||||||
} else {
|
} else {
|
||||||
return ErrNotTCPConn
|
return ErrNotTCPConn
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) readPayloadLen(length byte, buf []byte) (payloadLen uint64, err error) {
|
func (c *Conn) readPayloadLen(length byte, buf []byte) (payloadLen uint64, err error) {
|
||||||
if length < 126 {
|
if length < 126 {
|
||||||
payloadLen = uint64(length)
|
payloadLen = uint64(length)
|
||||||
} else if length == 126 {
|
} else if length == 126 {
|
||||||
err = c.read(buf[:2])
|
err = c.read(buf[:2])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
payloadLen = uint64(binary.BigEndian.Uint16(buf[:2]))
|
payloadLen = uint64(binary.BigEndian.Uint16(buf[:2]))
|
||||||
} else if length == 127 {
|
} else if length == 127 {
|
||||||
err = c.read(buf[:8])
|
err = c.read(buf[:8])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
payloadLen = uint64(binary.BigEndian.Uint16(buf[:8]))
|
payloadLen = uint64(binary.BigEndian.Uint16(buf[:8]))
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) readFrame(buf []byte) (opcode byte, messsage []byte, err error) {
|
func (c *Conn) readFrame(buf []byte) (opcode byte, messsage []byte, err error) {
|
||||||
//minimum head may 2 byte
|
//minimum head may 2 byte
|
||||||
|
|
||||||
err = c.read(buf[:2])
|
err = c.read(buf[:2])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
opcode = buf[0]
|
opcode = buf[0]
|
||||||
|
|
||||||
if opcode&0x70 > 0 {
|
if opcode&0x70 > 0 {
|
||||||
err = ErrRSVNotSupport
|
err = ErrRSVNotSupport
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//isMasking := (0x80 & buf[1]) > 0
|
//isMasking := (0x80 & buf[1]) > 0
|
||||||
isMasking := (0x80 & buf[1]) > 0
|
isMasking := (0x80 & buf[1]) > 0
|
||||||
|
|
||||||
var payloadLen uint64
|
var payloadLen uint64
|
||||||
payloadLen, err = c.readPayloadLen(buf[1]&0x7F, buf)
|
payloadLen, err = c.readPayloadLen(buf[1]&0x7F, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if opcode&0x08 > 0 && payloadLen > 125 {
|
if opcode&0x08 > 0 && payloadLen > 125 {
|
||||||
err = ErrControlTooLong
|
err = ErrControlTooLong
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var masking []byte
|
var masking []byte
|
||||||
|
|
||||||
if isMasking {
|
if isMasking {
|
||||||
err = c.read(buf[:4])
|
err = c.read(buf[:4])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
masking = buf[:4]
|
masking = buf[:4]
|
||||||
}
|
}
|
||||||
|
|
||||||
messsage = make([]byte, payloadLen)
|
messsage = make([]byte, payloadLen)
|
||||||
err = c.read(messsage)
|
err = c.read(messsage)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if isMasking {
|
if isMasking {
|
||||||
//maskingKey := c.newMaskingKey()
|
//maskingKey := c.newMaskingKey()
|
||||||
c.maskingData(messsage, masking)
|
c.maskingData(messsage, masking)
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) sendFrame(opcode byte, message []byte) error {
|
func (c *Conn) sendFrame(opcode byte, message []byte) error {
|
||||||
//max frame header may 14 length
|
//max frame header may 14 length
|
||||||
buf := make([]byte, 0, len(message)+14)
|
buf := make([]byte, 0, len(message)+14)
|
||||||
//here we don not support continue frame, all are final
|
//here we don not support continue frame, all are final
|
||||||
opcode |= 0x80
|
opcode |= 0x80
|
||||||
|
|
||||||
if opcode&0x08 > 0 && len(message) >= 126 {
|
if opcode&0x08 > 0 && len(message) >= 126 {
|
||||||
return ErrControlTooLong
|
return ErrControlTooLong
|
||||||
}
|
}
|
||||||
|
|
||||||
buf = append(buf, opcode)
|
buf = append(buf, opcode)
|
||||||
|
|
||||||
//no mask, because chrome may not support
|
//no mask, because chrome may not support
|
||||||
var mask byte = 0x00
|
var mask byte = 0x00
|
||||||
|
|
||||||
if !c.isServer {
|
if !c.isServer {
|
||||||
//for client, we will mask data
|
//for client, we will mask data
|
||||||
mask = 0x80
|
mask = 0x80
|
||||||
}
|
}
|
||||||
|
|
||||||
payloadLen := len(message)
|
payloadLen := len(message)
|
||||||
|
|
||||||
if payloadLen < 126 {
|
if payloadLen < 126 {
|
||||||
buf = append(buf, mask|byte(payloadLen))
|
buf = append(buf, mask|byte(payloadLen))
|
||||||
} else if payloadLen <= 0xFFFF {
|
} else if payloadLen <= 0xFFFF {
|
||||||
buf = append(buf, mask|byte(126), 0, 0)
|
buf = append(buf, mask|byte(126), 0, 0)
|
||||||
|
|
||||||
binary.BigEndian.PutUint16(buf[len(buf)-2:], uint16(payloadLen))
|
binary.BigEndian.PutUint16(buf[len(buf)-2:], uint16(payloadLen))
|
||||||
} else {
|
} else {
|
||||||
buf = append(buf, mask|byte(127), 0, 0, 0, 0, 0, 0, 0, 0)
|
buf = append(buf, mask|byte(127), 0, 0, 0, 0, 0, 0, 0, 0)
|
||||||
|
|
||||||
binary.BigEndian.PutUint64(buf[len(buf)-8:], uint64(payloadLen))
|
binary.BigEndian.PutUint64(buf[len(buf)-8:], uint64(payloadLen))
|
||||||
}
|
}
|
||||||
|
|
||||||
if !c.isServer {
|
if !c.isServer {
|
||||||
maskingKey := c.newMaskingKey()
|
maskingKey := c.newMaskingKey()
|
||||||
buf = append(buf, maskingKey...)
|
buf = append(buf, maskingKey...)
|
||||||
|
|
||||||
pos := len(buf)
|
pos := len(buf)
|
||||||
buf = append(buf, message...)
|
buf = append(buf, message...)
|
||||||
|
|
||||||
c.maskingData(buf[pos:], maskingKey)
|
c.maskingData(buf[pos:], maskingKey)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
buf = append(buf, message...)
|
buf = append(buf, message...)
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpBuf := buf
|
tmpBuf := buf
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
n, err := c.conn.Write(tmpBuf)
|
n, err := c.conn.Write(tmpBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if n == len(tmpBuf) {
|
if n == len(tmpBuf) {
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
tmpBuf = tmpBuf[n:]
|
tmpBuf = tmpBuf[n:]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ErrWriteError
|
return ErrWriteError
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) read(buf []byte) error {
|
func (c *Conn) read(buf []byte) error {
|
||||||
var err error
|
var err error
|
||||||
for len(buf) > 0 && err == nil {
|
for len(buf) > 0 && err == nil {
|
||||||
var nn int
|
var nn int
|
||||||
nn, err = c.br.Read(buf)
|
nn, err = c.br.Read(buf)
|
||||||
buf = buf[nn:]
|
buf = buf[nn:]
|
||||||
}
|
}
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
if len(buf) == 0 {
|
if len(buf) == 0 {
|
||||||
err = nil
|
err = nil
|
||||||
} else {
|
} else {
|
||||||
err = io.ErrUnexpectedEOF
|
err = io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) maskingData(data []byte, maskingKey []byte) {
|
func (c *Conn) maskingData(data []byte, maskingKey []byte) {
|
||||||
for i := range data {
|
for i := range data {
|
||||||
data[i] ^= maskingKey[i%4]
|
data[i] ^= maskingKey[i%4]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) newMaskingKey() []byte {
|
func (c *Conn) newMaskingKey() []byte {
|
||||||
n := rand.Uint32()
|
n := rand.Uint32()
|
||||||
return []byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 32)}
|
return []byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 32)}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,51 +1,51 @@
|
||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestWSPing(t *testing.T) {
|
func TestWSPing(t *testing.T) {
|
||||||
http.HandleFunc("/test/ping", func(w http.ResponseWriter, r *http.Request) {
|
http.HandleFunc("/test/ping", func(w http.ResponseWriter, r *http.Request) {
|
||||||
conn, err := Upgrade(w, r, nil)
|
conn, err := Upgrade(w, r, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
}
|
}
|
||||||
//conn := NewConn(c, true)
|
//conn := NewConn(c, true)
|
||||||
conn.Read()
|
conn.Read()
|
||||||
conn.Pong([]byte{})
|
conn.Pong([]byte{})
|
||||||
conn.Ping([]byte{})
|
conn.Ping([]byte{})
|
||||||
msgType, _, _ := conn.Read()
|
msgType, _, _ := conn.Read()
|
||||||
println(msgType)
|
println(msgType)
|
||||||
})
|
})
|
||||||
|
|
||||||
go http.ListenAndServe(":65500", nil)
|
go http.ListenAndServe(":65500", nil)
|
||||||
time.Sleep(time.Second * 1)
|
time.Sleep(time.Second * 1)
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", "127.0.0.1:65500")
|
conn, err := net.Dial("tcp", "127.0.0.1:65500")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
}
|
}
|
||||||
ws, _, err := NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/ping"}, nil)
|
ws, _, err := NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/ping"}, nil)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
}
|
}
|
||||||
ws.Ping([]byte{})
|
ws.Ping([]byte{})
|
||||||
|
|
||||||
msgType, _, _ := ws.Read()
|
msgType, _, _ := ws.Read()
|
||||||
if msgType != PongMessage {
|
if msgType != PongMessage {
|
||||||
t.Fatal("invalid msg type", msgType)
|
t.Fatal("invalid msg type", msgType)
|
||||||
}
|
}
|
||||||
|
|
||||||
msgType, _, _ = ws.Read()
|
msgType, _, _ = ws.Read()
|
||||||
if msgType != PingMessage {
|
if msgType != PingMessage {
|
||||||
t.Fatal("invalid msg type", msgType)
|
t.Fatal("invalid msg type", msgType)
|
||||||
}
|
}
|
||||||
ws.Pong([]byte{})
|
ws.Pong([]byte{})
|
||||||
time.Sleep(time.Second * 1)
|
time.Sleep(time.Second * 1)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,105 +1,105 @@
|
||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrInvalidMethod = errors.New("Only GET Supported")
|
ErrInvalidMethod = errors.New("Only GET Supported")
|
||||||
ErrInvalidVersion = errors.New("Sec-Websocket-Version: 13")
|
ErrInvalidVersion = errors.New("Sec-Websocket-Version: 13")
|
||||||
ErrInvalidUpgrade = errors.New("Can \"Upgrade\" only to \"WebSocket\"")
|
ErrInvalidUpgrade = errors.New("Can \"Upgrade\" only to \"WebSocket\"")
|
||||||
ErrInvalidConnection = errors.New("\"Connection\" must be \"Upgrade\"")
|
ErrInvalidConnection = errors.New("\"Connection\" must be \"Upgrade\"")
|
||||||
ErrMissingKey = errors.New("Missing Key")
|
ErrMissingKey = errors.New("Missing Key")
|
||||||
ErrHijacker = errors.New("Not implement http.Hijacker")
|
ErrHijacker = errors.New("Not implement http.Hijacker")
|
||||||
ErrNoEmptyConn = errors.New("Conn ReadBuf must be empty")
|
ErrNoEmptyConn = errors.New("Conn ReadBuf must be empty")
|
||||||
)
|
)
|
||||||
|
|
||||||
func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
|
func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
|
||||||
if r.Method != "GET" {
|
if r.Method != "GET" {
|
||||||
return nil, ErrInvalidMethod
|
return nil, ErrInvalidMethod
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Header.Get("Sec-Websocket-Version") != "13" {
|
if r.Header.Get("Sec-Websocket-Version") != "13" {
|
||||||
return nil, ErrInvalidVersion
|
return nil, ErrInvalidVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {
|
if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {
|
||||||
return nil, ErrInvalidUpgrade
|
return nil, ErrInvalidUpgrade
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.ToLower(r.Header.Get("Connection")) != "upgrade" {
|
if strings.ToLower(r.Header.Get("Connection")) != "upgrade" {
|
||||||
return nil, ErrInvalidConnection
|
return nil, ErrInvalidConnection
|
||||||
}
|
}
|
||||||
|
|
||||||
var acceptKey string
|
var acceptKey string
|
||||||
|
|
||||||
if key := r.Header.Get("Sec-Websocket-key"); len(key) == 0 {
|
if key := r.Header.Get("Sec-Websocket-key"); len(key) == 0 {
|
||||||
return nil, ErrMissingKey
|
return nil, ErrMissingKey
|
||||||
} else {
|
} else {
|
||||||
acceptKey = calcAcceptKey(key)
|
acceptKey = calcAcceptKey(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
netConn net.Conn
|
netConn net.Conn
|
||||||
br *bufio.Reader
|
br *bufio.Reader
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
h, ok := w.(http.Hijacker)
|
h, ok := w.(http.Hijacker)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, ErrHijacker
|
return nil, ErrHijacker
|
||||||
}
|
}
|
||||||
|
|
||||||
var rw *bufio.ReadWriter
|
var rw *bufio.ReadWriter
|
||||||
netConn, rw, err = h.Hijack()
|
netConn, rw, err = h.Hijack()
|
||||||
br = rw.Reader
|
br = rw.Reader
|
||||||
|
|
||||||
if br.Buffered() > 0 {
|
if br.Buffered() > 0 {
|
||||||
netConn.Close()
|
netConn.Close()
|
||||||
return nil, ErrNoEmptyConn
|
return nil, ErrNoEmptyConn
|
||||||
}
|
}
|
||||||
|
|
||||||
c := NewConn(netConn, true)
|
c := NewConn(netConn, true)
|
||||||
|
|
||||||
buf := bytes.NewBufferString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ")
|
buf := bytes.NewBufferString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ")
|
||||||
|
|
||||||
buf.WriteString(acceptKey)
|
buf.WriteString(acceptKey)
|
||||||
buf.WriteString("\r\n")
|
buf.WriteString("\r\n")
|
||||||
|
|
||||||
subProtol := selectSubProtocol(r)
|
subProtol := selectSubProtocol(r)
|
||||||
if len(subProtol) > 0 {
|
if len(subProtol) > 0 {
|
||||||
buf.WriteString("Sec-Websocket-Protocol: ")
|
buf.WriteString("Sec-Websocket-Protocol: ")
|
||||||
buf.WriteString(subProtol)
|
buf.WriteString(subProtol)
|
||||||
buf.WriteString("\r\n")
|
buf.WriteString("\r\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, vs := range responseHeader {
|
for k, vs := range responseHeader {
|
||||||
for _, v := range vs {
|
for _, v := range vs {
|
||||||
buf.WriteString(k)
|
buf.WriteString(k)
|
||||||
buf.WriteString(": ")
|
buf.WriteString(": ")
|
||||||
buf.WriteString(v)
|
buf.WriteString(v)
|
||||||
buf.WriteString("\r\n")
|
buf.WriteString("\r\n")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
buf.WriteString("\r\n")
|
buf.WriteString("\r\n")
|
||||||
|
|
||||||
if _, err = netConn.Write(buf.Bytes()); err != nil {
|
if _, err = netConn.Write(buf.Bytes()); err != nil {
|
||||||
netConn.Close()
|
netConn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func selectSubProtocol(r *http.Request) string {
|
func selectSubProtocol(r *http.Request) string {
|
||||||
h := r.Header.Get("Sec-Websocket-Protocol")
|
h := r.Header.Get("Sec-Websocket-Protocol")
|
||||||
if len(h) == 0 {
|
if len(h) == 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return strings.Split(h, ",")[0]
|
return strings.Split(h, ",")[0]
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,97 +1,98 @@
|
||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gorilla/websocket"
|
"net"
|
||||||
"net"
|
"net/http"
|
||||||
"net/http"
|
"net/url"
|
||||||
"net/url"
|
"testing"
|
||||||
"testing"
|
"time"
|
||||||
"time"
|
|
||||||
)
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
func TestWSServer(t *testing.T) {
|
|
||||||
http.HandleFunc("/test/server", func(w http.ResponseWriter, r *http.Request) {
|
func TestWSServer(t *testing.T) {
|
||||||
conn, err := Upgrade(w, r, nil)
|
http.HandleFunc("/test/server", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := Upgrade(w, r, nil)
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err.Error())
|
if err != nil {
|
||||||
}
|
t.Fatal(err.Error())
|
||||||
//err = conn.SetReadBuffer(1024 * 1024 * 4)
|
}
|
||||||
//if err != nil {
|
//err = conn.SetReadBuffer(1024 * 1024 * 4)
|
||||||
// println(err.Error())
|
//if err != nil {
|
||||||
//}
|
// println(err.Error())
|
||||||
//err = conn.SetWriteBuffer(1024 * 1024 * 4)
|
//}
|
||||||
|
//err = conn.SetWriteBuffer(1024 * 1024 * 4)
|
||||||
//if err != nil {
|
|
||||||
// println(err.Error())
|
//if err != nil {
|
||||||
//}
|
// println(err.Error())
|
||||||
|
//}
|
||||||
msgType, msg, err := conn.Read()
|
|
||||||
conn.Write(msg, false)
|
msgType, msg, err := conn.Read()
|
||||||
|
conn.Write(msg, false)
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err.Error())
|
if err != nil {
|
||||||
}
|
t.Fatal(err.Error())
|
||||||
|
}
|
||||||
if msgType != TextMessage {
|
|
||||||
t.Fatal("wrong msg type", msgType)
|
if msgType != TextMessage {
|
||||||
}
|
t.Fatal("wrong msg type", msgType)
|
||||||
|
}
|
||||||
msgType, msg, err = conn.ReadMessage()
|
|
||||||
if err != nil {
|
msgType, msg, err = conn.ReadMessage()
|
||||||
t.Fatal(err.Error())
|
if err != nil {
|
||||||
}
|
t.Fatal(err.Error())
|
||||||
|
}
|
||||||
if msgType != PingMessage {
|
|
||||||
t.Fatal("wrong msg type", msgType)
|
if msgType != PingMessage {
|
||||||
}
|
t.Fatal("wrong msg type", msgType)
|
||||||
|
}
|
||||||
err = conn.Pong([]byte("abc"))
|
|
||||||
|
err = conn.Pong([]byte("abc"))
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err.Error())
|
if err != nil {
|
||||||
}
|
t.Fatal(err.Error())
|
||||||
|
}
|
||||||
})
|
|
||||||
|
})
|
||||||
go http.ListenAndServe(":65500", nil)
|
|
||||||
time.Sleep(time.Second * 1)
|
go http.ListenAndServe(":65500", nil)
|
||||||
|
time.Sleep(time.Second * 1)
|
||||||
conn, err := net.Dial("tcp", "127.0.0.1:65500")
|
|
||||||
|
conn, err := net.Dial("tcp", "127.0.0.1:65500")
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err.Error())
|
if err != nil {
|
||||||
}
|
t.Fatal(err.Error())
|
||||||
ws, _, err := websocket.NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/server"}, nil, 1024, 1024)
|
}
|
||||||
|
ws, _, err := websocket.NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/server"}, nil, 1024, 1024)
|
||||||
ws.SetPongHandler(func(string) error {
|
|
||||||
println("pong")
|
ws.SetPongHandler(func(string) error {
|
||||||
return nil
|
println("pong")
|
||||||
})
|
return nil
|
||||||
|
})
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err.Error())
|
if err != nil {
|
||||||
}
|
t.Fatal(err.Error())
|
||||||
|
}
|
||||||
payload := make([]byte, 4*1024*1024)
|
|
||||||
for i := 0; i < 4*1024*1024; i++ {
|
payload := make([]byte, 4*1024*1024)
|
||||||
payload[i] = 'x'
|
for i := 0; i < 4*1024*1024; i++ {
|
||||||
}
|
payload[i] = 'x'
|
||||||
|
}
|
||||||
ws.WriteMessage(websocket.TextMessage, payload)
|
|
||||||
|
ws.WriteMessage(websocket.TextMessage, payload)
|
||||||
msgType, msg, err := ws.ReadMessage()
|
|
||||||
if err != nil {
|
msgType, msg, err := ws.ReadMessage()
|
||||||
t.Fatal(err.Error())
|
if err != nil {
|
||||||
}
|
t.Fatal(err.Error())
|
||||||
if msgType != websocket.TextMessage {
|
}
|
||||||
t.Fatal("invalid msg type", msgType)
|
if msgType != websocket.TextMessage {
|
||||||
}
|
t.Fatal("invalid msg type", msgType)
|
||||||
|
}
|
||||||
if string(msg) != string(payload) {
|
|
||||||
t.Fatal("invalid msg", string(msg))
|
if string(msg) != string(payload) {
|
||||||
|
t.Fatal("invalid msg", string(msg))
|
||||||
}
|
|
||||||
|
}
|
||||||
time.Sleep(time.Second * 1)
|
|
||||||
}
|
time.Sleep(time.Second * 1)
|
||||||
|
}
|
||||||
|
|
|
@ -1,36 +1,36 @@
|
||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||||
|
|
||||||
func calcAcceptKey(key string) string {
|
func calcAcceptKey(key string) string {
|
||||||
h := sha1.New()
|
h := sha1.New()
|
||||||
h.Write([]byte(key))
|
h.Write([]byte(key))
|
||||||
h.Write(keyGUID)
|
h.Write(keyGUID)
|
||||||
return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func calcKey() (string, error) {
|
func calcKey() (string, error) {
|
||||||
p := make([]byte, 16)
|
p := make([]byte, 16)
|
||||||
if _, err := io.ReadFull(rand.Reader, p); err != nil {
|
if _, err := io.ReadFull(rand.Reader, p); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return base64.StdEncoding.EncodeToString(p), nil
|
return base64.StdEncoding.EncodeToString(p), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandleCloseFrame(buf []byte) (int16, string, error) {
|
func HandleCloseFrame(buf []byte) (int16, string, error) {
|
||||||
|
|
||||||
if len(buf) < 2 {
|
if len(buf) < 2 {
|
||||||
return 0, "", errors.New("close frame msg's length less than 2")
|
return 0, "", errors.New("close frame msg's length less than 2")
|
||||||
}
|
}
|
||||||
code := int16(buf[0])<<8 + int16(buf[1])
|
code := int16(buf[0])<<8 + int16(buf[1])
|
||||||
reason := string(buf[2:])
|
reason := string(buf[2:])
|
||||||
return code, reason, nil
|
return code, reason, nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue