diff --git a/controller/controller.go b/controller/controller.go index 95ce5192..137f20a5 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -245,21 +245,24 @@ func (c *Controller) handleInputCommand(conn *server.Conn, msg *server.Message, requirePass := c.config.RequirePass c.mu.RUnlock() if requirePass != "" { - // This better be an AUTH command. - if msg.Command != "auth" { + password := "" + // This better be an AUTH command or the Message should contain an Auth + if msg.Command != "auth" && msg.Auth == "" { // Just shut down the pipeline now. The less the client connection knows the better. return writeErr(errors.New("authentication required")) } - password := "" - if len(msg.Values) > 1 { - password = msg.Values[1].String() + if msg.Auth != "" { + password = msg.Auth + } else { + if len(msg.Values) > 1 { + password = msg.Values[1].String() + } } if requirePass != strings.TrimSpace(password) { return writeErr(errors.New("invalid password")) } conn.Authenticated = true - w.Write([]byte(`{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}")) - return nil + return writeOutput(server.OKMessage(msg, start)) } else if msg.Command == "auth" { return writeErr(errors.New("invalid password")) } diff --git a/controller/crud.go b/controller/crud.go index e9693193..bfd0d97a 100644 --- a/controller/crud.go +++ b/controller/crud.go @@ -60,6 +60,13 @@ func (c *Controller) cmdGet(msg *server.Message) (string, error) { if vs, id, ok = tokenval(vs); !ok || id == "" { return "", errInvalidNumberOfArguments } + + withfields := false + if _, peek, ok := tokenval(vs); ok && strings.ToLower(peek) == "withfields" { + withfields = true + vs = vs[1:] + } + col := c.getCol(key) if col == nil { if msg.OutputType == server.RESP { @@ -85,7 +92,7 @@ func (c *Controller) cmdGet(msg *server.Message) (string, error) { buf.WriteString(`,"object":`) buf.WriteString(o.JSON()) } else { - vals = append(vals, resp.ArrayValue([]resp.Value{resp.StringValue(o.JSON())})) + vals = append(vals, resp.StringValue(o.JSON())) } } else { switch strings.ToLower(typ) { @@ -128,7 +135,7 @@ func (c *Controller) cmdGet(msg *server.Message) (string, error) { if msg.OutputType == server.JSON { buf.WriteString(`"` + p + `"`) } else { - vals = append(vals, resp.ArrayValue([]resp.Value{resp.StringValue(p)})) + vals = append(vals, resp.StringValue(p)) } case "bounds": bbox := o.CalculatedBBox() @@ -152,40 +159,49 @@ func (c *Controller) cmdGet(msg *server.Message) (string, error) { if len(vs) != 0 { return "", errInvalidNumberOfArguments } - - fvs := orderFields(col.FieldMap(), fields) - if len(fvs) > 0 { - fvals := make([]resp.Value, 0, len(fvs)*2) - if msg.OutputType == server.JSON { - buf.WriteString(`,"fields":{`) - } - for i, fv := range fvs { + if withfields { + fvs := orderFields(col.FieldMap(), fields) + if len(fvs) > 0 { + fvals := make([]resp.Value, 0, len(fvs)*2) if msg.OutputType == server.JSON { - if i > 0 { - buf.WriteString(`,`) - } - buf.WriteString(jsonString(fv.field) + ":" + strconv.FormatFloat(fv.value, 'f', -1, 64)) - } else { - fvals = append(fvals, resp.StringValue(fv.field), resp.StringValue(strconv.FormatFloat(fv.value, 'f', -1, 64))) + buf.WriteString(`,"fields":{`) + } + for i, fv := range fvs { + if msg.OutputType == server.JSON { + if i > 0 { + buf.WriteString(`,`) + } + buf.WriteString(jsonString(fv.field) + ":" + strconv.FormatFloat(fv.value, 'f', -1, 64)) + } else { + fvals = append(fvals, resp.StringValue(fv.field), resp.StringValue(strconv.FormatFloat(fv.value, 'f', -1, 64))) + } + i++ + } + if msg.OutputType == server.JSON { + buf.WriteString(`}`) + } else { + vals = append(vals, resp.ArrayValue(fvals)) } - i++ - } - if msg.OutputType == server.JSON { - buf.WriteString(`}`) - } else { - vals = append(vals, resp.ArrayValue(fvals)) } } - if msg.OutputType == server.JSON { + switch msg.OutputType { + case server.JSON: buf.WriteString(`,"elapsed":"` + time.Now().Sub(start).String() + "\"}") return buf.String(), nil + case server.RESP: + var oval resp.Value + if withfields { + oval = resp.ArrayValue(vals) + } else { + oval = vals[0] + } + data, err := oval.MarshalRESP() + if err != nil { + return "", err + } + return string(data), nil } - data, err := resp.ArrayValue(vals).MarshalRESP() - if err != nil { - return "", err - } - return string(data), nil - + return "", nil } func (c *Controller) cmdDel(msg *server.Message) (res string, d commandDetailsT, err error) { diff --git a/controller/fence.go b/controller/fence.go index 8a361c0f..85410e57 100644 --- a/controller/fence.go +++ b/controller/fence.go @@ -8,10 +8,10 @@ import ( "github.com/tidwall/tile38/geojson" ) -func (c *Controller) FenceMatch(hookName string, sw *scanWriter, fence *liveFenceSwitches, details *commandDetailsT, mustLock bool) [][]byte { +func (c *Controller) FenceMatch(hookName string, sw *scanWriter, fence *liveFenceSwitches, details *commandDetailsT, mustLock bool) []string { glob := fence.glob if details.command == "drop" { - return [][]byte{[]byte(`{"cmd":"drop"}`)} + return []string{`{"cmd":"drop"}`} } match := true if glob != "" && glob != "*" { @@ -27,7 +27,6 @@ func (c *Controller) FenceMatch(hookName string, sw *scanWriter, fence *liveFenc match = false detect := "outside" if fence != nil { - match1 := fenceMatchObject(fence, details.oldObj) match2 := fenceMatchObject(fence, details.obj) if match1 && match2 { @@ -67,7 +66,7 @@ func (c *Controller) FenceMatch(hookName string, sw *scanWriter, fence *liveFenc } } if details.command == "del" { - return [][]byte{[]byte(`{"command":"del","id":` + jsonString(details.id) + `}`)} + return []string{`{"command":"del","id":` + jsonString(details.id) + `}`} } var fmap map[string]int if mustLock { @@ -101,15 +100,23 @@ func (c *Controller) FenceMatch(hookName string, sw *scanWriter, fence *liveFenc jskey := jsonString(details.key) jstime := time.Now().Format("2006-01-02T15:04:05.999999999Z07:00") jshookName := jsonString(hookName) - if strings.HasPrefix(res, "{") { - res = `{"command":"` + details.command + `","detect":"` + detect + `","hook":` + jshookName + `,"time":"` + jstime + `","key":` + jskey + `,` + res[1:] + ores := res + msgs := make([]string, 0, 2) + if fence.detect == nil || fence.detect[detect] { + if strings.HasPrefix(ores, "{") { + res = `{"command":"` + details.command + `","detect":"` + detect + `","hook":` + jshookName + `,"time":"` + jstime + `","key":` + jskey + `,` + ores[1:] + } + msgs = append(msgs, res) } - msgs := [][]byte{[]byte(res)} switch detect { case "enter": - msgs = append(msgs, []byte(`{"command":"`+details.command+`","detect":"inside","hook":`+jshookName+`,"time":"`+jstime+`","key":`+jskey+`,`+res[1:])) + if fence.detect == nil || fence.detect["inside"] { + msgs = append(msgs, `{"command":"`+details.command+`","detect":"inside","hook":`+jshookName+`,"time":"`+jstime+`","key":`+jskey+`,`+ores[1:]) + } case "exit", "cross": - msgs = append(msgs, []byte(`{"command":"`+details.command+`","detect":"outside","hook":`+jshookName+`,"time":"`+jstime+`","key":`+jskey+`,`+res[1:])) + if fence.detect == nil || fence.detect["outside"] { + msgs = append(msgs, `{"command":"`+details.command+`","detect":"outside","hook":`+jshookName+`,"time":"`+jstime+`","key":`+jskey+`,`+ores[1:]) + } } return msgs } diff --git a/controller/hooks.go b/controller/hooks.go index 13849d98..86061fd1 100644 --- a/controller/hooks.go +++ b/controller/hooks.go @@ -48,24 +48,29 @@ type Hook struct { func (c *Controller) DoHook(hook *Hook, details *commandDetailsT) error { var lerrs []error msgs := c.FenceMatch(hook.Name, hook.ScanWriter, hook.Fence, details, false) +nextMessage: for _, msg := range msgs { + nextEndpoint: for _, endpoint := range hook.Endpoints { switch endpoint.Protocol { case HTTP: - if err := c.sendHTTPMessage(endpoint, msg); err != nil { + if err := c.sendHTTPMessage(endpoint, []byte(msg)); err != nil { lerrs = append(lerrs, err) - continue + continue nextEndpoint } - return nil //sent + continue nextMessage // sent case Disque: - if err := c.sendDisqueMessage(endpoint, msg); err != nil { + if err := c.sendDisqueMessage(endpoint, []byte(msg)); err != nil { lerrs = append(lerrs, err) - continue + continue nextEndpoint } - return nil // sent + continue nextMessage // sent } } } + if len(lerrs) == 0 { + return nil + } var errmsgs []string for _, err := range lerrs { errmsgs = append(errmsgs, err.Error()) @@ -185,10 +190,12 @@ func (c *Controller) cmdSetHook(msg *server.Message) (res string, d commandDetai } endpoints = append(endpoints, endpoint) } + commandvs := vs if vs, cmd, ok = tokenval(vs); !ok || cmd == "" { return "", d, errInvalidNumberOfArguments } + cmdlc := strings.ToLower(cmd) var types []string switch cmdlc { @@ -226,7 +233,6 @@ func (c *Controller) cmdSetHook(msg *server.Message) (res string, d commandDetai return "", d, err } - // delete the previous hook if h, ok := c.hooks[name]; ok { // lets see if the previous hook matches the new hook if h.Key == hook.Key && h.Name == hook.Name { @@ -248,6 +254,8 @@ func (c *Controller) cmdSetHook(msg *server.Message) (res string, d commandDetai } } } + + // delete the previous hook if hm, ok := c.hookcols[h.Key]; ok { delete(hm, h.Name) } diff --git a/controller/live.go b/controller/live.go index 92548d18..03220638 100644 --- a/controller/live.go +++ b/controller/live.go @@ -3,6 +3,7 @@ package controller import ( "bytes" "errors" + "fmt" "io" "net" "sync" @@ -43,11 +44,25 @@ func (c *Controller) processLives() { } } -func writeMessage(conn net.Conn, message []byte, websocket bool) error { - if websocket { - return client.WriteWebSocket(conn, message) +func writeMessage(conn net.Conn, message []byte, wrapRESP bool, connType server.Type, websocket bool) error { + if len(message) == 0 { + return nil } - return client.WriteMessage(conn, message) + if websocket { + return server.WriteWebSocketMessage(conn, message) + } + var err error + switch connType { + case server.RESP: + if wrapRESP { + _, err = fmt.Fprintf(conn, "$%d\r\n%s\r\n", len(message), string(message)) + } else { + _, err = conn.Write(message) + } + case server.Native: + _, err = fmt.Fprintf(conn, "$%d\r\n%s\r\n", len(message), string(message)) + } + return err } func (c *Controller) goLive(inerr error, conn net.Conn, rd *server.AnyReaderWriter, msg *server.Message, websocket bool) error { @@ -73,8 +88,6 @@ func (c *Controller) goLive(inerr error, conn net.Conn, rd *server.AnyReaderWrit lb.key = s.key lb.fence = &s c.mu.RLock() - var msg *server.Message - panic("todo: goLive message must be defined") sw, err = c.newScanWriter(&wr, msg, s.key, s.output, s.precision, s.glob, s.limit, s.wheres, s.nofields) c.mu.RUnlock() } @@ -118,7 +131,19 @@ func (c *Controller) goLive(inerr error, conn net.Conn, rd *server.AnyReaderWrit } } }() - if err := writeMessage(conn, []byte(client.LiveJSON), websocket); err != nil { + outputType := msg.OutputType + connType := msg.ConnType + if websocket { + outputType = server.JSON + } + var livemsg []byte + switch outputType { + case server.JSON: + livemsg = []byte(client.LiveJSON) + case server.RESP: + livemsg = []byte("+OK\r\n") + } + if err := writeMessage(conn, livemsg, false, connType, websocket); err != nil { return nil // nil return is fine here } for { @@ -137,7 +162,7 @@ func (c *Controller) goLive(inerr error, conn net.Conn, rd *server.AnyReaderWrit lb.cond.L.Unlock() msgs := c.FenceMatch("", sw, fence, details, true) for _, msg := range msgs { - if err := writeMessage(conn, msg, websocket); err != nil { + if err := writeMessage(conn, []byte(msg), true, connType, websocket); err != nil { return nil // nil return is fine here } } diff --git a/controller/server/anyreader.go b/controller/server/anyreader.go index 24607d40..744e4753 100644 --- a/controller/server/anyreader.go +++ b/controller/server/anyreader.go @@ -2,7 +2,6 @@ package server import ( "bufio" - "bytes" "crypto/sha1" "encoding/base64" "errors" @@ -28,6 +27,27 @@ const ( JSON ) +func (t Type) String() string { + switch t { + default: + return "Unknown" + case Null: + return "Null" + case RESP: + return "RESP" + case Telnet: + return "Telnet" + case Native: + return "Native" + case HTTP: + return "HTTP" + case WebSocket: + return "WebSocket" + case JSON: + return "JSON" + } +} + type errRESPProtocolError struct { msg string } @@ -123,6 +143,30 @@ func (ar *AnyReaderWriter) ReadMessage() (*Message, error) { return ar.readMultiBulkMessage() } +func readNativeMessageLine(line []byte) (*Message, error) { + values := make([]resp.Value, 0, 16) +reading: + for len(line) != 0 { + if line[0] == '{' { + // The native protocol cannot understand json boundaries so it assumes that + // a json element must be at the end of the line. + values = append(values, resp.StringValue(string(line))) + break + } + i := 0 + for ; i < len(line); i++ { + if line[i] == ' ' { + values = append(values, resp.StringValue(string(line[:i]))) + line = line[i+1:] + continue reading + } + } + values = append(values, resp.StringValue(string(line))) + break + } + return &Message{Command: commandValues(values), Values: values, ConnType: Native, OutputType: JSON}, nil +} + func (ar *AnyReaderWriter) readNativeMessage() (*Message, error) { b, err := ar.rd.ReadBytes(' ') if err != nil { @@ -145,28 +189,8 @@ func (ar *AnyReaderWriter) readNativeMessage() (*Message, error) { if b[len(b)-2] != '\r' || b[len(b)-1] != '\n' { return nil, errors.New("expecting crlf") } - values := make([]resp.Value, 0, 16) - line := b[:len(b)-2] -reading: - for len(line) != 0 { - if line[0] == '{' { - // The native protocol cannot understand json boundaries so it assumes that - // a json element must be at the end of the line. - values = append(values, resp.StringValue(string(line))) - break - } - i := 0 - for ; i < len(line); i++ { - if line[i] == ' ' { - values = append(values, resp.StringValue(string(line[:i]))) - line = line[i+1:] - continue reading - } - } - values = append(values, resp.StringValue(string(line))) - break - } - return &Message{Command: commandValues(values), Values: values, ConnType: Native, OutputType: JSON}, nil + + return readNativeMessageLine(b[:len(b)-2]) } func commandValues(values []resp.Value) string { @@ -283,8 +307,8 @@ func (ar *AnyReaderWriter) readHTTPMessage() (*Message, error) { if !strings.HasSuffix(path, "\r\n") { path += "\r\n" } - rd := NewAnyReaderWriter(bytes.NewBufferString(path)) - nmsg, err := rd.ReadMessage() + + nmsg, err := readNativeMessageLine([]byte(path)) if err != nil { return nil, err } diff --git a/controller/server/server.go b/controller/server/server.go index 6dcfde4c..72a91ef6 100644 --- a/controller/server/server.go +++ b/controller/server/server.go @@ -9,8 +9,6 @@ import ( "strings" "time" - //"github.com/tidwall/tile38/client" - "github.com/tidwall/tile38/controller/log" "github.com/tidwall/tile38/core" ) @@ -171,68 +169,3 @@ func OKMessage(msg *Message, start time.Time) string { } return "" } - -//err := func() error { -// command, proto, auth, err := client.ReadMessage(rd, conn) -// if err != nil { -// return err -// } -// if len(command) > 0 && (command[0] == 'Q' || command[0] == 'q') && strings.ToLower(string(command)) == "quit" { -// return io.EOF -// } -// var b bytes.Buffer -// var denied bool -// if (proto == client.HTTP || proto == client.WebSocket) && auth != "" { -// if err := handler(conn, []byte("AUTH "+auth), rd, &b, proto == client.WebSocket); err != nil { -// return writeCommandErr(proto, conn, err) -// } -// if strings.HasPrefix(b.String(), `{"ok":false`) { -// denied = true -// } else { -// b.Reset() -// } -// } -// if !denied { -// if err := handler(conn, command, rd, &b, proto == client.WebSocket); err != nil { -// return writeCommandErr(proto, conn, err) -// } -// } -// switch proto { -// case client.Native: -// if err := client.WriteMessage(conn, b.Bytes()); err != nil { -// return err -// } -// case client.HTTP: -// if err := client.WriteHTTP(conn, b.Bytes()); err != nil { -// return err -// } -// return errCloseHTTP -// case client.WebSocket: -// if err := client.WriteWebSocket(conn, b.Bytes()); err != nil { -// return err -// } -// if _, err := conn.Write([]byte{137, 0}); err != nil { -// return err -// } -// return errCloseHTTP -// default: -// b.WriteString("\r\n") -// if _, err := conn.Write(b.Bytes()); err != nil { -// return err -// } -// } -// return nil -//}() -// if err != nil { -// if err == io.EOF { -// return -// } -// if err == errCloseHTTP || -// strings.Contains(err.Error(), "use of closed network connection") { -// return -// } -// log.Error(err) -// return -// } -// } -// } diff --git a/controller/token.go b/controller/token.go index 80d9cf59..e6e6fb9d 100644 --- a/controller/token.go +++ b/controller/token.go @@ -124,6 +124,7 @@ type searchScanBaseTokens struct { precision uint64 lineout string fence bool + detect map[string]bool glob string wheres []whereT nofields bool @@ -236,6 +237,43 @@ func parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vsout []resp.Value, return } t.fence = true + continue + } else if (wtok[0] == 'D' || wtok[0] == 'd') && strings.ToLower(wtok) == "detect" { + vs = nvs + if t.detect != nil { + err = errDuplicateArgument(strings.ToUpper(wtok)) + return + } + t.detect = make(map[string]bool) + var peek string + if vs, peek, ok = tokenval(vs); !ok || peek == "" { + err = errInvalidNumberOfArguments + return + } + for _, s := range strings.Split(peek, ",") { + part := strings.TrimSpace(strings.ToLower(s)) + switch part { + default: + err = errInvalidArgument(peek) + return + case "inside", "outside", "enter", "exit", "cross": + } + if t.detect[part] { + err = errDuplicateArgument(s) + return + } + t.detect[part] = true + } + if len(t.detect) == 0 { + t.detect = map[string]bool{ + "inside": true, + "outside": true, + "enter": true, + "exit": true, + "cross": true, + } + } + continue } else if (wtok[0] == 'M' || wtok[0] == 'm') && strings.ToLower(wtok) == "match" { vs = nvs @@ -276,6 +314,10 @@ func parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vsout []resp.Value, err = errors.New("CURSOR is not allowed when FENCE is specified") return } + if t.detect != nil && !t.fence { + err = errors.New("DETECT is not allowed when FENCE is not specified") + return + } t.output = defaultSearchOutput var nvs []resp.Value diff --git a/core/commands.json b/core/commands.json index 5877f7c5..65d3cd26 100644 --- a/core/commands.json +++ b/core/commands.json @@ -120,6 +120,12 @@ "name": "id", "type": "string" }, + { + "command": "WITHFIELDS", + "name": [], + "type": [], + "optional": true + }, { "name": "type", "optional": true, diff --git a/core/commands_gen.go b/core/commands_gen.go index 9ae7bad1..978dd9ee 100644 --- a/core/commands_gen.go +++ b/core/commands_gen.go @@ -273,6 +273,12 @@ var commandsJSON = `{ "name": "id", "type": "string" }, + { + "command": "WITHFIELDS", + "name": [], + "type": [], + "optional": true + }, { "name": "type", "optional": true,