mirror of https://github.com/tidwall/tile38.git
replication auth
This commit is contained in:
parent
4ef89db63b
commit
c6619a529f
|
@ -83,7 +83,7 @@ func (conn *Conn) Do(command string) ([]byte, error) {
|
|||
conn.pool = nil
|
||||
return nil, err
|
||||
}
|
||||
message, _, err := ReadMessage(conn.rd, nil)
|
||||
message, _, _, err := ReadMessage(conn.rd, nil)
|
||||
if err != nil {
|
||||
conn.pool = nil
|
||||
return nil, err
|
||||
|
@ -96,7 +96,7 @@ func (conn *Conn) Do(command string) ([]byte, error) {
|
|||
|
||||
// ReadMessage returns the next message. Used when reading live connections
|
||||
func (conn *Conn) ReadMessage() (message []byte, err error) {
|
||||
message, _, err = readMessage(conn.c, conn.rd)
|
||||
message, _, _, err = readMessage(conn.c, conn.rd)
|
||||
if err != nil {
|
||||
conn.pool = nil
|
||||
return message, err
|
||||
|
@ -160,10 +160,10 @@ func WriteWebSocket(conn net.Conn, data []byte) error {
|
|||
}
|
||||
|
||||
// ReadMessage reads the next message from a bufio.Reader.
|
||||
func readMessage(wr io.Writer, rd *bufio.Reader) (message []byte, proto Proto, err error) {
|
||||
func readMessage(wr io.Writer, rd *bufio.Reader) (message []byte, proto Proto, auth string, err error) {
|
||||
h, err := rd.Peek(1)
|
||||
if err != nil {
|
||||
return nil, proto, err
|
||||
return nil, proto, auth, err
|
||||
}
|
||||
switch h[0] {
|
||||
case '$':
|
||||
|
@ -171,41 +171,41 @@ func readMessage(wr io.Writer, rd *bufio.Reader) (message []byte, proto Proto, e
|
|||
}
|
||||
message, proto, err = readTelnetMessage(rd)
|
||||
if err != nil {
|
||||
return nil, proto, err
|
||||
return nil, proto, auth, err
|
||||
}
|
||||
if len(message) > 6 && string(message[len(message)-9:len(message)-2]) == " HTTP/1" {
|
||||
return readHTTPMessage(string(message), wr, rd)
|
||||
}
|
||||
return message, proto, nil
|
||||
return message, proto, auth, nil
|
||||
|
||||
}
|
||||
func ReadMessage(rd *bufio.Reader, wr io.Writer) (message []byte, proto Proto, err error) {
|
||||
func ReadMessage(rd *bufio.Reader, wr io.Writer) (message []byte, proto Proto, auth string, err error) {
|
||||
return readMessage(wr, rd)
|
||||
}
|
||||
|
||||
func readProtoMessage(rd *bufio.Reader) (message []byte, proto Proto, err error) {
|
||||
func readProtoMessage(rd *bufio.Reader) (message []byte, proto Proto, auth string, err error) {
|
||||
b, err := rd.ReadBytes(' ')
|
||||
if err != nil {
|
||||
return nil, Native, err
|
||||
return nil, Native, auth, err
|
||||
}
|
||||
if len(b) > 0 && b[0] != '$' {
|
||||
return nil, Native, errors.New("not a proto message")
|
||||
return nil, Native, auth, errors.New("not a proto message")
|
||||
}
|
||||
n, err := strconv.ParseUint(string(b[1:len(b)-1]), 10, 32)
|
||||
if err != nil {
|
||||
return nil, Native, errors.New("invalid size")
|
||||
return nil, Native, auth, errors.New("invalid size")
|
||||
}
|
||||
if n > MaxMessageSize {
|
||||
return nil, Native, errors.New("message too big")
|
||||
return nil, Native, auth, errors.New("message too big")
|
||||
}
|
||||
b = make([]byte, int(n)+2)
|
||||
if _, err := io.ReadFull(rd, b); err != nil {
|
||||
return nil, Native, err
|
||||
return nil, Native, auth, err
|
||||
}
|
||||
if b[len(b)-2] != '\r' || b[len(b)-1] != '\n' {
|
||||
return nil, Native, errors.New("expecting crlf suffix")
|
||||
return nil, Native, auth, errors.New("expecting crlf suffix")
|
||||
}
|
||||
return b[:len(b)-2], Native, nil
|
||||
return b[:len(b)-2], Native, auth, nil
|
||||
}
|
||||
|
||||
func readTelnetMessage(rd *bufio.Reader) (command []byte, proto Proto, err error) {
|
||||
|
@ -221,45 +221,56 @@ func readTelnetMessage(rd *bufio.Reader) (command []byte, proto Proto, err error
|
|||
return line, Telnet, nil
|
||||
}
|
||||
|
||||
func readHTTPMessage(line string, wr io.Writer, rd *bufio.Reader) (command []byte, proto Proto, err error) {
|
||||
func readHTTPMessage(line string, wr io.Writer, rd *bufio.Reader) (command []byte, proto Proto, auth string, err error) {
|
||||
proto = HTTP
|
||||
parts := strings.Split(line, " ")
|
||||
if len(parts) != 3 {
|
||||
return nil, HTTP, errors.New("invalid HTTP request")
|
||||
err = errors.New("invalid HTTP request")
|
||||
return
|
||||
}
|
||||
method := parts[0]
|
||||
path := parts[1]
|
||||
if len(path) == 0 || path[0] != '/' {
|
||||
return nil, HTTP, errors.New("invalid HTTP request")
|
||||
err = errors.New("invalid HTTP request")
|
||||
return
|
||||
}
|
||||
path, err = url.QueryUnescape(path[1:])
|
||||
if err != nil {
|
||||
return nil, HTTP, errors.New("invalid HTTP request")
|
||||
err = errors.New("invalid HTTP request")
|
||||
return
|
||||
}
|
||||
if method != "GET" && method != "POST" {
|
||||
return nil, HTTP, errors.New("invalid HTTP method")
|
||||
err = errors.New("invalid HTTP method")
|
||||
return
|
||||
}
|
||||
contentLength := 0
|
||||
websocket := false
|
||||
websocketVersion := 0
|
||||
websocketKey := ""
|
||||
for {
|
||||
b, _, err := readTelnetMessage(rd) // read a header line
|
||||
var b []byte
|
||||
b, _, err = readTelnetMessage(rd) // read a header line
|
||||
if err != nil {
|
||||
return nil, HTTP, nil
|
||||
return
|
||||
}
|
||||
header := string(b)
|
||||
if header == "" {
|
||||
break // end of headers
|
||||
}
|
||||
if header[0] == 'u' || header[0] == 'U' {
|
||||
if header[0] == 'a' || header[0] == 'A' {
|
||||
if strings.HasPrefix(strings.ToLower(header), "authorization:") {
|
||||
auth = strings.TrimSpace(header[len("authorization:"):])
|
||||
}
|
||||
} else if header[0] == 'u' || header[0] == 'U' {
|
||||
if strings.HasPrefix(strings.ToLower(header), "upgrade:") && strings.ToLower(strings.TrimSpace(header[len("upgrade:"):])) == "websocket" {
|
||||
websocket = true
|
||||
}
|
||||
} else if header[0] == 's' || header[0] == 'S' {
|
||||
if strings.HasPrefix(strings.ToLower(header), "sec-websocket-version:") {
|
||||
n, err := strconv.ParseUint(strings.TrimSpace(header[len("sec-websocket-version:"):]), 10, 64)
|
||||
var n uint64
|
||||
n, err = strconv.ParseUint(strings.TrimSpace(header[len("sec-websocket-version:"):]), 10, 64)
|
||||
if err != nil {
|
||||
return nil, HTTP, err
|
||||
return
|
||||
}
|
||||
websocketVersion = int(n)
|
||||
} else if strings.HasPrefix(strings.ToLower(header), "sec-websocket-key:") {
|
||||
|
@ -267,33 +278,35 @@ func readHTTPMessage(line string, wr io.Writer, rd *bufio.Reader) (command []byt
|
|||
}
|
||||
} else if header[0] == 'c' || header[0] == 'C' {
|
||||
if strings.HasPrefix(strings.ToLower(header), "content-length:") {
|
||||
n, err := strconv.ParseUint(strings.TrimSpace(header[len("content-length:"):]), 10, 64)
|
||||
var n uint64
|
||||
n, err = strconv.ParseUint(strings.TrimSpace(header[len("content-length:"):]), 10, 64)
|
||||
if err != nil {
|
||||
return nil, HTTP, err
|
||||
return
|
||||
}
|
||||
contentLength = int(n)
|
||||
}
|
||||
}
|
||||
}
|
||||
if websocket && websocketVersion >= 13 && websocketKey != "" {
|
||||
proto = WebSocket
|
||||
if wr == nil {
|
||||
return nil, WebSocket, errors.New("connection is nil")
|
||||
err = errors.New("connection is nil")
|
||||
return
|
||||
}
|
||||
sum := sha1.Sum([]byte(websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
|
||||
accept := base64.StdEncoding.EncodeToString(sum[:])
|
||||
wshead := "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + accept + "\r\n\r\n"
|
||||
_, err := wr.Write([]byte(wshead))
|
||||
if err != nil {
|
||||
return nil, WebSocket, err
|
||||
if _, err = wr.Write([]byte(wshead)); err != nil {
|
||||
return
|
||||
}
|
||||
return []byte(path), WebSocket, nil
|
||||
} else if contentLength > 0 {
|
||||
proto = HTTP
|
||||
buf := make([]byte, contentLength)
|
||||
_, err := io.ReadFull(rd, buf)
|
||||
if err != nil {
|
||||
return nil, HTTP, err
|
||||
if _, err = io.ReadFull(rd, buf); err != nil {
|
||||
return
|
||||
}
|
||||
path += string(buf)
|
||||
}
|
||||
return []byte(path), HTTP, nil
|
||||
command = []byte(path)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -12,8 +12,8 @@ type Standard struct {
|
|||
Elapsed string `json:"elapsed"`
|
||||
}
|
||||
|
||||
// Stats represents tile38 server statistics.
|
||||
type Stats struct {
|
||||
// Server represents tile38 server statistics.
|
||||
type ServerStats struct {
|
||||
Standard
|
||||
Stats struct {
|
||||
ServerID string `json:"id"`
|
||||
|
@ -30,9 +30,9 @@ type Stats struct {
|
|||
}
|
||||
|
||||
// Stats returns tile38 server statistics.
|
||||
func (conn *Conn) Stats() (Stats, error) {
|
||||
var stats Stats
|
||||
msg, err := conn.Do("stats")
|
||||
func (conn *Conn) Server() (ServerStats, error) {
|
||||
var stats ServerStats
|
||||
msg, err := conn.Do("server")
|
||||
if err != nil {
|
||||
return stats, err
|
||||
}
|
||||
|
@ -40,7 +40,7 @@ func (conn *Conn) Stats() (Stats, error) {
|
|||
return stats, err
|
||||
}
|
||||
if !stats.OK {
|
||||
if stats.Err == "" {
|
||||
if stats.Err != "" {
|
||||
return stats, errors.New(stats.Err)
|
||||
}
|
||||
return stats, errors.New("not ok")
|
||||
|
|
|
@ -251,7 +251,7 @@ func (c *Controller) liveAOF(pos int64, conn net.Conn, rd *bufio.Reader) error {
|
|||
cond.L.Unlock()
|
||||
}()
|
||||
for {
|
||||
command, _, err := client.ReadMessage(rd, nil)
|
||||
command, _, _, err := client.ReadMessage(rd, nil)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Error(err)
|
||||
|
|
|
@ -1,15 +0,0 @@
|
|||
package controller
|
||||
|
||||
func (c *Controller) cmdAuth(line string) error {
|
||||
var password string
|
||||
if line, password = token(line); password == "" {
|
||||
return errInvalidNumberOfArguments
|
||||
}
|
||||
if line != "" {
|
||||
return errInvalidNumberOfArguments
|
||||
}
|
||||
|
||||
println(password)
|
||||
|
||||
return nil
|
||||
}
|
|
@ -183,7 +183,7 @@ func (c *Controller) handleInputCommand(conn *server.Conn, line string, w io.Wri
|
|||
return writeErr(errors.New("empty command"))
|
||||
}
|
||||
|
||||
if !conn.Authenticated {
|
||||
if !conn.Authenticated || cmd == "auth" {
|
||||
c.mu.RLock()
|
||||
requirePass := c.config.RequirePass
|
||||
c.mu.RUnlock()
|
||||
|
@ -194,11 +194,14 @@ func (c *Controller) handleInputCommand(conn *server.Conn, line string, w io.Wri
|
|||
return writeErr(errors.New("authentication required"))
|
||||
}
|
||||
password, _ := token(line)
|
||||
if requirePass == strings.TrimSpace(password) {
|
||||
conn.Authenticated = true
|
||||
} else {
|
||||
if requirePass != strings.TrimSpace(password) {
|
||||
return writeErr(errors.New("invalid password"))
|
||||
}
|
||||
conn.Authenticated = true
|
||||
w.Write([]byte(`{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}"))
|
||||
return nil
|
||||
} else if cmd == "auth" {
|
||||
return writeErr(errors.New("invalid password"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -318,13 +321,10 @@ func (c *Controller) command(line string, w io.Writer) (resp string, d commandDe
|
|||
case "readonly":
|
||||
err = c.cmdReadOnly(nline)
|
||||
resp = okResp()
|
||||
case "auth":
|
||||
err = c.cmdAuth(nline)
|
||||
resp = okResp()
|
||||
case "stats":
|
||||
resp, err = c.cmdServer(nline)
|
||||
case "server":
|
||||
resp, err = c.cmdStats(nline)
|
||||
case "server":
|
||||
resp, err = c.cmdServer(nline)
|
||||
case "scan":
|
||||
err = c.cmdScan(nline, w)
|
||||
case "nearby":
|
||||
|
|
|
@ -45,6 +45,7 @@ func (c *Controller) cmdFollow(line string) error {
|
|||
}
|
||||
port := int(n)
|
||||
update = c.config.FollowHost != host || c.config.FollowPort != port
|
||||
auth := c.config.LeaderAuth
|
||||
if update {
|
||||
c.mu.Unlock()
|
||||
conn, err := client.DialTimeout(fmt.Sprintf("%s:%d", host, port), time.Second*2)
|
||||
|
@ -53,7 +54,12 @@ func (c *Controller) cmdFollow(line string) error {
|
|||
return fmt.Errorf("cannot follow: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
msg, err := conn.Stats()
|
||||
if auth != "" {
|
||||
if err := c.followDoLeaderAuth(conn, auth); err != nil {
|
||||
return fmt.Errorf("cannot follow: %v", err)
|
||||
}
|
||||
}
|
||||
msg, err := conn.Server()
|
||||
if err != nil {
|
||||
c.mu.Lock()
|
||||
return fmt.Errorf("cannot follow: %v", err)
|
||||
|
@ -103,6 +109,25 @@ func (c *Controller) followHandleCommand(line string, followc uint64, w io.Write
|
|||
return c.aofsz, nil
|
||||
}
|
||||
|
||||
func (c *Controller) followDoLeaderAuth(conn *client.Conn, auth string) error {
|
||||
data, err := conn.Do("AUTH " + auth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var msg client.Standard
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
if !msg.OK {
|
||||
if msg.Err != "" {
|
||||
return errors.New(msg.Err)
|
||||
} else {
|
||||
return errors.New("cannot follow: auth no ok")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) followStep(host string, port int, followc uint64) error {
|
||||
c.mu.Lock()
|
||||
if c.followc != followc {
|
||||
|
@ -110,6 +135,7 @@ func (c *Controller) followStep(host string, port int, followc uint64) error {
|
|||
return errNoLongerFollowing
|
||||
}
|
||||
c.fcup = false
|
||||
auth := c.config.LeaderAuth
|
||||
c.mu.Unlock()
|
||||
addr := fmt.Sprintf("%s:%d", host, port)
|
||||
// check if we are following self
|
||||
|
@ -118,7 +144,13 @@ func (c *Controller) followStep(host string, port int, followc uint64) error {
|
|||
return fmt.Errorf("cannot follow: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
stats, err := conn.Stats()
|
||||
if auth != "" {
|
||||
if err := c.followDoLeaderAuth(conn, auth); err != nil {
|
||||
return fmt.Errorf("cannot follow: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := conn.Server()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot follow: %v", err)
|
||||
}
|
||||
|
@ -134,13 +166,6 @@ func (c *Controller) followStep(host string, port int, followc uint64) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// make real connection
|
||||
conn, err = client.Dial(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
msg, err := conn.Do(fmt.Sprintf("aof %d", pos))
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -198,7 +223,7 @@ func (c *Controller) follow(host string, port int, followc uint64) {
|
|||
return
|
||||
}
|
||||
if err != nil && err != io.EOF {
|
||||
log.Debug("follow: " + err.Error())
|
||||
log.Error("follow: " + err.Error())
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
|
|
@ -103,7 +103,7 @@ func (c *Controller) goLive(inerr error, conn net.Conn, rd *bufio.Reader, websoc
|
|||
conn.Close()
|
||||
}()
|
||||
for {
|
||||
command, _, err := client.ReadMessage(rd, nil)
|
||||
command, _, _, err := client.ReadMessage(rd, nil)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Error(err)
|
||||
|
|
|
@ -68,6 +68,13 @@ func ListenAndServe(
|
|||
}
|
||||
}
|
||||
|
||||
func writeCommandErr(proto client.Proto, conn *Conn, err error) error {
|
||||
if proto == client.HTTP || proto == client.WebSocket {
|
||||
conn.Write([]byte(`HTTP/1.1 500 ` + err.Error() + "\r\nConnection: close\r\n\r\n"))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func handleConn(
|
||||
conn *Conn,
|
||||
protected func() bool,
|
||||
|
@ -92,7 +99,7 @@ func handleConn(
|
|||
rd := bufio.NewReader(conn)
|
||||
for i := 0; ; i++ {
|
||||
err := func() error {
|
||||
command, proto, err := client.ReadMessage(rd, conn)
|
||||
command, proto, auth, err := client.ReadMessage(rd, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -100,12 +107,21 @@ func handleConn(
|
|||
return io.EOF
|
||||
}
|
||||
var b bytes.Buffer
|
||||
|
||||
if err := handler(conn, command, rd, &b, proto == client.WebSocket); err != nil {
|
||||
if proto == client.HTTP {
|
||||
conn.Write([]byte(`HTTP/1.1 500 ` + err.Error() + "\r\nConnection: close\r\n\r\n"))
|
||||
var denied bool
|
||||
if (proto == client.HTTP || proto == client.WebSocket) && auth != "" {
|
||||
if err := handler(conn, []byte("AUTH "+auth), rd, &b, proto == client.WebSocket); err != nil {
|
||||
return writeCommandErr(proto, conn, err)
|
||||
}
|
||||
if strings.HasPrefix(b.String(), `{"ok":false`) {
|
||||
denied = true
|
||||
} else {
|
||||
b.Reset()
|
||||
}
|
||||
}
|
||||
if !denied {
|
||||
if err := handler(conn, command, rd, &b, proto == client.WebSocket); err != nil {
|
||||
return writeCommandErr(proto, conn, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
switch proto {
|
||||
case client.Native:
|
||||
|
|
Loading…
Reference in New Issue