mirror of https://github.com/tidwall/tile38.git
replication auth
This commit is contained in:
parent
4ef89db63b
commit
c6619a529f
|
@ -83,7 +83,7 @@ func (conn *Conn) Do(command string) ([]byte, error) {
|
||||||
conn.pool = nil
|
conn.pool = nil
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
message, _, err := ReadMessage(conn.rd, nil)
|
message, _, _, err := ReadMessage(conn.rd, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.pool = nil
|
conn.pool = nil
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -96,7 +96,7 @@ func (conn *Conn) Do(command string) ([]byte, error) {
|
||||||
|
|
||||||
// ReadMessage returns the next message. Used when reading live connections
|
// ReadMessage returns the next message. Used when reading live connections
|
||||||
func (conn *Conn) ReadMessage() (message []byte, err error) {
|
func (conn *Conn) ReadMessage() (message []byte, err error) {
|
||||||
message, _, err = readMessage(conn.c, conn.rd)
|
message, _, _, err = readMessage(conn.c, conn.rd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.pool = nil
|
conn.pool = nil
|
||||||
return message, err
|
return message, err
|
||||||
|
@ -160,10 +160,10 @@ func WriteWebSocket(conn net.Conn, data []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadMessage reads the next message from a bufio.Reader.
|
// ReadMessage reads the next message from a bufio.Reader.
|
||||||
func readMessage(wr io.Writer, rd *bufio.Reader) (message []byte, proto Proto, err error) {
|
func readMessage(wr io.Writer, rd *bufio.Reader) (message []byte, proto Proto, auth string, err error) {
|
||||||
h, err := rd.Peek(1)
|
h, err := rd.Peek(1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, proto, err
|
return nil, proto, auth, err
|
||||||
}
|
}
|
||||||
switch h[0] {
|
switch h[0] {
|
||||||
case '$':
|
case '$':
|
||||||
|
@ -171,41 +171,41 @@ func readMessage(wr io.Writer, rd *bufio.Reader) (message []byte, proto Proto, e
|
||||||
}
|
}
|
||||||
message, proto, err = readTelnetMessage(rd)
|
message, proto, err = readTelnetMessage(rd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, proto, err
|
return nil, proto, auth, err
|
||||||
}
|
}
|
||||||
if len(message) > 6 && string(message[len(message)-9:len(message)-2]) == " HTTP/1" {
|
if len(message) > 6 && string(message[len(message)-9:len(message)-2]) == " HTTP/1" {
|
||||||
return readHTTPMessage(string(message), wr, rd)
|
return readHTTPMessage(string(message), wr, rd)
|
||||||
}
|
}
|
||||||
return message, proto, nil
|
return message, proto, auth, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
func ReadMessage(rd *bufio.Reader, wr io.Writer) (message []byte, proto Proto, err error) {
|
func ReadMessage(rd *bufio.Reader, wr io.Writer) (message []byte, proto Proto, auth string, err error) {
|
||||||
return readMessage(wr, rd)
|
return readMessage(wr, rd)
|
||||||
}
|
}
|
||||||
|
|
||||||
func readProtoMessage(rd *bufio.Reader) (message []byte, proto Proto, err error) {
|
func readProtoMessage(rd *bufio.Reader) (message []byte, proto Proto, auth string, err error) {
|
||||||
b, err := rd.ReadBytes(' ')
|
b, err := rd.ReadBytes(' ')
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, Native, err
|
return nil, Native, auth, err
|
||||||
}
|
}
|
||||||
if len(b) > 0 && b[0] != '$' {
|
if len(b) > 0 && b[0] != '$' {
|
||||||
return nil, Native, errors.New("not a proto message")
|
return nil, Native, auth, errors.New("not a proto message")
|
||||||
}
|
}
|
||||||
n, err := strconv.ParseUint(string(b[1:len(b)-1]), 10, 32)
|
n, err := strconv.ParseUint(string(b[1:len(b)-1]), 10, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, Native, errors.New("invalid size")
|
return nil, Native, auth, errors.New("invalid size")
|
||||||
}
|
}
|
||||||
if n > MaxMessageSize {
|
if n > MaxMessageSize {
|
||||||
return nil, Native, errors.New("message too big")
|
return nil, Native, auth, errors.New("message too big")
|
||||||
}
|
}
|
||||||
b = make([]byte, int(n)+2)
|
b = make([]byte, int(n)+2)
|
||||||
if _, err := io.ReadFull(rd, b); err != nil {
|
if _, err := io.ReadFull(rd, b); err != nil {
|
||||||
return nil, Native, err
|
return nil, Native, auth, err
|
||||||
}
|
}
|
||||||
if b[len(b)-2] != '\r' || b[len(b)-1] != '\n' {
|
if b[len(b)-2] != '\r' || b[len(b)-1] != '\n' {
|
||||||
return nil, Native, errors.New("expecting crlf suffix")
|
return nil, Native, auth, errors.New("expecting crlf suffix")
|
||||||
}
|
}
|
||||||
return b[:len(b)-2], Native, nil
|
return b[:len(b)-2], Native, auth, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func readTelnetMessage(rd *bufio.Reader) (command []byte, proto Proto, err error) {
|
func readTelnetMessage(rd *bufio.Reader) (command []byte, proto Proto, err error) {
|
||||||
|
@ -221,45 +221,56 @@ func readTelnetMessage(rd *bufio.Reader) (command []byte, proto Proto, err error
|
||||||
return line, Telnet, nil
|
return line, Telnet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func readHTTPMessage(line string, wr io.Writer, rd *bufio.Reader) (command []byte, proto Proto, err error) {
|
func readHTTPMessage(line string, wr io.Writer, rd *bufio.Reader) (command []byte, proto Proto, auth string, err error) {
|
||||||
|
proto = HTTP
|
||||||
parts := strings.Split(line, " ")
|
parts := strings.Split(line, " ")
|
||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
return nil, HTTP, errors.New("invalid HTTP request")
|
err = errors.New("invalid HTTP request")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
method := parts[0]
|
method := parts[0]
|
||||||
path := parts[1]
|
path := parts[1]
|
||||||
if len(path) == 0 || path[0] != '/' {
|
if len(path) == 0 || path[0] != '/' {
|
||||||
return nil, HTTP, errors.New("invalid HTTP request")
|
err = errors.New("invalid HTTP request")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
path, err = url.QueryUnescape(path[1:])
|
path, err = url.QueryUnescape(path[1:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, HTTP, errors.New("invalid HTTP request")
|
err = errors.New("invalid HTTP request")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if method != "GET" && method != "POST" {
|
if method != "GET" && method != "POST" {
|
||||||
return nil, HTTP, errors.New("invalid HTTP method")
|
err = errors.New("invalid HTTP method")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
contentLength := 0
|
contentLength := 0
|
||||||
websocket := false
|
websocket := false
|
||||||
websocketVersion := 0
|
websocketVersion := 0
|
||||||
websocketKey := ""
|
websocketKey := ""
|
||||||
for {
|
for {
|
||||||
b, _, err := readTelnetMessage(rd) // read a header line
|
var b []byte
|
||||||
|
b, _, err = readTelnetMessage(rd) // read a header line
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, HTTP, nil
|
return
|
||||||
}
|
}
|
||||||
header := string(b)
|
header := string(b)
|
||||||
if header == "" {
|
if header == "" {
|
||||||
break // end of headers
|
break // end of headers
|
||||||
}
|
}
|
||||||
if header[0] == 'u' || header[0] == 'U' {
|
if header[0] == 'a' || header[0] == 'A' {
|
||||||
|
if strings.HasPrefix(strings.ToLower(header), "authorization:") {
|
||||||
|
auth = strings.TrimSpace(header[len("authorization:"):])
|
||||||
|
}
|
||||||
|
} else if header[0] == 'u' || header[0] == 'U' {
|
||||||
if strings.HasPrefix(strings.ToLower(header), "upgrade:") && strings.ToLower(strings.TrimSpace(header[len("upgrade:"):])) == "websocket" {
|
if strings.HasPrefix(strings.ToLower(header), "upgrade:") && strings.ToLower(strings.TrimSpace(header[len("upgrade:"):])) == "websocket" {
|
||||||
websocket = true
|
websocket = true
|
||||||
}
|
}
|
||||||
} else if header[0] == 's' || header[0] == 'S' {
|
} else if header[0] == 's' || header[0] == 'S' {
|
||||||
if strings.HasPrefix(strings.ToLower(header), "sec-websocket-version:") {
|
if strings.HasPrefix(strings.ToLower(header), "sec-websocket-version:") {
|
||||||
n, err := strconv.ParseUint(strings.TrimSpace(header[len("sec-websocket-version:"):]), 10, 64)
|
var n uint64
|
||||||
|
n, err = strconv.ParseUint(strings.TrimSpace(header[len("sec-websocket-version:"):]), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, HTTP, err
|
return
|
||||||
}
|
}
|
||||||
websocketVersion = int(n)
|
websocketVersion = int(n)
|
||||||
} else if strings.HasPrefix(strings.ToLower(header), "sec-websocket-key:") {
|
} else if strings.HasPrefix(strings.ToLower(header), "sec-websocket-key:") {
|
||||||
|
@ -267,33 +278,35 @@ func readHTTPMessage(line string, wr io.Writer, rd *bufio.Reader) (command []byt
|
||||||
}
|
}
|
||||||
} else if header[0] == 'c' || header[0] == 'C' {
|
} else if header[0] == 'c' || header[0] == 'C' {
|
||||||
if strings.HasPrefix(strings.ToLower(header), "content-length:") {
|
if strings.HasPrefix(strings.ToLower(header), "content-length:") {
|
||||||
n, err := strconv.ParseUint(strings.TrimSpace(header[len("content-length:"):]), 10, 64)
|
var n uint64
|
||||||
|
n, err = strconv.ParseUint(strings.TrimSpace(header[len("content-length:"):]), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, HTTP, err
|
return
|
||||||
}
|
}
|
||||||
contentLength = int(n)
|
contentLength = int(n)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if websocket && websocketVersion >= 13 && websocketKey != "" {
|
if websocket && websocketVersion >= 13 && websocketKey != "" {
|
||||||
|
proto = WebSocket
|
||||||
if wr == nil {
|
if wr == nil {
|
||||||
return nil, WebSocket, errors.New("connection is nil")
|
err = errors.New("connection is nil")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
sum := sha1.Sum([]byte(websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
|
sum := sha1.Sum([]byte(websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
|
||||||
accept := base64.StdEncoding.EncodeToString(sum[:])
|
accept := base64.StdEncoding.EncodeToString(sum[:])
|
||||||
wshead := "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + accept + "\r\n\r\n"
|
wshead := "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + accept + "\r\n\r\n"
|
||||||
_, err := wr.Write([]byte(wshead))
|
if _, err = wr.Write([]byte(wshead)); err != nil {
|
||||||
if err != nil {
|
return
|
||||||
return nil, WebSocket, err
|
|
||||||
}
|
}
|
||||||
return []byte(path), WebSocket, nil
|
|
||||||
} else if contentLength > 0 {
|
} else if contentLength > 0 {
|
||||||
|
proto = HTTP
|
||||||
buf := make([]byte, contentLength)
|
buf := make([]byte, contentLength)
|
||||||
_, err := io.ReadFull(rd, buf)
|
if _, err = io.ReadFull(rd, buf); err != nil {
|
||||||
if err != nil {
|
return
|
||||||
return nil, HTTP, err
|
|
||||||
}
|
}
|
||||||
path += string(buf)
|
path += string(buf)
|
||||||
}
|
}
|
||||||
return []byte(path), HTTP, nil
|
command = []byte(path)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,8 +12,8 @@ type Standard struct {
|
||||||
Elapsed string `json:"elapsed"`
|
Elapsed string `json:"elapsed"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stats represents tile38 server statistics.
|
// Server represents tile38 server statistics.
|
||||||
type Stats struct {
|
type ServerStats struct {
|
||||||
Standard
|
Standard
|
||||||
Stats struct {
|
Stats struct {
|
||||||
ServerID string `json:"id"`
|
ServerID string `json:"id"`
|
||||||
|
@ -30,9 +30,9 @@ type Stats struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stats returns tile38 server statistics.
|
// Stats returns tile38 server statistics.
|
||||||
func (conn *Conn) Stats() (Stats, error) {
|
func (conn *Conn) Server() (ServerStats, error) {
|
||||||
var stats Stats
|
var stats ServerStats
|
||||||
msg, err := conn.Do("stats")
|
msg, err := conn.Do("server")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return stats, err
|
return stats, err
|
||||||
}
|
}
|
||||||
|
@ -40,7 +40,7 @@ func (conn *Conn) Stats() (Stats, error) {
|
||||||
return stats, err
|
return stats, err
|
||||||
}
|
}
|
||||||
if !stats.OK {
|
if !stats.OK {
|
||||||
if stats.Err == "" {
|
if stats.Err != "" {
|
||||||
return stats, errors.New(stats.Err)
|
return stats, errors.New(stats.Err)
|
||||||
}
|
}
|
||||||
return stats, errors.New("not ok")
|
return stats, errors.New("not ok")
|
||||||
|
|
|
@ -251,7 +251,7 @@ func (c *Controller) liveAOF(pos int64, conn net.Conn, rd *bufio.Reader) error {
|
||||||
cond.L.Unlock()
|
cond.L.Unlock()
|
||||||
}()
|
}()
|
||||||
for {
|
for {
|
||||||
command, _, err := client.ReadMessage(rd, nil)
|
command, _, _, err := client.ReadMessage(rd, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != io.EOF {
|
if err != io.EOF {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
|
|
|
@ -1,15 +0,0 @@
|
||||||
package controller
|
|
||||||
|
|
||||||
func (c *Controller) cmdAuth(line string) error {
|
|
||||||
var password string
|
|
||||||
if line, password = token(line); password == "" {
|
|
||||||
return errInvalidNumberOfArguments
|
|
||||||
}
|
|
||||||
if line != "" {
|
|
||||||
return errInvalidNumberOfArguments
|
|
||||||
}
|
|
||||||
|
|
||||||
println(password)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -183,7 +183,7 @@ func (c *Controller) handleInputCommand(conn *server.Conn, line string, w io.Wri
|
||||||
return writeErr(errors.New("empty command"))
|
return writeErr(errors.New("empty command"))
|
||||||
}
|
}
|
||||||
|
|
||||||
if !conn.Authenticated {
|
if !conn.Authenticated || cmd == "auth" {
|
||||||
c.mu.RLock()
|
c.mu.RLock()
|
||||||
requirePass := c.config.RequirePass
|
requirePass := c.config.RequirePass
|
||||||
c.mu.RUnlock()
|
c.mu.RUnlock()
|
||||||
|
@ -194,11 +194,14 @@ func (c *Controller) handleInputCommand(conn *server.Conn, line string, w io.Wri
|
||||||
return writeErr(errors.New("authentication required"))
|
return writeErr(errors.New("authentication required"))
|
||||||
}
|
}
|
||||||
password, _ := token(line)
|
password, _ := token(line)
|
||||||
if requirePass == strings.TrimSpace(password) {
|
if requirePass != strings.TrimSpace(password) {
|
||||||
conn.Authenticated = true
|
|
||||||
} else {
|
|
||||||
return writeErr(errors.New("invalid password"))
|
return writeErr(errors.New("invalid password"))
|
||||||
}
|
}
|
||||||
|
conn.Authenticated = true
|
||||||
|
w.Write([]byte(`{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}"))
|
||||||
|
return nil
|
||||||
|
} else if cmd == "auth" {
|
||||||
|
return writeErr(errors.New("invalid password"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -318,13 +321,10 @@ func (c *Controller) command(line string, w io.Writer) (resp string, d commandDe
|
||||||
case "readonly":
|
case "readonly":
|
||||||
err = c.cmdReadOnly(nline)
|
err = c.cmdReadOnly(nline)
|
||||||
resp = okResp()
|
resp = okResp()
|
||||||
case "auth":
|
|
||||||
err = c.cmdAuth(nline)
|
|
||||||
resp = okResp()
|
|
||||||
case "stats":
|
case "stats":
|
||||||
resp, err = c.cmdServer(nline)
|
|
||||||
case "server":
|
|
||||||
resp, err = c.cmdStats(nline)
|
resp, err = c.cmdStats(nline)
|
||||||
|
case "server":
|
||||||
|
resp, err = c.cmdServer(nline)
|
||||||
case "scan":
|
case "scan":
|
||||||
err = c.cmdScan(nline, w)
|
err = c.cmdScan(nline, w)
|
||||||
case "nearby":
|
case "nearby":
|
||||||
|
|
|
@ -45,6 +45,7 @@ func (c *Controller) cmdFollow(line string) error {
|
||||||
}
|
}
|
||||||
port := int(n)
|
port := int(n)
|
||||||
update = c.config.FollowHost != host || c.config.FollowPort != port
|
update = c.config.FollowHost != host || c.config.FollowPort != port
|
||||||
|
auth := c.config.LeaderAuth
|
||||||
if update {
|
if update {
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
conn, err := client.DialTimeout(fmt.Sprintf("%s:%d", host, port), time.Second*2)
|
conn, err := client.DialTimeout(fmt.Sprintf("%s:%d", host, port), time.Second*2)
|
||||||
|
@ -53,7 +54,12 @@ func (c *Controller) cmdFollow(line string) error {
|
||||||
return fmt.Errorf("cannot follow: %v", err)
|
return fmt.Errorf("cannot follow: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
msg, err := conn.Stats()
|
if auth != "" {
|
||||||
|
if err := c.followDoLeaderAuth(conn, auth); err != nil {
|
||||||
|
return fmt.Errorf("cannot follow: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
msg, err := conn.Server()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
return fmt.Errorf("cannot follow: %v", err)
|
return fmt.Errorf("cannot follow: %v", err)
|
||||||
|
@ -103,6 +109,25 @@ func (c *Controller) followHandleCommand(line string, followc uint64, w io.Write
|
||||||
return c.aofsz, nil
|
return c.aofsz, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Controller) followDoLeaderAuth(conn *client.Conn, auth string) error {
|
||||||
|
data, err := conn.Do("AUTH " + auth)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var msg client.Standard
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !msg.OK {
|
||||||
|
if msg.Err != "" {
|
||||||
|
return errors.New(msg.Err)
|
||||||
|
} else {
|
||||||
|
return errors.New("cannot follow: auth no ok")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Controller) followStep(host string, port int, followc uint64) error {
|
func (c *Controller) followStep(host string, port int, followc uint64) error {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
if c.followc != followc {
|
if c.followc != followc {
|
||||||
|
@ -110,6 +135,7 @@ func (c *Controller) followStep(host string, port int, followc uint64) error {
|
||||||
return errNoLongerFollowing
|
return errNoLongerFollowing
|
||||||
}
|
}
|
||||||
c.fcup = false
|
c.fcup = false
|
||||||
|
auth := c.config.LeaderAuth
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
addr := fmt.Sprintf("%s:%d", host, port)
|
addr := fmt.Sprintf("%s:%d", host, port)
|
||||||
// check if we are following self
|
// check if we are following self
|
||||||
|
@ -118,7 +144,13 @@ func (c *Controller) followStep(host string, port int, followc uint64) error {
|
||||||
return fmt.Errorf("cannot follow: %v", err)
|
return fmt.Errorf("cannot follow: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
stats, err := conn.Stats()
|
if auth != "" {
|
||||||
|
if err := c.followDoLeaderAuth(conn, auth); err != nil {
|
||||||
|
return fmt.Errorf("cannot follow: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := conn.Server()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot follow: %v", err)
|
return fmt.Errorf("cannot follow: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -134,13 +166,6 @@ func (c *Controller) followStep(host string, port int, followc uint64) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// make real connection
|
|
||||||
conn, err = client.Dial(addr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
msg, err := conn.Do(fmt.Sprintf("aof %d", pos))
|
msg, err := conn.Do(fmt.Sprintf("aof %d", pos))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -198,7 +223,7 @@ func (c *Controller) follow(host string, port int, followc uint64) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err != nil && err != io.EOF {
|
if err != nil && err != io.EOF {
|
||||||
log.Debug("follow: " + err.Error())
|
log.Error("follow: " + err.Error())
|
||||||
}
|
}
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
|
|
|
@ -103,7 +103,7 @@ func (c *Controller) goLive(inerr error, conn net.Conn, rd *bufio.Reader, websoc
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}()
|
}()
|
||||||
for {
|
for {
|
||||||
command, _, err := client.ReadMessage(rd, nil)
|
command, _, _, err := client.ReadMessage(rd, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != io.EOF {
|
if err != io.EOF {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
|
|
|
@ -68,6 +68,13 @@ func ListenAndServe(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func writeCommandErr(proto client.Proto, conn *Conn, err error) error {
|
||||||
|
if proto == client.HTTP || proto == client.WebSocket {
|
||||||
|
conn.Write([]byte(`HTTP/1.1 500 ` + err.Error() + "\r\nConnection: close\r\n\r\n"))
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func handleConn(
|
func handleConn(
|
||||||
conn *Conn,
|
conn *Conn,
|
||||||
protected func() bool,
|
protected func() bool,
|
||||||
|
@ -92,7 +99,7 @@ func handleConn(
|
||||||
rd := bufio.NewReader(conn)
|
rd := bufio.NewReader(conn)
|
||||||
for i := 0; ; i++ {
|
for i := 0; ; i++ {
|
||||||
err := func() error {
|
err := func() error {
|
||||||
command, proto, err := client.ReadMessage(rd, conn)
|
command, proto, auth, err := client.ReadMessage(rd, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -100,12 +107,21 @@ func handleConn(
|
||||||
return io.EOF
|
return io.EOF
|
||||||
}
|
}
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
|
var denied bool
|
||||||
if err := handler(conn, command, rd, &b, proto == client.WebSocket); err != nil {
|
if (proto == client.HTTP || proto == client.WebSocket) && auth != "" {
|
||||||
if proto == client.HTTP {
|
if err := handler(conn, []byte("AUTH "+auth), rd, &b, proto == client.WebSocket); err != nil {
|
||||||
conn.Write([]byte(`HTTP/1.1 500 ` + err.Error() + "\r\nConnection: close\r\n\r\n"))
|
return writeCommandErr(proto, conn, err)
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(b.String(), `{"ok":false`) {
|
||||||
|
denied = true
|
||||||
|
} else {
|
||||||
|
b.Reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !denied {
|
||||||
|
if err := handler(conn, command, rd, &b, proto == client.WebSocket); err != nil {
|
||||||
|
return writeCommandErr(proto, conn, err)
|
||||||
}
|
}
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
switch proto {
|
switch proto {
|
||||||
case client.Native:
|
case client.Native:
|
||||||
|
|
Loading…
Reference in New Issue