From c6619a529f92e133601c84e900b4d55f65770d42 Mon Sep 17 00:00:00 2001 From: Josh Baker Date: Tue, 8 Mar 2016 08:35:43 -0700 Subject: [PATCH] replication auth --- client/conn.go | 85 +++++++++++++++++++++---------------- client/helper.go | 12 +++--- controller/aof.go | 2 +- controller/auth.go | 15 ------- controller/controller.go | 18 ++++---- controller/follow.go | 45 +++++++++++++++----- controller/live.go | 2 +- controller/server/server.go | 28 +++++++++--- 8 files changed, 123 insertions(+), 84 deletions(-) delete mode 100644 controller/auth.go diff --git a/client/conn.go b/client/conn.go index e89e57c5..fdf68c4c 100644 --- a/client/conn.go +++ b/client/conn.go @@ -83,7 +83,7 @@ func (conn *Conn) Do(command string) ([]byte, error) { conn.pool = nil return nil, err } - message, _, err := ReadMessage(conn.rd, nil) + message, _, _, err := ReadMessage(conn.rd, nil) if err != nil { conn.pool = nil return nil, err @@ -96,7 +96,7 @@ func (conn *Conn) Do(command string) ([]byte, error) { // ReadMessage returns the next message. Used when reading live connections func (conn *Conn) ReadMessage() (message []byte, err error) { - message, _, err = readMessage(conn.c, conn.rd) + message, _, _, err = readMessage(conn.c, conn.rd) if err != nil { conn.pool = nil return message, err @@ -160,10 +160,10 @@ func WriteWebSocket(conn net.Conn, data []byte) error { } // ReadMessage reads the next message from a bufio.Reader. -func readMessage(wr io.Writer, rd *bufio.Reader) (message []byte, proto Proto, err error) { +func readMessage(wr io.Writer, rd *bufio.Reader) (message []byte, proto Proto, auth string, err error) { h, err := rd.Peek(1) if err != nil { - return nil, proto, err + return nil, proto, auth, err } switch h[0] { case '$': @@ -171,41 +171,41 @@ func readMessage(wr io.Writer, rd *bufio.Reader) (message []byte, proto Proto, e } message, proto, err = readTelnetMessage(rd) if err != nil { - return nil, proto, err + return nil, proto, auth, err } if len(message) > 6 && string(message[len(message)-9:len(message)-2]) == " HTTP/1" { return readHTTPMessage(string(message), wr, rd) } - return message, proto, nil + return message, proto, auth, nil } -func ReadMessage(rd *bufio.Reader, wr io.Writer) (message []byte, proto Proto, err error) { +func ReadMessage(rd *bufio.Reader, wr io.Writer) (message []byte, proto Proto, auth string, err error) { return readMessage(wr, rd) } -func readProtoMessage(rd *bufio.Reader) (message []byte, proto Proto, err error) { +func readProtoMessage(rd *bufio.Reader) (message []byte, proto Proto, auth string, err error) { b, err := rd.ReadBytes(' ') if err != nil { - return nil, Native, err + return nil, Native, auth, err } if len(b) > 0 && b[0] != '$' { - return nil, Native, errors.New("not a proto message") + return nil, Native, auth, errors.New("not a proto message") } n, err := strconv.ParseUint(string(b[1:len(b)-1]), 10, 32) if err != nil { - return nil, Native, errors.New("invalid size") + return nil, Native, auth, errors.New("invalid size") } if n > MaxMessageSize { - return nil, Native, errors.New("message too big") + return nil, Native, auth, errors.New("message too big") } b = make([]byte, int(n)+2) if _, err := io.ReadFull(rd, b); err != nil { - return nil, Native, err + return nil, Native, auth, err } if b[len(b)-2] != '\r' || b[len(b)-1] != '\n' { - return nil, Native, errors.New("expecting crlf suffix") + return nil, Native, auth, errors.New("expecting crlf suffix") } - return b[:len(b)-2], Native, nil + return b[:len(b)-2], Native, auth, nil } func readTelnetMessage(rd *bufio.Reader) (command []byte, proto Proto, err error) { @@ -221,45 +221,56 @@ func readTelnetMessage(rd *bufio.Reader) (command []byte, proto Proto, err error return line, Telnet, nil } -func readHTTPMessage(line string, wr io.Writer, rd *bufio.Reader) (command []byte, proto Proto, err error) { +func readHTTPMessage(line string, wr io.Writer, rd *bufio.Reader) (command []byte, proto Proto, auth string, err error) { + proto = HTTP parts := strings.Split(line, " ") if len(parts) != 3 { - return nil, HTTP, errors.New("invalid HTTP request") + err = errors.New("invalid HTTP request") + return } method := parts[0] path := parts[1] if len(path) == 0 || path[0] != '/' { - return nil, HTTP, errors.New("invalid HTTP request") + err = errors.New("invalid HTTP request") + return } path, err = url.QueryUnescape(path[1:]) if err != nil { - return nil, HTTP, errors.New("invalid HTTP request") + err = errors.New("invalid HTTP request") + return } if method != "GET" && method != "POST" { - return nil, HTTP, errors.New("invalid HTTP method") + err = errors.New("invalid HTTP method") + return } contentLength := 0 websocket := false websocketVersion := 0 websocketKey := "" for { - b, _, err := readTelnetMessage(rd) // read a header line + var b []byte + b, _, err = readTelnetMessage(rd) // read a header line if err != nil { - return nil, HTTP, nil + return } header := string(b) if header == "" { break // end of headers } - if header[0] == 'u' || header[0] == 'U' { + if header[0] == 'a' || header[0] == 'A' { + if strings.HasPrefix(strings.ToLower(header), "authorization:") { + auth = strings.TrimSpace(header[len("authorization:"):]) + } + } else if header[0] == 'u' || header[0] == 'U' { if strings.HasPrefix(strings.ToLower(header), "upgrade:") && strings.ToLower(strings.TrimSpace(header[len("upgrade:"):])) == "websocket" { websocket = true } } else if header[0] == 's' || header[0] == 'S' { if strings.HasPrefix(strings.ToLower(header), "sec-websocket-version:") { - n, err := strconv.ParseUint(strings.TrimSpace(header[len("sec-websocket-version:"):]), 10, 64) + var n uint64 + n, err = strconv.ParseUint(strings.TrimSpace(header[len("sec-websocket-version:"):]), 10, 64) if err != nil { - return nil, HTTP, err + return } websocketVersion = int(n) } else if strings.HasPrefix(strings.ToLower(header), "sec-websocket-key:") { @@ -267,33 +278,35 @@ func readHTTPMessage(line string, wr io.Writer, rd *bufio.Reader) (command []byt } } else if header[0] == 'c' || header[0] == 'C' { if strings.HasPrefix(strings.ToLower(header), "content-length:") { - n, err := strconv.ParseUint(strings.TrimSpace(header[len("content-length:"):]), 10, 64) + var n uint64 + n, err = strconv.ParseUint(strings.TrimSpace(header[len("content-length:"):]), 10, 64) if err != nil { - return nil, HTTP, err + return } contentLength = int(n) } } } if websocket && websocketVersion >= 13 && websocketKey != "" { + proto = WebSocket if wr == nil { - return nil, WebSocket, errors.New("connection is nil") + err = errors.New("connection is nil") + return } sum := sha1.Sum([]byte(websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) accept := base64.StdEncoding.EncodeToString(sum[:]) wshead := "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + accept + "\r\n\r\n" - _, err := wr.Write([]byte(wshead)) - if err != nil { - return nil, WebSocket, err + if _, err = wr.Write([]byte(wshead)); err != nil { + return } - return []byte(path), WebSocket, nil } else if contentLength > 0 { + proto = HTTP buf := make([]byte, contentLength) - _, err := io.ReadFull(rd, buf) - if err != nil { - return nil, HTTP, err + if _, err = io.ReadFull(rd, buf); err != nil { + return } path += string(buf) } - return []byte(path), HTTP, nil + command = []byte(path) + return } diff --git a/client/helper.go b/client/helper.go index fecd7067..6aa14b84 100644 --- a/client/helper.go +++ b/client/helper.go @@ -12,8 +12,8 @@ type Standard struct { Elapsed string `json:"elapsed"` } -// Stats represents tile38 server statistics. -type Stats struct { +// Server represents tile38 server statistics. +type ServerStats struct { Standard Stats struct { ServerID string `json:"id"` @@ -30,9 +30,9 @@ type Stats struct { } // Stats returns tile38 server statistics. -func (conn *Conn) Stats() (Stats, error) { - var stats Stats - msg, err := conn.Do("stats") +func (conn *Conn) Server() (ServerStats, error) { + var stats ServerStats + msg, err := conn.Do("server") if err != nil { return stats, err } @@ -40,7 +40,7 @@ func (conn *Conn) Stats() (Stats, error) { return stats, err } if !stats.OK { - if stats.Err == "" { + if stats.Err != "" { return stats, errors.New(stats.Err) } return stats, errors.New("not ok") diff --git a/controller/aof.go b/controller/aof.go index 11dd7c74..007b86e9 100644 --- a/controller/aof.go +++ b/controller/aof.go @@ -251,7 +251,7 @@ func (c *Controller) liveAOF(pos int64, conn net.Conn, rd *bufio.Reader) error { cond.L.Unlock() }() for { - command, _, err := client.ReadMessage(rd, nil) + command, _, _, err := client.ReadMessage(rd, nil) if err != nil { if err != io.EOF { log.Error(err) diff --git a/controller/auth.go b/controller/auth.go deleted file mode 100644 index a8660e7b..00000000 --- a/controller/auth.go +++ /dev/null @@ -1,15 +0,0 @@ -package controller - -func (c *Controller) cmdAuth(line string) error { - var password string - if line, password = token(line); password == "" { - return errInvalidNumberOfArguments - } - if line != "" { - return errInvalidNumberOfArguments - } - - println(password) - - return nil -} diff --git a/controller/controller.go b/controller/controller.go index af592336..bc9ff863 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -183,7 +183,7 @@ func (c *Controller) handleInputCommand(conn *server.Conn, line string, w io.Wri return writeErr(errors.New("empty command")) } - if !conn.Authenticated { + if !conn.Authenticated || cmd == "auth" { c.mu.RLock() requirePass := c.config.RequirePass c.mu.RUnlock() @@ -194,11 +194,14 @@ func (c *Controller) handleInputCommand(conn *server.Conn, line string, w io.Wri return writeErr(errors.New("authentication required")) } password, _ := token(line) - if requirePass == strings.TrimSpace(password) { - conn.Authenticated = true - } else { + if requirePass != strings.TrimSpace(password) { return writeErr(errors.New("invalid password")) } + conn.Authenticated = true + w.Write([]byte(`{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}")) + return nil + } else if cmd == "auth" { + return writeErr(errors.New("invalid password")) } } @@ -318,13 +321,10 @@ func (c *Controller) command(line string, w io.Writer) (resp string, d commandDe case "readonly": err = c.cmdReadOnly(nline) resp = okResp() - case "auth": - err = c.cmdAuth(nline) - resp = okResp() case "stats": - resp, err = c.cmdServer(nline) - case "server": resp, err = c.cmdStats(nline) + case "server": + resp, err = c.cmdServer(nline) case "scan": err = c.cmdScan(nline, w) case "nearby": diff --git a/controller/follow.go b/controller/follow.go index a6efcbf2..280daaf9 100644 --- a/controller/follow.go +++ b/controller/follow.go @@ -45,6 +45,7 @@ func (c *Controller) cmdFollow(line string) error { } port := int(n) update = c.config.FollowHost != host || c.config.FollowPort != port + auth := c.config.LeaderAuth if update { c.mu.Unlock() conn, err := client.DialTimeout(fmt.Sprintf("%s:%d", host, port), time.Second*2) @@ -53,7 +54,12 @@ func (c *Controller) cmdFollow(line string) error { return fmt.Errorf("cannot follow: %v", err) } defer conn.Close() - msg, err := conn.Stats() + if auth != "" { + if err := c.followDoLeaderAuth(conn, auth); err != nil { + return fmt.Errorf("cannot follow: %v", err) + } + } + msg, err := conn.Server() if err != nil { c.mu.Lock() return fmt.Errorf("cannot follow: %v", err) @@ -103,6 +109,25 @@ func (c *Controller) followHandleCommand(line string, followc uint64, w io.Write return c.aofsz, nil } +func (c *Controller) followDoLeaderAuth(conn *client.Conn, auth string) error { + data, err := conn.Do("AUTH " + auth) + if err != nil { + return err + } + var msg client.Standard + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + if !msg.OK { + if msg.Err != "" { + return errors.New(msg.Err) + } else { + return errors.New("cannot follow: auth no ok") + } + } + return nil +} + func (c *Controller) followStep(host string, port int, followc uint64) error { c.mu.Lock() if c.followc != followc { @@ -110,6 +135,7 @@ func (c *Controller) followStep(host string, port int, followc uint64) error { return errNoLongerFollowing } c.fcup = false + auth := c.config.LeaderAuth c.mu.Unlock() addr := fmt.Sprintf("%s:%d", host, port) // check if we are following self @@ -118,7 +144,13 @@ func (c *Controller) followStep(host string, port int, followc uint64) error { return fmt.Errorf("cannot follow: %v", err) } defer conn.Close() - stats, err := conn.Stats() + if auth != "" { + if err := c.followDoLeaderAuth(conn, auth); err != nil { + return fmt.Errorf("cannot follow: %v", err) + } + } + + stats, err := conn.Server() if err != nil { return fmt.Errorf("cannot follow: %v", err) } @@ -134,13 +166,6 @@ func (c *Controller) followStep(host string, port int, followc uint64) error { return err } - // make real connection - conn, err = client.Dial(addr) - if err != nil { - return err - } - defer conn.Close() - msg, err := conn.Do(fmt.Sprintf("aof %d", pos)) if err != nil { return err @@ -198,7 +223,7 @@ func (c *Controller) follow(host string, port int, followc uint64) { return } if err != nil && err != io.EOF { - log.Debug("follow: " + err.Error()) + log.Error("follow: " + err.Error()) } time.Sleep(time.Second) } diff --git a/controller/live.go b/controller/live.go index 0270a438..b0830150 100644 --- a/controller/live.go +++ b/controller/live.go @@ -103,7 +103,7 @@ func (c *Controller) goLive(inerr error, conn net.Conn, rd *bufio.Reader, websoc conn.Close() }() for { - command, _, err := client.ReadMessage(rd, nil) + command, _, _, err := client.ReadMessage(rd, nil) if err != nil { if err != io.EOF { log.Error(err) diff --git a/controller/server/server.go b/controller/server/server.go index bf0f00a6..dcab4bea 100644 --- a/controller/server/server.go +++ b/controller/server/server.go @@ -68,6 +68,13 @@ func ListenAndServe( } } +func writeCommandErr(proto client.Proto, conn *Conn, err error) error { + if proto == client.HTTP || proto == client.WebSocket { + conn.Write([]byte(`HTTP/1.1 500 ` + err.Error() + "\r\nConnection: close\r\n\r\n")) + } + return err +} + func handleConn( conn *Conn, protected func() bool, @@ -92,7 +99,7 @@ func handleConn( rd := bufio.NewReader(conn) for i := 0; ; i++ { err := func() error { - command, proto, err := client.ReadMessage(rd, conn) + command, proto, auth, err := client.ReadMessage(rd, conn) if err != nil { return err } @@ -100,12 +107,21 @@ func handleConn( return io.EOF } var b bytes.Buffer - - if err := handler(conn, command, rd, &b, proto == client.WebSocket); err != nil { - if proto == client.HTTP { - conn.Write([]byte(`HTTP/1.1 500 ` + err.Error() + "\r\nConnection: close\r\n\r\n")) + var denied bool + if (proto == client.HTTP || proto == client.WebSocket) && auth != "" { + if err := handler(conn, []byte("AUTH "+auth), rd, &b, proto == client.WebSocket); err != nil { + return writeCommandErr(proto, conn, err) + } + if strings.HasPrefix(b.String(), `{"ok":false`) { + denied = true + } else { + b.Reset() + } + } + if !denied { + if err := handler(conn, command, rd, &b, proto == client.WebSocket); err != nil { + return writeCommandErr(proto, conn, err) } - return err } switch proto { case client.Native: