replication auth

This commit is contained in:
Josh Baker 2016-03-08 08:35:43 -07:00
parent 4ef89db63b
commit c6619a529f
8 changed files with 123 additions and 84 deletions

View File

@ -83,7 +83,7 @@ func (conn *Conn) Do(command string) ([]byte, error) {
conn.pool = nil conn.pool = nil
return nil, err return nil, err
} }
message, _, err := ReadMessage(conn.rd, nil) message, _, _, err := ReadMessage(conn.rd, nil)
if err != nil { if err != nil {
conn.pool = nil conn.pool = nil
return nil, err return nil, err
@ -96,7 +96,7 @@ func (conn *Conn) Do(command string) ([]byte, error) {
// ReadMessage returns the next message. Used when reading live connections // ReadMessage returns the next message. Used when reading live connections
func (conn *Conn) ReadMessage() (message []byte, err error) { func (conn *Conn) ReadMessage() (message []byte, err error) {
message, _, err = readMessage(conn.c, conn.rd) message, _, _, err = readMessage(conn.c, conn.rd)
if err != nil { if err != nil {
conn.pool = nil conn.pool = nil
return message, err return message, err
@ -160,10 +160,10 @@ func WriteWebSocket(conn net.Conn, data []byte) error {
} }
// ReadMessage reads the next message from a bufio.Reader. // ReadMessage reads the next message from a bufio.Reader.
func readMessage(wr io.Writer, rd *bufio.Reader) (message []byte, proto Proto, err error) { func readMessage(wr io.Writer, rd *bufio.Reader) (message []byte, proto Proto, auth string, err error) {
h, err := rd.Peek(1) h, err := rd.Peek(1)
if err != nil { if err != nil {
return nil, proto, err return nil, proto, auth, err
} }
switch h[0] { switch h[0] {
case '$': case '$':
@ -171,41 +171,41 @@ func readMessage(wr io.Writer, rd *bufio.Reader) (message []byte, proto Proto, e
} }
message, proto, err = readTelnetMessage(rd) message, proto, err = readTelnetMessage(rd)
if err != nil { if err != nil {
return nil, proto, err return nil, proto, auth, err
} }
if len(message) > 6 && string(message[len(message)-9:len(message)-2]) == " HTTP/1" { if len(message) > 6 && string(message[len(message)-9:len(message)-2]) == " HTTP/1" {
return readHTTPMessage(string(message), wr, rd) return readHTTPMessage(string(message), wr, rd)
} }
return message, proto, nil return message, proto, auth, nil
} }
func ReadMessage(rd *bufio.Reader, wr io.Writer) (message []byte, proto Proto, err error) { func ReadMessage(rd *bufio.Reader, wr io.Writer) (message []byte, proto Proto, auth string, err error) {
return readMessage(wr, rd) return readMessage(wr, rd)
} }
func readProtoMessage(rd *bufio.Reader) (message []byte, proto Proto, err error) { func readProtoMessage(rd *bufio.Reader) (message []byte, proto Proto, auth string, err error) {
b, err := rd.ReadBytes(' ') b, err := rd.ReadBytes(' ')
if err != nil { if err != nil {
return nil, Native, err return nil, Native, auth, err
} }
if len(b) > 0 && b[0] != '$' { if len(b) > 0 && b[0] != '$' {
return nil, Native, errors.New("not a proto message") return nil, Native, auth, errors.New("not a proto message")
} }
n, err := strconv.ParseUint(string(b[1:len(b)-1]), 10, 32) n, err := strconv.ParseUint(string(b[1:len(b)-1]), 10, 32)
if err != nil { if err != nil {
return nil, Native, errors.New("invalid size") return nil, Native, auth, errors.New("invalid size")
} }
if n > MaxMessageSize { if n > MaxMessageSize {
return nil, Native, errors.New("message too big") return nil, Native, auth, errors.New("message too big")
} }
b = make([]byte, int(n)+2) b = make([]byte, int(n)+2)
if _, err := io.ReadFull(rd, b); err != nil { if _, err := io.ReadFull(rd, b); err != nil {
return nil, Native, err return nil, Native, auth, err
} }
if b[len(b)-2] != '\r' || b[len(b)-1] != '\n' { if b[len(b)-2] != '\r' || b[len(b)-1] != '\n' {
return nil, Native, errors.New("expecting crlf suffix") return nil, Native, auth, errors.New("expecting crlf suffix")
} }
return b[:len(b)-2], Native, nil return b[:len(b)-2], Native, auth, nil
} }
func readTelnetMessage(rd *bufio.Reader) (command []byte, proto Proto, err error) { func readTelnetMessage(rd *bufio.Reader) (command []byte, proto Proto, err error) {
@ -221,45 +221,56 @@ func readTelnetMessage(rd *bufio.Reader) (command []byte, proto Proto, err error
return line, Telnet, nil return line, Telnet, nil
} }
func readHTTPMessage(line string, wr io.Writer, rd *bufio.Reader) (command []byte, proto Proto, err error) { func readHTTPMessage(line string, wr io.Writer, rd *bufio.Reader) (command []byte, proto Proto, auth string, err error) {
proto = HTTP
parts := strings.Split(line, " ") parts := strings.Split(line, " ")
if len(parts) != 3 { if len(parts) != 3 {
return nil, HTTP, errors.New("invalid HTTP request") err = errors.New("invalid HTTP request")
return
} }
method := parts[0] method := parts[0]
path := parts[1] path := parts[1]
if len(path) == 0 || path[0] != '/' { if len(path) == 0 || path[0] != '/' {
return nil, HTTP, errors.New("invalid HTTP request") err = errors.New("invalid HTTP request")
return
} }
path, err = url.QueryUnescape(path[1:]) path, err = url.QueryUnescape(path[1:])
if err != nil { if err != nil {
return nil, HTTP, errors.New("invalid HTTP request") err = errors.New("invalid HTTP request")
return
} }
if method != "GET" && method != "POST" { if method != "GET" && method != "POST" {
return nil, HTTP, errors.New("invalid HTTP method") err = errors.New("invalid HTTP method")
return
} }
contentLength := 0 contentLength := 0
websocket := false websocket := false
websocketVersion := 0 websocketVersion := 0
websocketKey := "" websocketKey := ""
for { for {
b, _, err := readTelnetMessage(rd) // read a header line var b []byte
b, _, err = readTelnetMessage(rd) // read a header line
if err != nil { if err != nil {
return nil, HTTP, nil return
} }
header := string(b) header := string(b)
if header == "" { if header == "" {
break // end of headers break // end of headers
} }
if header[0] == 'u' || header[0] == 'U' { if header[0] == 'a' || header[0] == 'A' {
if strings.HasPrefix(strings.ToLower(header), "authorization:") {
auth = strings.TrimSpace(header[len("authorization:"):])
}
} else if header[0] == 'u' || header[0] == 'U' {
if strings.HasPrefix(strings.ToLower(header), "upgrade:") && strings.ToLower(strings.TrimSpace(header[len("upgrade:"):])) == "websocket" { if strings.HasPrefix(strings.ToLower(header), "upgrade:") && strings.ToLower(strings.TrimSpace(header[len("upgrade:"):])) == "websocket" {
websocket = true websocket = true
} }
} else if header[0] == 's' || header[0] == 'S' { } else if header[0] == 's' || header[0] == 'S' {
if strings.HasPrefix(strings.ToLower(header), "sec-websocket-version:") { if strings.HasPrefix(strings.ToLower(header), "sec-websocket-version:") {
n, err := strconv.ParseUint(strings.TrimSpace(header[len("sec-websocket-version:"):]), 10, 64) var n uint64
n, err = strconv.ParseUint(strings.TrimSpace(header[len("sec-websocket-version:"):]), 10, 64)
if err != nil { if err != nil {
return nil, HTTP, err return
} }
websocketVersion = int(n) websocketVersion = int(n)
} else if strings.HasPrefix(strings.ToLower(header), "sec-websocket-key:") { } else if strings.HasPrefix(strings.ToLower(header), "sec-websocket-key:") {
@ -267,33 +278,35 @@ func readHTTPMessage(line string, wr io.Writer, rd *bufio.Reader) (command []byt
} }
} else if header[0] == 'c' || header[0] == 'C' { } else if header[0] == 'c' || header[0] == 'C' {
if strings.HasPrefix(strings.ToLower(header), "content-length:") { if strings.HasPrefix(strings.ToLower(header), "content-length:") {
n, err := strconv.ParseUint(strings.TrimSpace(header[len("content-length:"):]), 10, 64) var n uint64
n, err = strconv.ParseUint(strings.TrimSpace(header[len("content-length:"):]), 10, 64)
if err != nil { if err != nil {
return nil, HTTP, err return
} }
contentLength = int(n) contentLength = int(n)
} }
} }
} }
if websocket && websocketVersion >= 13 && websocketKey != "" { if websocket && websocketVersion >= 13 && websocketKey != "" {
proto = WebSocket
if wr == nil { if wr == nil {
return nil, WebSocket, errors.New("connection is nil") err = errors.New("connection is nil")
return
} }
sum := sha1.Sum([]byte(websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) sum := sha1.Sum([]byte(websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
accept := base64.StdEncoding.EncodeToString(sum[:]) accept := base64.StdEncoding.EncodeToString(sum[:])
wshead := "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + accept + "\r\n\r\n" wshead := "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + accept + "\r\n\r\n"
_, err := wr.Write([]byte(wshead)) if _, err = wr.Write([]byte(wshead)); err != nil {
if err != nil { return
return nil, WebSocket, err
} }
return []byte(path), WebSocket, nil
} else if contentLength > 0 { } else if contentLength > 0 {
proto = HTTP
buf := make([]byte, contentLength) buf := make([]byte, contentLength)
_, err := io.ReadFull(rd, buf) if _, err = io.ReadFull(rd, buf); err != nil {
if err != nil { return
return nil, HTTP, err
} }
path += string(buf) path += string(buf)
} }
return []byte(path), HTTP, nil command = []byte(path)
return
} }

View File

@ -12,8 +12,8 @@ type Standard struct {
Elapsed string `json:"elapsed"` Elapsed string `json:"elapsed"`
} }
// Stats represents tile38 server statistics. // Server represents tile38 server statistics.
type Stats struct { type ServerStats struct {
Standard Standard
Stats struct { Stats struct {
ServerID string `json:"id"` ServerID string `json:"id"`
@ -30,9 +30,9 @@ type Stats struct {
} }
// Stats returns tile38 server statistics. // Stats returns tile38 server statistics.
func (conn *Conn) Stats() (Stats, error) { func (conn *Conn) Server() (ServerStats, error) {
var stats Stats var stats ServerStats
msg, err := conn.Do("stats") msg, err := conn.Do("server")
if err != nil { if err != nil {
return stats, err return stats, err
} }
@ -40,7 +40,7 @@ func (conn *Conn) Stats() (Stats, error) {
return stats, err return stats, err
} }
if !stats.OK { if !stats.OK {
if stats.Err == "" { if stats.Err != "" {
return stats, errors.New(stats.Err) return stats, errors.New(stats.Err)
} }
return stats, errors.New("not ok") return stats, errors.New("not ok")

View File

@ -251,7 +251,7 @@ func (c *Controller) liveAOF(pos int64, conn net.Conn, rd *bufio.Reader) error {
cond.L.Unlock() cond.L.Unlock()
}() }()
for { for {
command, _, err := client.ReadMessage(rd, nil) command, _, _, err := client.ReadMessage(rd, nil)
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Error(err) log.Error(err)

View File

@ -1,15 +0,0 @@
package controller
func (c *Controller) cmdAuth(line string) error {
var password string
if line, password = token(line); password == "" {
return errInvalidNumberOfArguments
}
if line != "" {
return errInvalidNumberOfArguments
}
println(password)
return nil
}

View File

@ -183,7 +183,7 @@ func (c *Controller) handleInputCommand(conn *server.Conn, line string, w io.Wri
return writeErr(errors.New("empty command")) return writeErr(errors.New("empty command"))
} }
if !conn.Authenticated { if !conn.Authenticated || cmd == "auth" {
c.mu.RLock() c.mu.RLock()
requirePass := c.config.RequirePass requirePass := c.config.RequirePass
c.mu.RUnlock() c.mu.RUnlock()
@ -194,11 +194,14 @@ func (c *Controller) handleInputCommand(conn *server.Conn, line string, w io.Wri
return writeErr(errors.New("authentication required")) return writeErr(errors.New("authentication required"))
} }
password, _ := token(line) password, _ := token(line)
if requirePass == strings.TrimSpace(password) { if requirePass != strings.TrimSpace(password) {
conn.Authenticated = true
} else {
return writeErr(errors.New("invalid password")) return writeErr(errors.New("invalid password"))
} }
conn.Authenticated = true
w.Write([]byte(`{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}"))
return nil
} else if cmd == "auth" {
return writeErr(errors.New("invalid password"))
} }
} }
@ -318,13 +321,10 @@ func (c *Controller) command(line string, w io.Writer) (resp string, d commandDe
case "readonly": case "readonly":
err = c.cmdReadOnly(nline) err = c.cmdReadOnly(nline)
resp = okResp() resp = okResp()
case "auth":
err = c.cmdAuth(nline)
resp = okResp()
case "stats": case "stats":
resp, err = c.cmdServer(nline)
case "server":
resp, err = c.cmdStats(nline) resp, err = c.cmdStats(nline)
case "server":
resp, err = c.cmdServer(nline)
case "scan": case "scan":
err = c.cmdScan(nline, w) err = c.cmdScan(nline, w)
case "nearby": case "nearby":

View File

@ -45,6 +45,7 @@ func (c *Controller) cmdFollow(line string) error {
} }
port := int(n) port := int(n)
update = c.config.FollowHost != host || c.config.FollowPort != port update = c.config.FollowHost != host || c.config.FollowPort != port
auth := c.config.LeaderAuth
if update { if update {
c.mu.Unlock() c.mu.Unlock()
conn, err := client.DialTimeout(fmt.Sprintf("%s:%d", host, port), time.Second*2) conn, err := client.DialTimeout(fmt.Sprintf("%s:%d", host, port), time.Second*2)
@ -53,7 +54,12 @@ func (c *Controller) cmdFollow(line string) error {
return fmt.Errorf("cannot follow: %v", err) return fmt.Errorf("cannot follow: %v", err)
} }
defer conn.Close() defer conn.Close()
msg, err := conn.Stats() if auth != "" {
if err := c.followDoLeaderAuth(conn, auth); err != nil {
return fmt.Errorf("cannot follow: %v", err)
}
}
msg, err := conn.Server()
if err != nil { if err != nil {
c.mu.Lock() c.mu.Lock()
return fmt.Errorf("cannot follow: %v", err) return fmt.Errorf("cannot follow: %v", err)
@ -103,6 +109,25 @@ func (c *Controller) followHandleCommand(line string, followc uint64, w io.Write
return c.aofsz, nil return c.aofsz, nil
} }
func (c *Controller) followDoLeaderAuth(conn *client.Conn, auth string) error {
data, err := conn.Do("AUTH " + auth)
if err != nil {
return err
}
var msg client.Standard
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if !msg.OK {
if msg.Err != "" {
return errors.New(msg.Err)
} else {
return errors.New("cannot follow: auth no ok")
}
}
return nil
}
func (c *Controller) followStep(host string, port int, followc uint64) error { func (c *Controller) followStep(host string, port int, followc uint64) error {
c.mu.Lock() c.mu.Lock()
if c.followc != followc { if c.followc != followc {
@ -110,6 +135,7 @@ func (c *Controller) followStep(host string, port int, followc uint64) error {
return errNoLongerFollowing return errNoLongerFollowing
} }
c.fcup = false c.fcup = false
auth := c.config.LeaderAuth
c.mu.Unlock() c.mu.Unlock()
addr := fmt.Sprintf("%s:%d", host, port) addr := fmt.Sprintf("%s:%d", host, port)
// check if we are following self // check if we are following self
@ -118,7 +144,13 @@ func (c *Controller) followStep(host string, port int, followc uint64) error {
return fmt.Errorf("cannot follow: %v", err) return fmt.Errorf("cannot follow: %v", err)
} }
defer conn.Close() defer conn.Close()
stats, err := conn.Stats() if auth != "" {
if err := c.followDoLeaderAuth(conn, auth); err != nil {
return fmt.Errorf("cannot follow: %v", err)
}
}
stats, err := conn.Server()
if err != nil { if err != nil {
return fmt.Errorf("cannot follow: %v", err) return fmt.Errorf("cannot follow: %v", err)
} }
@ -134,13 +166,6 @@ func (c *Controller) followStep(host string, port int, followc uint64) error {
return err return err
} }
// make real connection
conn, err = client.Dial(addr)
if err != nil {
return err
}
defer conn.Close()
msg, err := conn.Do(fmt.Sprintf("aof %d", pos)) msg, err := conn.Do(fmt.Sprintf("aof %d", pos))
if err != nil { if err != nil {
return err return err
@ -198,7 +223,7 @@ func (c *Controller) follow(host string, port int, followc uint64) {
return return
} }
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
log.Debug("follow: " + err.Error()) log.Error("follow: " + err.Error())
} }
time.Sleep(time.Second) time.Sleep(time.Second)
} }

View File

@ -103,7 +103,7 @@ func (c *Controller) goLive(inerr error, conn net.Conn, rd *bufio.Reader, websoc
conn.Close() conn.Close()
}() }()
for { for {
command, _, err := client.ReadMessage(rd, nil) command, _, _, err := client.ReadMessage(rd, nil)
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Error(err) log.Error(err)

View File

@ -68,6 +68,13 @@ func ListenAndServe(
} }
} }
func writeCommandErr(proto client.Proto, conn *Conn, err error) error {
if proto == client.HTTP || proto == client.WebSocket {
conn.Write([]byte(`HTTP/1.1 500 ` + err.Error() + "\r\nConnection: close\r\n\r\n"))
}
return err
}
func handleConn( func handleConn(
conn *Conn, conn *Conn,
protected func() bool, protected func() bool,
@ -92,7 +99,7 @@ func handleConn(
rd := bufio.NewReader(conn) rd := bufio.NewReader(conn)
for i := 0; ; i++ { for i := 0; ; i++ {
err := func() error { err := func() error {
command, proto, err := client.ReadMessage(rd, conn) command, proto, auth, err := client.ReadMessage(rd, conn)
if err != nil { if err != nil {
return err return err
} }
@ -100,12 +107,21 @@ func handleConn(
return io.EOF return io.EOF
} }
var b bytes.Buffer var b bytes.Buffer
var denied bool
if err := handler(conn, command, rd, &b, proto == client.WebSocket); err != nil { if (proto == client.HTTP || proto == client.WebSocket) && auth != "" {
if proto == client.HTTP { if err := handler(conn, []byte("AUTH "+auth), rd, &b, proto == client.WebSocket); err != nil {
conn.Write([]byte(`HTTP/1.1 500 ` + err.Error() + "\r\nConnection: close\r\n\r\n")) return writeCommandErr(proto, conn, err)
}
if strings.HasPrefix(b.String(), `{"ok":false`) {
denied = true
} else {
b.Reset()
}
}
if !denied {
if err := handler(conn, command, rd, &b, proto == client.WebSocket); err != nil {
return writeCommandErr(proto, conn, err)
} }
return err
} }
switch proto { switch proto {
case client.Native: case client.Native: