replication auth

This commit is contained in:
Josh Baker 2016-03-08 08:35:43 -07:00
parent 4ef89db63b
commit c6619a529f
8 changed files with 123 additions and 84 deletions

View File

@ -83,7 +83,7 @@ func (conn *Conn) Do(command string) ([]byte, error) {
conn.pool = nil
return nil, err
}
message, _, err := ReadMessage(conn.rd, nil)
message, _, _, err := ReadMessage(conn.rd, nil)
if err != nil {
conn.pool = nil
return nil, err
@ -96,7 +96,7 @@ func (conn *Conn) Do(command string) ([]byte, error) {
// ReadMessage returns the next message. Used when reading live connections
func (conn *Conn) ReadMessage() (message []byte, err error) {
message, _, err = readMessage(conn.c, conn.rd)
message, _, _, err = readMessage(conn.c, conn.rd)
if err != nil {
conn.pool = nil
return message, err
@ -160,10 +160,10 @@ func WriteWebSocket(conn net.Conn, data []byte) error {
}
// ReadMessage reads the next message from a bufio.Reader.
func readMessage(wr io.Writer, rd *bufio.Reader) (message []byte, proto Proto, err error) {
func readMessage(wr io.Writer, rd *bufio.Reader) (message []byte, proto Proto, auth string, err error) {
h, err := rd.Peek(1)
if err != nil {
return nil, proto, err
return nil, proto, auth, err
}
switch h[0] {
case '$':
@ -171,41 +171,41 @@ func readMessage(wr io.Writer, rd *bufio.Reader) (message []byte, proto Proto, e
}
message, proto, err = readTelnetMessage(rd)
if err != nil {
return nil, proto, err
return nil, proto, auth, err
}
if len(message) > 6 && string(message[len(message)-9:len(message)-2]) == " HTTP/1" {
return readHTTPMessage(string(message), wr, rd)
}
return message, proto, nil
return message, proto, auth, nil
}
func ReadMessage(rd *bufio.Reader, wr io.Writer) (message []byte, proto Proto, err error) {
func ReadMessage(rd *bufio.Reader, wr io.Writer) (message []byte, proto Proto, auth string, err error) {
return readMessage(wr, rd)
}
func readProtoMessage(rd *bufio.Reader) (message []byte, proto Proto, err error) {
func readProtoMessage(rd *bufio.Reader) (message []byte, proto Proto, auth string, err error) {
b, err := rd.ReadBytes(' ')
if err != nil {
return nil, Native, err
return nil, Native, auth, err
}
if len(b) > 0 && b[0] != '$' {
return nil, Native, errors.New("not a proto message")
return nil, Native, auth, errors.New("not a proto message")
}
n, err := strconv.ParseUint(string(b[1:len(b)-1]), 10, 32)
if err != nil {
return nil, Native, errors.New("invalid size")
return nil, Native, auth, errors.New("invalid size")
}
if n > MaxMessageSize {
return nil, Native, errors.New("message too big")
return nil, Native, auth, errors.New("message too big")
}
b = make([]byte, int(n)+2)
if _, err := io.ReadFull(rd, b); err != nil {
return nil, Native, err
return nil, Native, auth, err
}
if b[len(b)-2] != '\r' || b[len(b)-1] != '\n' {
return nil, Native, errors.New("expecting crlf suffix")
return nil, Native, auth, errors.New("expecting crlf suffix")
}
return b[:len(b)-2], Native, nil
return b[:len(b)-2], Native, auth, nil
}
func readTelnetMessage(rd *bufio.Reader) (command []byte, proto Proto, err error) {
@ -221,45 +221,56 @@ func readTelnetMessage(rd *bufio.Reader) (command []byte, proto Proto, err error
return line, Telnet, nil
}
func readHTTPMessage(line string, wr io.Writer, rd *bufio.Reader) (command []byte, proto Proto, err error) {
func readHTTPMessage(line string, wr io.Writer, rd *bufio.Reader) (command []byte, proto Proto, auth string, err error) {
proto = HTTP
parts := strings.Split(line, " ")
if len(parts) != 3 {
return nil, HTTP, errors.New("invalid HTTP request")
err = errors.New("invalid HTTP request")
return
}
method := parts[0]
path := parts[1]
if len(path) == 0 || path[0] != '/' {
return nil, HTTP, errors.New("invalid HTTP request")
err = errors.New("invalid HTTP request")
return
}
path, err = url.QueryUnescape(path[1:])
if err != nil {
return nil, HTTP, errors.New("invalid HTTP request")
err = errors.New("invalid HTTP request")
return
}
if method != "GET" && method != "POST" {
return nil, HTTP, errors.New("invalid HTTP method")
err = errors.New("invalid HTTP method")
return
}
contentLength := 0
websocket := false
websocketVersion := 0
websocketKey := ""
for {
b, _, err := readTelnetMessage(rd) // read a header line
var b []byte
b, _, err = readTelnetMessage(rd) // read a header line
if err != nil {
return nil, HTTP, nil
return
}
header := string(b)
if header == "" {
break // end of headers
}
if header[0] == 'u' || header[0] == 'U' {
if header[0] == 'a' || header[0] == 'A' {
if strings.HasPrefix(strings.ToLower(header), "authorization:") {
auth = strings.TrimSpace(header[len("authorization:"):])
}
} else if header[0] == 'u' || header[0] == 'U' {
if strings.HasPrefix(strings.ToLower(header), "upgrade:") && strings.ToLower(strings.TrimSpace(header[len("upgrade:"):])) == "websocket" {
websocket = true
}
} else if header[0] == 's' || header[0] == 'S' {
if strings.HasPrefix(strings.ToLower(header), "sec-websocket-version:") {
n, err := strconv.ParseUint(strings.TrimSpace(header[len("sec-websocket-version:"):]), 10, 64)
var n uint64
n, err = strconv.ParseUint(strings.TrimSpace(header[len("sec-websocket-version:"):]), 10, 64)
if err != nil {
return nil, HTTP, err
return
}
websocketVersion = int(n)
} else if strings.HasPrefix(strings.ToLower(header), "sec-websocket-key:") {
@ -267,33 +278,35 @@ func readHTTPMessage(line string, wr io.Writer, rd *bufio.Reader) (command []byt
}
} else if header[0] == 'c' || header[0] == 'C' {
if strings.HasPrefix(strings.ToLower(header), "content-length:") {
n, err := strconv.ParseUint(strings.TrimSpace(header[len("content-length:"):]), 10, 64)
var n uint64
n, err = strconv.ParseUint(strings.TrimSpace(header[len("content-length:"):]), 10, 64)
if err != nil {
return nil, HTTP, err
return
}
contentLength = int(n)
}
}
}
if websocket && websocketVersion >= 13 && websocketKey != "" {
proto = WebSocket
if wr == nil {
return nil, WebSocket, errors.New("connection is nil")
err = errors.New("connection is nil")
return
}
sum := sha1.Sum([]byte(websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
accept := base64.StdEncoding.EncodeToString(sum[:])
wshead := "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + accept + "\r\n\r\n"
_, err := wr.Write([]byte(wshead))
if err != nil {
return nil, WebSocket, err
if _, err = wr.Write([]byte(wshead)); err != nil {
return
}
return []byte(path), WebSocket, nil
} else if contentLength > 0 {
proto = HTTP
buf := make([]byte, contentLength)
_, err := io.ReadFull(rd, buf)
if err != nil {
return nil, HTTP, err
if _, err = io.ReadFull(rd, buf); err != nil {
return
}
path += string(buf)
}
return []byte(path), HTTP, nil
command = []byte(path)
return
}

View File

@ -12,8 +12,8 @@ type Standard struct {
Elapsed string `json:"elapsed"`
}
// Stats represents tile38 server statistics.
type Stats struct {
// Server represents tile38 server statistics.
type ServerStats struct {
Standard
Stats struct {
ServerID string `json:"id"`
@ -30,9 +30,9 @@ type Stats struct {
}
// Stats returns tile38 server statistics.
func (conn *Conn) Stats() (Stats, error) {
var stats Stats
msg, err := conn.Do("stats")
func (conn *Conn) Server() (ServerStats, error) {
var stats ServerStats
msg, err := conn.Do("server")
if err != nil {
return stats, err
}
@ -40,7 +40,7 @@ func (conn *Conn) Stats() (Stats, error) {
return stats, err
}
if !stats.OK {
if stats.Err == "" {
if stats.Err != "" {
return stats, errors.New(stats.Err)
}
return stats, errors.New("not ok")

View File

@ -251,7 +251,7 @@ func (c *Controller) liveAOF(pos int64, conn net.Conn, rd *bufio.Reader) error {
cond.L.Unlock()
}()
for {
command, _, err := client.ReadMessage(rd, nil)
command, _, _, err := client.ReadMessage(rd, nil)
if err != nil {
if err != io.EOF {
log.Error(err)

View File

@ -1,15 +0,0 @@
package controller
func (c *Controller) cmdAuth(line string) error {
var password string
if line, password = token(line); password == "" {
return errInvalidNumberOfArguments
}
if line != "" {
return errInvalidNumberOfArguments
}
println(password)
return nil
}

View File

@ -183,7 +183,7 @@ func (c *Controller) handleInputCommand(conn *server.Conn, line string, w io.Wri
return writeErr(errors.New("empty command"))
}
if !conn.Authenticated {
if !conn.Authenticated || cmd == "auth" {
c.mu.RLock()
requirePass := c.config.RequirePass
c.mu.RUnlock()
@ -194,11 +194,14 @@ func (c *Controller) handleInputCommand(conn *server.Conn, line string, w io.Wri
return writeErr(errors.New("authentication required"))
}
password, _ := token(line)
if requirePass == strings.TrimSpace(password) {
conn.Authenticated = true
} else {
if requirePass != strings.TrimSpace(password) {
return writeErr(errors.New("invalid password"))
}
conn.Authenticated = true
w.Write([]byte(`{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}"))
return nil
} else if cmd == "auth" {
return writeErr(errors.New("invalid password"))
}
}
@ -318,13 +321,10 @@ func (c *Controller) command(line string, w io.Writer) (resp string, d commandDe
case "readonly":
err = c.cmdReadOnly(nline)
resp = okResp()
case "auth":
err = c.cmdAuth(nline)
resp = okResp()
case "stats":
resp, err = c.cmdServer(nline)
case "server":
resp, err = c.cmdStats(nline)
case "server":
resp, err = c.cmdServer(nline)
case "scan":
err = c.cmdScan(nline, w)
case "nearby":

View File

@ -45,6 +45,7 @@ func (c *Controller) cmdFollow(line string) error {
}
port := int(n)
update = c.config.FollowHost != host || c.config.FollowPort != port
auth := c.config.LeaderAuth
if update {
c.mu.Unlock()
conn, err := client.DialTimeout(fmt.Sprintf("%s:%d", host, port), time.Second*2)
@ -53,7 +54,12 @@ func (c *Controller) cmdFollow(line string) error {
return fmt.Errorf("cannot follow: %v", err)
}
defer conn.Close()
msg, err := conn.Stats()
if auth != "" {
if err := c.followDoLeaderAuth(conn, auth); err != nil {
return fmt.Errorf("cannot follow: %v", err)
}
}
msg, err := conn.Server()
if err != nil {
c.mu.Lock()
return fmt.Errorf("cannot follow: %v", err)
@ -103,6 +109,25 @@ func (c *Controller) followHandleCommand(line string, followc uint64, w io.Write
return c.aofsz, nil
}
func (c *Controller) followDoLeaderAuth(conn *client.Conn, auth string) error {
data, err := conn.Do("AUTH " + auth)
if err != nil {
return err
}
var msg client.Standard
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if !msg.OK {
if msg.Err != "" {
return errors.New(msg.Err)
} else {
return errors.New("cannot follow: auth no ok")
}
}
return nil
}
func (c *Controller) followStep(host string, port int, followc uint64) error {
c.mu.Lock()
if c.followc != followc {
@ -110,6 +135,7 @@ func (c *Controller) followStep(host string, port int, followc uint64) error {
return errNoLongerFollowing
}
c.fcup = false
auth := c.config.LeaderAuth
c.mu.Unlock()
addr := fmt.Sprintf("%s:%d", host, port)
// check if we are following self
@ -118,7 +144,13 @@ func (c *Controller) followStep(host string, port int, followc uint64) error {
return fmt.Errorf("cannot follow: %v", err)
}
defer conn.Close()
stats, err := conn.Stats()
if auth != "" {
if err := c.followDoLeaderAuth(conn, auth); err != nil {
return fmt.Errorf("cannot follow: %v", err)
}
}
stats, err := conn.Server()
if err != nil {
return fmt.Errorf("cannot follow: %v", err)
}
@ -134,13 +166,6 @@ func (c *Controller) followStep(host string, port int, followc uint64) error {
return err
}
// make real connection
conn, err = client.Dial(addr)
if err != nil {
return err
}
defer conn.Close()
msg, err := conn.Do(fmt.Sprintf("aof %d", pos))
if err != nil {
return err
@ -198,7 +223,7 @@ func (c *Controller) follow(host string, port int, followc uint64) {
return
}
if err != nil && err != io.EOF {
log.Debug("follow: " + err.Error())
log.Error("follow: " + err.Error())
}
time.Sleep(time.Second)
}

View File

@ -103,7 +103,7 @@ func (c *Controller) goLive(inerr error, conn net.Conn, rd *bufio.Reader, websoc
conn.Close()
}()
for {
command, _, err := client.ReadMessage(rd, nil)
command, _, _, err := client.ReadMessage(rd, nil)
if err != nil {
if err != io.EOF {
log.Error(err)

View File

@ -68,6 +68,13 @@ func ListenAndServe(
}
}
func writeCommandErr(proto client.Proto, conn *Conn, err error) error {
if proto == client.HTTP || proto == client.WebSocket {
conn.Write([]byte(`HTTP/1.1 500 ` + err.Error() + "\r\nConnection: close\r\n\r\n"))
}
return err
}
func handleConn(
conn *Conn,
protected func() bool,
@ -92,7 +99,7 @@ func handleConn(
rd := bufio.NewReader(conn)
for i := 0; ; i++ {
err := func() error {
command, proto, err := client.ReadMessage(rd, conn)
command, proto, auth, err := client.ReadMessage(rd, conn)
if err != nil {
return err
}
@ -100,12 +107,21 @@ func handleConn(
return io.EOF
}
var b bytes.Buffer
if err := handler(conn, command, rd, &b, proto == client.WebSocket); err != nil {
if proto == client.HTTP {
conn.Write([]byte(`HTTP/1.1 500 ` + err.Error() + "\r\nConnection: close\r\n\r\n"))
var denied bool
if (proto == client.HTTP || proto == client.WebSocket) && auth != "" {
if err := handler(conn, []byte("AUTH "+auth), rd, &b, proto == client.WebSocket); err != nil {
return writeCommandErr(proto, conn, err)
}
if strings.HasPrefix(b.String(), `{"ok":false`) {
denied = true
} else {
b.Reset()
}
}
if !denied {
if err := handler(conn, command, rd, &b, proto == client.WebSocket); err != nil {
return writeCommandErr(proto, conn, err)
}
return err
}
switch proto {
case client.Native: