diff --git a/controller/aof.go b/controller/aof.go index 5b334ffa..8b4a5802 100644 --- a/controller/aof.go +++ b/controller/aof.go @@ -13,10 +13,11 @@ import ( "sync" "time" - "github.com/boltdb/bolt" "github.com/google/btree" + "github.com/tidwall/resp" "github.com/tidwall/tile38/client" "github.com/tidwall/tile38/controller/log" + "github.com/tidwall/tile38/controller/server" ) const backwardsBufferSize = 50000 @@ -104,61 +105,36 @@ func (c *Controller) loadAOF() error { ps := float64(count) / (float64(d) / float64(time.Second)) log.Infof("AOF loaded %d commands: %s: %.0f/sec", count, d, ps) }() - rd := NewAOFReader(c.f) + rd := resp.NewReader(c.f) for { - buf, err := rd.ReadCommand() + v, _, n, err := rd.ReadMultiBulk() if err != nil { if err == io.EOF { return nil } - if err == io.ErrUnexpectedEOF || err == errCorruptedAOF { - log.Warnf("aof is corrupted, likely data loss. Truncating to %d", c.aofsz) - fname := c.f.Name() - c.f.Close() - if err := os.Truncate(c.f.Name(), int64(c.aofsz)); err != nil { - log.Fatalf("could not truncate aof, possible data loss. %s", err.Error()) - return err - } - c.f, err = os.OpenFile(fname, os.O_CREATE|os.O_RDWR, 0600) - if err != nil { - log.Fatalf("could not create aof, possible data loss. %s", err.Error()) - return err - } - if _, err := c.f.Seek(int64(c.aofsz), 0); err != nil { - log.Fatalf("could not seek aof, possible data loss. %s", err.Error()) - return err - } - } return err } - empty := true - for i := 0; i < len(buf); i++ { - if buf[i] != 0 { - empty = false - break - } + values := v.Array() + if len(values) == 0 { + return errors.New("multibulk missing command component") } - if empty { - return nil + msg := &server.Message{ + Command: strings.ToLower(values[0].String()), + Values: values, } - if _, _, err := c.command(string(buf), nil); err != nil { + if _, _, err := c.command(msg, nil); err != nil { return err } - c.aofsz += 9 + len(buf) + c.aofsz += n count++ } } -func writeCommand(w io.Writer, line []byte) (n int, err error) { - b := make([]byte, len(line)+9) - binary.LittleEndian.PutUint32(b, uint32(len(line))) - copy(b[4:], line) - binary.LittleEndian.PutUint32(b[len(b)-5:], uint32(len(line))) - return w.Write(b) -} - -func (c *Controller) writeAOF(line string, d *commandDetailsT) error { +func (c *Controller) writeAOF(value resp.Value, d *commandDetailsT) error { if d != nil { + if !d.updated { + return nil // just ignore writes if the command did not update + } // process hooks if hm, ok := c.hookcols[d.key]; ok { for _, hook := range hm { @@ -168,8 +144,11 @@ func (c *Controller) writeAOF(line string, d *commandDetailsT) error { } } } - - n, err := writeCommand(c.f, []byte(line)) + data, err := value.MarshalRESP() + if err != nil { + return err + } + n, err := c.f.Write(data) if err != nil { return err } @@ -306,14 +285,18 @@ func (c *Controller) liveAOF(pos int64, conn net.Conn, rd *bufio.Reader) error { if err != nil { return err } - rd := NewAOFReader(f) + rd := resp.NewReader(f) for { - cmd, err := rd.ReadCommand() + v, _, err := rd.ReadValue() if err != io.EOF { if err != nil { return err } - if _, err := writeCommand(conn, cmd); err != nil { + data, err := v.MarshalRESP() + if err != nil { + return err + } + if _, err := conn.Write(data); err != nil { return err } continue @@ -387,413 +370,413 @@ func (k *treeKeyBoolT) Less(item btree.Item) bool { // - Stop shrinking, nothing left to do func (c *Controller) aofshrink() { - c.mu.Lock() - c.f.Sync() - if c.shrinking { - c.mu.Unlock() - return - } - c.shrinking = true - endpos := int64(c.aofsz) - start := time.Now() - log.Infof("aof shrink started at pos %d", endpos) + // c.mu.Lock() + // c.f.Sync() + // if c.shrinking { + // c.mu.Unlock() + // return + // } + // c.shrinking = true + // endpos := int64(c.aofsz) + // start := time.Now() + // log.Infof("aof shrink started at pos %d", endpos) - var hooks []string - for _, hook := range c.hooks { - var orgs []string - for _, endpoint := range hook.Endpoints { - orgs = append(orgs, endpoint.Original) - } + // var hooks []string + // for _, hook := range c.hooks { + // var orgs []string + // for _, endpoint := range hook.Endpoints { + // orgs = append(orgs, endpoint.Original) + // } - hooks = append(hooks, "SETHOOK "+hook.Name+" "+strings.Join(orgs, ",")+" "+hook.Command) - } + // hooks = append(hooks, "SETHOOK "+hook.Name+" "+strings.Join(orgs, ",")+" "+hook.Command) + // } - c.mu.Unlock() - var err error - defer func() { - c.mu.Lock() - c.shrinking = false - c.mu.Unlock() - os.RemoveAll(c.dir + "/shrink.db") - os.RemoveAll(c.dir + "/shrink") - if err != nil { - log.Error("aof shrink failed: " + err.Error()) - } else { - log.Info("aof shrink completed: " + time.Now().Sub(start).String()) - } - }() - var db *bolt.DB - db, err = bolt.Open(c.dir+"/shrink.db", 0600, nil) - if err != nil { - return - } - defer db.Close() - var nf *os.File - nf, err = os.Create(c.dir + "/shrink") - if err != nil { - return - } - defer nf.Close() - defer func() { - c.mu.Lock() - defer c.mu.Unlock() - if err == nil { - c.f.Sync() - _, err = nf.Seek(0, 2) - if err == nil { - var f *os.File - f, err = os.Open(c.dir + "/aof") - if err != nil { - return - } - defer f.Close() - _, err = f.Seek(endpos, 0) - if err == nil { - _, err = io.Copy(nf, f) - if err == nil { - f.Close() - nf.Close() - // At this stage we need to kill all aof followers. To do so we will - // write a KILLAOF command to the stream. KILLAOF isn't really a command. - // This will cause the followers will close their connection and then - // automatically reconnect. The reconnection will force a sync of the aof. - err = c.writeAOF("KILLAOF", nil) - if err == nil { - c.f.Close() - err = os.Rename(c.dir+"/shrink", c.dir+"/aof") - if err != nil { - log.Fatal("shink rename fatal operation") - } - c.f, err = os.OpenFile(c.dir+"/aof", os.O_CREATE|os.O_RDWR, 0600) - if err != nil { - log.Fatal("shink openfile fatal operation") - } - var n int64 - n, err = c.f.Seek(0, 2) - if err != nil { - log.Fatal("shink seek end fatal operation") - } - c.aofsz = int(n) - } - } - } - } - } - }() - var f *os.File - f, err = os.Open(c.dir + "/aof") - if err != nil { - return - } - defer f.Close() + // c.mu.Unlock() + // var err error + // defer func() { + // c.mu.Lock() + // c.shrinking = false + // c.mu.Unlock() + // os.RemoveAll(c.dir + "/shrink.db") + // os.RemoveAll(c.dir + "/shrink") + // if err != nil { + // log.Error("aof shrink failed: " + err.Error()) + // } else { + // log.Info("aof shrink completed: " + time.Now().Sub(start).String()) + // } + // }() + // var db *bolt.DB + // db, err = bolt.Open(c.dir+"/shrink.db", 0600, nil) + // if err != nil { + // return + // } + // defer db.Close() + // var nf *os.File + // nf, err = os.Create(c.dir + "/shrink") + // if err != nil { + // return + // } + // defer nf.Close() + // defer func() { + // c.mu.Lock() + // defer c.mu.Unlock() + // if err == nil { + // c.f.Sync() + // _, err = nf.Seek(0, 2) + // if err == nil { + // var f *os.File + // f, err = os.Open(c.dir + "/aof") + // if err != nil { + // return + // } + // defer f.Close() + // _, err = f.Seek(endpos, 0) + // if err == nil { + // _, err = io.Copy(nf, f) + // if err == nil { + // f.Close() + // nf.Close() + // // At this stage we need to kill all aof followers. To do so we will + // // write a KILLAOF command to the stream. KILLAOF isn't really a command. + // // This will cause the followers will close their connection and then + // // automatically reconnect. The reconnection will force a sync of the aof. + // err = c.writeAOF(resp.MultiBulkValue("KILLAOF"), nil) + // if err == nil { + // c.f.Close() + // err = os.Rename(c.dir+"/shrink", c.dir+"/aof") + // if err != nil { + // log.Fatal("shink rename fatal operation") + // } + // c.f, err = os.OpenFile(c.dir+"/aof", os.O_CREATE|os.O_RDWR, 0600) + // if err != nil { + // log.Fatal("shink openfile fatal operation") + // } + // var n int64 + // n, err = c.f.Seek(0, 2) + // if err != nil { + // log.Fatal("shink seek end fatal operation") + // } + // c.aofsz = int(n) + // } + // } + // } + // } + // } + // }() + // var f *os.File + // f, err = os.Open(c.dir + "/aof") + // if err != nil { + // return + // } + // defer f.Close() - var buf []byte - var pos int64 - pos, err = f.Seek(endpos, 0) - if err != nil { - return - } - var readPreviousCommand func() ([]byte, error) - readPreviousCommand = func() ([]byte, error) { - if len(buf) >= 5 { - if buf[len(buf)-1] != 0 { - return nil, errCorruptedAOF - } - sz2 := int(binary.LittleEndian.Uint32(buf[len(buf)-5:])) - if len(buf) >= sz2+9 { - sz1 := int(binary.LittleEndian.Uint32(buf[len(buf)-(sz2+9):])) - if sz1 != sz2 { - return nil, errCorruptedAOF - } - command := buf[len(buf)-(sz2+5) : len(buf)-5] - buf = buf[:len(buf)-(sz2+9)] - return command, nil - } - } - if pos == 0 { - if len(buf) > 0 { - return nil, io.ErrUnexpectedEOF - } else { - return nil, io.EOF - } - } - sz := int64(backwardsBufferSize) - offset := pos - sz - if offset < 0 { - sz = pos - offset = 0 - } - pos, err = f.Seek(offset, 0) - if err != nil { - return nil, err - } - nbuf := make([]byte, int(sz)) - _, err = io.ReadFull(f, nbuf) - if err != nil { - return nil, err - } - if len(buf) > 0 { - nbuf = append(nbuf, buf...) - } - buf = nbuf - return readPreviousCommand() - } - var tx *bolt.Tx - tx, err = db.Begin(true) - if err != nil { - return - } - defer func() { - tx.Rollback() - }() - var keyIgnoreM = map[string]bool{} - var keyBucketM = btree.New(16) - var cmd, key, id, field string - var line string - var command []byte - var val []byte - var b *bolt.Bucket -reading: - for i := 0; ; i++ { - if i%500 == 0 { - if err = tx.Commit(); err != nil { - return - } - tx, err = db.Begin(true) - if err != nil { - return - } - } - command, err = readPreviousCommand() - if err != nil { - if err == io.EOF { - err = nil - break - } - return - } - // quick path - if len(command) == 0 { - continue // ignore blank commands - } - line, cmd = token(string(command)) - cmd = strings.ToLower(cmd) - switch cmd { - case "flushdb": - break reading // all done - case "drop": - if line, key = token(line); key == "" { - err = errors.New("DROP is missing key") - return - } - if !keyIgnoreM[key] { - keyIgnoreM[key] = true - } - case "del": - if line, key = token(line); key == "" { - err = errors.New("DEL is missing key") - return - } - if keyIgnoreM[key] { - continue // ignore - } - if line, id = token(line); id == "" { - err = errors.New("DEL is missing id") - return - } - if keyBucketM.Get(&treeKeyBoolT{key}) == nil { - if _, err = tx.CreateBucket([]byte(key + ".ids")); err != nil { - return - } - if _, err = tx.CreateBucket([]byte(key + ".ignore_ids")); err != nil { - return - } - keyBucketM.ReplaceOrInsert(&treeKeyBoolT{key}) - } - b = tx.Bucket([]byte(key + ".ignore_ids")) - err = b.Put([]byte(id), []byte("2")) // 2 for hard ignore - if err != nil { - return - } + // var buf []byte + // var pos int64 + // pos, err = f.Seek(endpos, 0) + // if err != nil { + // return + // } + // var readPreviousCommand func() ([]byte, error) + // readPreviousCommand = func() ([]byte, error) { + // if len(buf) >= 5 { + // if buf[len(buf)-1] != 0 { + // return nil, errCorruptedAOF + // } + // sz2 := int(binary.LittleEndian.Uint32(buf[len(buf)-5:])) + // if len(buf) >= sz2+9 { + // sz1 := int(binary.LittleEndian.Uint32(buf[len(buf)-(sz2+9):])) + // if sz1 != sz2 { + // return nil, errCorruptedAOF + // } + // command := buf[len(buf)-(sz2+5) : len(buf)-5] + // buf = buf[:len(buf)-(sz2+9)] + // return command, nil + // } + // } + // if pos == 0 { + // if len(buf) > 0 { + // return nil, io.ErrUnexpectedEOF + // } else { + // return nil, io.EOF + // } + // } + // sz := int64(backwardsBufferSize) + // offset := pos - sz + // if offset < 0 { + // sz = pos + // offset = 0 + // } + // pos, err = f.Seek(offset, 0) + // if err != nil { + // return nil, err + // } + // nbuf := make([]byte, int(sz)) + // _, err = io.ReadFull(f, nbuf) + // if err != nil { + // return nil, err + // } + // if len(buf) > 0 { + // nbuf = append(nbuf, buf...) + // } + // buf = nbuf + // return readPreviousCommand() + // } + // var tx *bolt.Tx + // tx, err = db.Begin(true) + // if err != nil { + // return + // } + // defer func() { + // tx.Rollback() + // }() + // var keyIgnoreM = map[string]bool{} + // var keyBucketM = btree.New(16) + // var cmd, key, id, field string + // var line string + // var command []byte + // var val []byte + // var b *bolt.Bucket + // reading: + // for i := 0; ; i++ { + // if i%500 == 0 { + // if err = tx.Commit(); err != nil { + // return + // } + // tx, err = db.Begin(true) + // if err != nil { + // return + // } + // } + // command, err = readPreviousCommand() + // if err != nil { + // if err == io.EOF { + // err = nil + // break + // } + // return + // } + // // quick path + // if len(command) == 0 { + // continue // ignore blank commands + // } + // line, cmd = token(string(command)) + // cmd = strings.ToLower(cmd) + // switch cmd { + // case "flushdb": + // break reading // all done + // case "drop": + // if line, key = token(line); key == "" { + // err = errors.New("DROP is missing key") + // return + // } + // if !keyIgnoreM[key] { + // keyIgnoreM[key] = true + // } + // case "del": + // if line, key = token(line); key == "" { + // err = errors.New("DEL is missing key") + // return + // } + // if keyIgnoreM[key] { + // continue // ignore + // } + // if line, id = token(line); id == "" { + // err = errors.New("DEL is missing id") + // return + // } + // if keyBucketM.Get(&treeKeyBoolT{key}) == nil { + // if _, err = tx.CreateBucket([]byte(key + ".ids")); err != nil { + // return + // } + // if _, err = tx.CreateBucket([]byte(key + ".ignore_ids")); err != nil { + // return + // } + // keyBucketM.ReplaceOrInsert(&treeKeyBoolT{key}) + // } + // b = tx.Bucket([]byte(key + ".ignore_ids")) + // err = b.Put([]byte(id), []byte("2")) // 2 for hard ignore + // if err != nil { + // return + // } - case "set": - if line, key = token(line); key == "" { - err = errors.New("SET is missing key") - return - } - if keyIgnoreM[key] { - continue // ignore - } - if line, id = token(line); id == "" { - err = errors.New("SET is missing id") - return - } - if keyBucketM.Get(&treeKeyBoolT{key}) == nil { - if _, err = tx.CreateBucket([]byte(key + ".ids")); err != nil { - return - } - if _, err = tx.CreateBucket([]byte(key + ".ignore_ids")); err != nil { - return - } - keyBucketM.ReplaceOrInsert(&treeKeyBoolT{key}) - } - b = tx.Bucket([]byte(key + ".ignore_ids")) - val = b.Get([]byte(id)) - if val == nil { - if err = b.Put([]byte(id), []byte("1")); err != nil { - return - } - b = tx.Bucket([]byte(key + ".ids")) - if err = b.Put([]byte(id), command); err != nil { - return - } - } else { - switch string(val) { - default: - err = errors.New("invalid ignore") - case "1", "2": - continue // ignore - } - } - case "fset": - if line, key = token(line); key == "" { - err = errors.New("FSET is missing key") - return - } - if keyIgnoreM[key] { - continue // ignore - } - if line, id = token(line); id == "" { - err = errors.New("FSET is missing id") - return - } - if line, field = token(line); field == "" { - err = errors.New("FSET is missing field") - return - } - if keyBucketM.Get(&treeKeyBoolT{key}) == nil { - if _, err = tx.CreateBucket([]byte(key + ".ids")); err != nil { - return - } - if _, err = tx.CreateBucket([]byte(key + ".ignore_ids")); err != nil { - return - } - keyBucketM.ReplaceOrInsert(&treeKeyBoolT{key}) - } - b = tx.Bucket([]byte(key + ".ignore_ids")) - val = b.Get([]byte(id)) - if val == nil { - b = tx.Bucket([]byte(key + ":" + id + ":0")) - if b == nil { - if b, err = tx.CreateBucket([]byte(key + ":" + id + ":0")); err != nil { - return - } - } - if b.Get([]byte(field)) == nil { - if err = b.Put([]byte(field), command); err != nil { - return - } - } - } else { - switch string(val) { - default: - err = errors.New("invalid ignore") - case "1": - b = tx.Bucket([]byte(key + ":" + id + ":1")) - if b == nil { - if b, err = tx.CreateBucket([]byte(key + ":" + id + ":1")); err != nil { - return - } - } - if b.Get([]byte(field)) == nil { - if err = b.Put([]byte(field), command); err != nil { - return - } - } - case "2": - continue // ignore - } - } - } - } - if err = tx.Commit(); err != nil { - return - } - tx, err = db.Begin(false) - if err != nil { - return - } - keyBucketM.Ascend(func(item btree.Item) bool { - key := item.(*treeKeyBoolT).key - b := tx.Bucket([]byte(key + ".ids")) - if b != nil { - err = b.ForEach(func(id, command []byte) error { - // parse the SET command - _, fields, values, etype, eargs, err := c.parseSetArgs(string(command[4:])) - if err != nil { - return err - } - // store the fields in a map - var fieldm = map[string]float64{} - for i, field := range fields { - fieldm[field] = values[i] - } - // append old FSET values. these are FSETs that existed prior to the last SET. - f1 := tx.Bucket([]byte(key + ":" + string(id) + ":1")) - if f1 != nil { - err = f1.ForEach(func(field, command []byte) error { - d, err := c.parseFSetArgs(string(command[5:])) - if err != nil { - return err - } - if _, ok := fieldm[d.field]; !ok { - fieldm[d.field] = d.value - } - return nil - }) - if err != nil { - return err - } - } - // append new FSET values. these are FSETs that were added after the last SET. - f0 := tx.Bucket([]byte(key + ":" + string(id) + ":0")) - if f0 != nil { - f0.ForEach(func(field, command []byte) error { - d, err := c.parseFSetArgs(string(command[5:])) - if err != nil { - return err - } - fieldm[d.field] = d.value - return nil - }) - } - // rebuild the SET command - ncommand := "set " + key + " " + string(id) - for field, value := range fieldm { - if value != 0 { - ncommand += " field " + field + " " + strconv.FormatFloat(value, 'f', -1, 64) - } - } - ncommand += " " + strings.ToUpper(etype) + " " + eargs - _, err = writeCommand(nf, []byte(ncommand)) - if err != nil { - return err - } - return nil - }) - if err != nil { - return false - } - } - return true - }) - if err == nil { - // add all of the hooks - for _, line := range hooks { - _, err = writeCommand(nf, []byte(line)) - if err != nil { - return - } - } - } + // case "set": + // if line, key = token(line); key == "" { + // err = errors.New("SET is missing key") + // return + // } + // if keyIgnoreM[key] { + // continue // ignore + // } + // if line, id = token(line); id == "" { + // err = errors.New("SET is missing id") + // return + // } + // if keyBucketM.Get(&treeKeyBoolT{key}) == nil { + // if _, err = tx.CreateBucket([]byte(key + ".ids")); err != nil { + // return + // } + // if _, err = tx.CreateBucket([]byte(key + ".ignore_ids")); err != nil { + // return + // } + // keyBucketM.ReplaceOrInsert(&treeKeyBoolT{key}) + // } + // b = tx.Bucket([]byte(key + ".ignore_ids")) + // val = b.Get([]byte(id)) + // if val == nil { + // if err = b.Put([]byte(id), []byte("1")); err != nil { + // return + // } + // b = tx.Bucket([]byte(key + ".ids")) + // if err = b.Put([]byte(id), command); err != nil { + // return + // } + // } else { + // switch string(val) { + // default: + // err = errors.New("invalid ignore") + // case "1", "2": + // continue // ignore + // } + // } + // case "fset": + // if line, key = token(line); key == "" { + // err = errors.New("FSET is missing key") + // return + // } + // if keyIgnoreM[key] { + // continue // ignore + // } + // if line, id = token(line); id == "" { + // err = errors.New("FSET is missing id") + // return + // } + // if line, field = token(line); field == "" { + // err = errors.New("FSET is missing field") + // return + // } + // if keyBucketM.Get(&treeKeyBoolT{key}) == nil { + // if _, err = tx.CreateBucket([]byte(key + ".ids")); err != nil { + // return + // } + // if _, err = tx.CreateBucket([]byte(key + ".ignore_ids")); err != nil { + // return + // } + // keyBucketM.ReplaceOrInsert(&treeKeyBoolT{key}) + // } + // b = tx.Bucket([]byte(key + ".ignore_ids")) + // val = b.Get([]byte(id)) + // if val == nil { + // b = tx.Bucket([]byte(key + ":" + id + ":0")) + // if b == nil { + // if b, err = tx.CreateBucket([]byte(key + ":" + id + ":0")); err != nil { + // return + // } + // } + // if b.Get([]byte(field)) == nil { + // if err = b.Put([]byte(field), command); err != nil { + // return + // } + // } + // } else { + // switch string(val) { + // default: + // err = errors.New("invalid ignore") + // case "1": + // b = tx.Bucket([]byte(key + ":" + id + ":1")) + // if b == nil { + // if b, err = tx.CreateBucket([]byte(key + ":" + id + ":1")); err != nil { + // return + // } + // } + // if b.Get([]byte(field)) == nil { + // if err = b.Put([]byte(field), command); err != nil { + // return + // } + // } + // case "2": + // continue // ignore + // } + // } + // } + // } + // if err = tx.Commit(); err != nil { + // return + // } + // tx, err = db.Begin(false) + // if err != nil { + // return + // } + // keyBucketM.Ascend(func(item btree.Item) bool { + // key := item.(*treeKeyBoolT).key + // b := tx.Bucket([]byte(key + ".ids")) + // if b != nil { + // err = b.ForEach(func(id, command []byte) error { + // // parse the SET command + // _, fields, values, etype, eargs, err := c.parseSetArgs(string(command[4:])) + // if err != nil { + // return err + // } + // // store the fields in a map + // var fieldm = map[string]float64{} + // for i, field := range fields { + // fieldm[field] = values[i] + // } + // // append old FSET values. these are FSETs that existed prior to the last SET. + // f1 := tx.Bucket([]byte(key + ":" + string(id) + ":1")) + // if f1 != nil { + // err = f1.ForEach(func(field, command []byte) error { + // d, err := c.parseFSetArgs(string(command[5:])) + // if err != nil { + // return err + // } + // if _, ok := fieldm[d.field]; !ok { + // fieldm[d.field] = d.value + // } + // return nil + // }) + // if err != nil { + // return err + // } + // } + // // append new FSET values. these are FSETs that were added after the last SET. + // f0 := tx.Bucket([]byte(key + ":" + string(id) + ":0")) + // if f0 != nil { + // f0.ForEach(func(field, command []byte) error { + // d, err := c.parseFSetArgs(string(command[5:])) + // if err != nil { + // return err + // } + // fieldm[d.field] = d.value + // return nil + // }) + // } + // // rebuild the SET command + // ncommand := "set " + key + " " + string(id) + // for field, value := range fieldm { + // if value != 0 { + // ncommand += " field " + field + " " + strconv.FormatFloat(value, 'f', -1, 64) + // } + // } + // ncommand += " " + strings.ToUpper(etype) + " " + eargs + // _, err = writeCommand(nf, []byte(ncommand)) + // if err != nil { + // return err + // } + // return nil + // }) + // if err != nil { + // return false + // } + // } + // return true + // }) + // if err == nil { + // // add all of the hooks + // for _, line := range hooks { + // _, err = writeCommand(nf, []byte(line)) + // if err != nil { + // return + // } + // } + // } } diff --git a/controller/collection/collection.go b/controller/collection/collection.go index e759f509..4abf5d96 100644 --- a/controller/collection/collection.go +++ b/controller/collection/collection.go @@ -151,18 +151,18 @@ func (c *Collection) Get(id string) (obj geojson.Object, fields []float64, ok bo // SetField set a field value for an object and returns that object. // If the object does not exist then the 'ok' return value will be false. -func (c *Collection) SetField(id, field string, value float64) (obj geojson.Object, fields []float64, ok bool) { +func (c *Collection) SetField(id, field string, value float64) (obj geojson.Object, fields []float64, updated bool, ok bool) { i := c.items.Get(&itemT{ID: id}) if i == nil { ok = false return } item := i.(*itemT) - c.setField(item, field, value) - return item.Object, item.Fields, true + updated = c.setField(item, field, value) + return item.Object, item.Fields, updated, true } -func (c *Collection) setField(item *itemT, field string, value float64) { +func (c *Collection) setField(item *itemT, field string, value float64) (updated bool) { idx, ok := c.fieldMap[field] if !ok { idx = len(c.fieldMap) @@ -173,7 +173,12 @@ func (c *Collection) setField(item *itemT, field string, value float64) { item.Fields = append(item.Fields, math.NaN()) } c.weight += len(item.Fields) * 8 + ovalue := item.Fields[idx] + if math.IsNaN(ovalue) { + ovalue = 0 + } item.Fields[idx] = value + return ovalue != value } // FieldMap return a maps of the field names. diff --git a/controller/controller.go b/controller/controller.go index 327507d0..3a9f80fb 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -8,12 +8,12 @@ import ( "fmt" "io" "os" - "runtime" "strings" "sync" "time" "github.com/google/btree" + "github.com/tidwall/resp" "github.com/tidwall/tile38/controller/collection" "github.com/tidwall/tile38/controller/log" "github.com/tidwall/tile38/controller/server" @@ -35,6 +35,7 @@ type commandDetailsT struct { fields []float64 oldObj geojson.Object oldFields []float64 + updated bool } func (col *collectionT) Less(item btree.Item) bool { @@ -105,8 +106,8 @@ func ListenAndServe(host string, port int, dir string) error { c.mu.Unlock() }() go c.processLives() - handler := func(conn *server.Conn, command []byte, rd *bufio.Reader, w io.Writer, websocket bool) error { - err := c.handleInputCommand(conn, string(command), w) + handler := func(conn *server.Conn, msg *server.Message, rd *bufio.Reader, w io.Writer, websocket bool) error { + err := c.handleInputCommand(conn, msg, w) if err != nil { if err.Error() == "going live" { return c.goLive(err, conn, rd, websocket) @@ -162,55 +163,73 @@ func isReservedFieldName(field string) bool { return false } -func (c *Controller) handleInputCommand(conn *server.Conn, line string, w io.Writer) error { - if core.ShowDebugMessages && line != "pInG" { - log.Debug(line) +func (c *Controller) handleInputCommand(conn *server.Conn, msg *server.Message, w io.Writer) error { + var words []string + for _, v := range msg.Values { + words = append(words, v.String()) } + // line := strings.Join(words, " ") + + // if core.ShowDebugMessages && line != "pInG" { + // log.Debug(line) + // } start := time.Now() + // Ping. Just send back the response. No need to put through the pipeline. - if len(line) == 4 && (line[0] == 'p' || line[0] == 'P') && lc(line, "ping") { - w.Write([]byte(`{"ok":true,"ping":"pong","elapsed":"` + time.Now().Sub(start).String() + `"}`)) + if msg.Command == "ping" { + switch msg.OutputType { + case server.JSON: + w.Write([]byte(`{"ok":true,"ping":"pong","elapsed":"` + time.Now().Sub(start).String() + `"}`)) + case server.RESP: + io.WriteString(w, "+PONG\r\n") + } return nil } writeErr := func(err error) error { - js := `{"ok":false,"err":` + jsonString(err.Error()) + `,"elapsed":"` + time.Now().Sub(start).String() + "\"}" - if _, err := w.Write([]byte(js)); err != nil { - return err + switch msg.OutputType { + case server.JSON: + io.WriteString(w, `{"ok":false,"err":`+jsonString(err.Error())+`,"elapsed":"`+time.Now().Sub(start).String()+"\"}") + case server.RESP: + if err == errInvalidNumberOfArguments { + io.WriteString(w, "-ERR wrong number of arguments for '"+msg.Command+"' command\r\n") + } else { + v, _ := resp.ErrorValue(errors.New("ERR " + err.Error())).MarshalRESP() + io.WriteString(w, string(v)) + } } return nil } var write bool - _, cmd := tokenlc(line) - if cmd == "" { - return writeErr(errors.New("empty command")) - } - if !conn.Authenticated || cmd == "auth" { + if !conn.Authenticated || msg.Command == "auth" { c.mu.RLock() requirePass := c.config.RequirePass c.mu.RUnlock() if requirePass != "" { // This better be an AUTH command. - if cmd != "auth" { + if msg.Command != "auth" { // Just shut down the pipeline now. The less the client connection knows the better. return writeErr(errors.New("authentication required")) } - password, _ := token(line) + password := "" + if len(msg.Values) > 1 { + password = msg.Values[1].String() + } if requirePass != strings.TrimSpace(password) { return writeErr(errors.New("invalid password")) } conn.Authenticated = true w.Write([]byte(`{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}")) return nil - } else if cmd == "auth" { + } else if msg.Command == "auth" { return writeErr(errors.New("invalid password")) } } // choose the locking strategy - switch cmd { + switch msg.Command { default: c.mu.RLock() defer c.mu.RUnlock() @@ -243,7 +262,7 @@ func (c *Controller) handleInputCommand(conn *server.Conn, line string, w io.Wri // no locks! DEV MODE ONLY } - resp, d, err := c.command(line, w) + res, d, err := c.command(msg, w) if err != nil { if err.Error() == "going live" { return err @@ -251,7 +270,7 @@ func (c *Controller) handleInputCommand(conn *server.Conn, line string, w io.Wri return writeErr(err) } if write { - if err := c.writeAOF(line, &d); err != nil { + if err := c.writeAOF(resp.ArrayValue(msg.Values), &d); err != nil { if _, ok := err.(errAOFHook); ok { return writeErr(err) } @@ -259,8 +278,8 @@ func (c *Controller) handleInputCommand(conn *server.Conn, line string, w io.Wri return err } } - if resp != "" { - if _, err := io.WriteString(w, resp); err != nil { + if res != "" { + if _, err := io.WriteString(w, res); err != nil { return err } } @@ -284,84 +303,85 @@ func (c *Controller) reset() { c.cols = btree.New(16) } -func (c *Controller) command(line string, w io.Writer) (resp string, d commandDetailsT, err error) { +func (c *Controller) command(msg *server.Message, w io.Writer) (res string, d commandDetailsT, err error) { start := time.Now() okResp := func() string { - if w == nil { - return "" + if w != nil { + switch msg.OutputType { + case server.JSON: + return `{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}" + case server.RESP: + return "+OK\r\n" + } } - return `{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}" + return "" } - nline, cmd := tokenlc(line) - switch cmd { + okResp = okResp + switch msg.Command { default: - err = fmt.Errorf("unknown command '%s'", cmd) + err = fmt.Errorf("unknown command '%s'", msg.Values[0]) return // lock case "set": - d, err = c.cmdSet(nline) - resp = okResp() + res, d, err = c.cmdSet(msg) case "fset": - d, err = c.cmdFset(nline) - resp = okResp() + res, d, err = c.cmdFset(msg) case "del": - d, err = c.cmdDel(nline) - resp = okResp() + res, d, err = c.cmdDel(msg) case "drop": - d, err = c.cmdDrop(nline) - resp = okResp() + res, d, err = c.cmdDrop(msg) case "flushdb": - d, err = c.cmdFlushDB(nline) - resp = okResp() - case "sethook": - err = c.cmdSetHook(nline) - resp = okResp() - case "delhook": - err = c.cmdDelHook(nline) - resp = okResp() - case "hooks": - err = c.cmdHooks(nline, w) - case "massinsert": - if !core.DevMode { - err = fmt.Errorf("unknown command '%s'", cmd) - return - } - err = c.cmdMassInsert(nline) - resp = okResp() - case "follow": - err = c.cmdFollow(nline) - resp = okResp() - case "config": - resp, err = c.cmdConfig(nline) - case "readonly": - err = c.cmdReadOnly(nline) - resp = okResp() - case "stats": - resp, err = c.cmdStats(nline) - case "server": - resp, err = c.cmdServer(nline) - case "scan": - err = c.cmdScan(nline, w) - case "nearby": - err = c.cmdNearby(nline, w) - case "within": - err = c.cmdWithin(nline, w) - case "intersects": - err = c.cmdIntersects(nline, w) + res, d, err = c.cmdFlushDB(msg) + // case "sethook": + // err = c.cmdSetHook(nline) + // resp = okResp() + // case "delhook": + // err = c.cmdDelHook(nline) + // resp = okResp() + // case "hooks": + // err = c.cmdHooks(nline, w) + // case "massinsert": + // if !core.DevMode { + // err = fmt.Errorf("unknown command '%s'", cmd) + // return + // } + // err = c.cmdMassInsert(nline) + // resp = okResp() + // case "follow": + // err = c.cmdFollow(nline) + // resp = okResp() + // case "config": + // resp, err = c.cmdConfig(nline) + // case "readonly": + // err = c.cmdReadOnly(nline) + // resp = okResp() + // case "stats": + // resp, err = c.cmdStats(nline) + // case "server": + // resp, err = c.cmdServer(nline) + // case "scan": + // err = c.cmdScan(nline, w) + // case "nearby": + // err = c.cmdNearby(nline, w) + // case "within": + // err = c.cmdWithin(nline, w) + // case "intersects": + // err = c.cmdIntersects(nline, w) + case "get": - resp, err = c.cmdGet(nline) - case "keys": - err = c.cmdKeys(nline, w) - case "aof": - err = c.cmdAOF(nline, w) - case "aofmd5": - resp, err = c.cmdAOFMD5(nline) - case "gc": - go runtime.GC() - resp = okResp() - case "aofshrink": - go c.aofshrink() - resp = okResp() + res, err = c.cmdGet(msg) + // case "keys": + // err = c.cmdKeys(nline, w) + // case "aof": + // err = c.cmdAOF(nline, w) + // case "aofmd5": + // resp, err = c.cmdAOFMD5(nline) + // case "gc": + // go runtime.GC() + // resp = okResp() + // case "aofshrink": + // go c.aofshrink() + // resp = okResp() } return } diff --git a/controller/crud.go b/controller/crud.go index d1ab9e84..dbe92c49 100644 --- a/controller/crud.go +++ b/controller/crud.go @@ -3,51 +3,121 @@ package controller import ( "bytes" "math" + "sort" "strconv" "strings" "time" "github.com/google/btree" + "github.com/tidwall/resp" "github.com/tidwall/tile38/controller/collection" + "github.com/tidwall/tile38/controller/server" "github.com/tidwall/tile38/geojson" "github.com/tidwall/tile38/geojson/geohash" ) -func (c *Controller) cmdGet(line string) (string, error) { +type fvt struct { + field string + value float64 +} + +type byField []fvt + +func (a byField) Len() int { + return len(a) +} +func (a byField) Less(i, j int) bool { + return a[i].field < a[j].field +} +func (a byField) Swap(i, j int) { + a[i], a[j] = a[j], a[i] +} + +func orderFields(fmap map[string]int, fields []float64) []fvt { + var fv fvt + fvs := make([]fvt, 0, len(fmap)) + for field, idx := range fmap { + if idx < len(fields) { + fv.field = field + fv.value = fields[idx] + if !math.IsNaN(fv.value) && fv.value != 0 { + fvs = append(fvs, fv) + } + } + } + sort.Sort(byField(fvs)) + return fvs +} + +func (c *Controller) cmdGet(msg *server.Message) (string, error) { start := time.Now() + vs := msg.Values[1:] + + var ok bool var key, id, typ, sprecision string - if line, key = token(line); key == "" { + if vs, key, ok = tokenval(vs); !ok || key == "" { return "", errInvalidNumberOfArguments } - if line, id = token(line); id == "" { + if vs, id, ok = tokenval(vs); !ok || id == "" { return "", errInvalidNumberOfArguments } col := c.getCol(key) if col == nil { + if msg.OutputType == server.RESP { + return "$-1\r\n", nil + } return "", errKeyNotFound } o, fields, ok := col.Get(id) if !ok { + if msg.OutputType == server.RESP { + return "$-1\r\n", nil + } return "", errIDNotFound } + + vals := make([]resp.Value, 0, 2) var buf bytes.Buffer - buf.WriteString(`{"ok":true`) - if line, typ = token(line); typ == "" || strings.ToLower(typ) == "object" { - buf.WriteString(`,"object":`) - buf.WriteString(o.JSON()) + if msg.OutputType == server.JSON { + buf.WriteString(`{"ok":true`) + } + if vs, typ, ok = tokenval(vs); !ok || strings.ToLower(typ) == "object" { + if msg.OutputType == server.JSON { + buf.WriteString(`,"object":`) + buf.WriteString(o.JSON()) + } else { + vals = append(vals, resp.ArrayValue([]resp.Value{resp.StringValue(o.JSON())})) + } } else { - ltyp := strings.ToLower(typ) - switch ltyp { + switch strings.ToLower(typ) { default: return "", errInvalidArgument(typ) case "point": - buf.WriteString(`,"point":`) - buf.WriteString(o.CalculatedPoint().ExternalJSON()) + point := o.CalculatedPoint() + if msg.OutputType == server.JSON { + buf.WriteString(`,"point":`) + buf.WriteString(point.ExternalJSON()) + } else { + if point.Z != 0 { + vals = append(vals, resp.ArrayValue([]resp.Value{ + resp.StringValue(strconv.FormatFloat(point.Y, 'f', -1, 64)), + resp.StringValue(strconv.FormatFloat(point.X, 'f', -1, 64)), + resp.StringValue(strconv.FormatFloat(point.Z, 'f', -1, 64)), + })) + } else { + vals = append(vals, resp.ArrayValue([]resp.Value{ + resp.StringValue(strconv.FormatFloat(point.Y, 'f', -1, 64)), + resp.StringValue(strconv.FormatFloat(point.X, 'f', -1, 64)), + })) + } + } case "hash": - if line, sprecision = token(line); sprecision == "" { + if vs, sprecision, ok = tokenval(vs); !ok || sprecision == "" { return "", errInvalidNumberOfArguments } - buf.WriteString(`,"hash":`) + if msg.OutputType == server.JSON { + buf.WriteString(`,"hash":`) + } precision, err := strconv.ParseInt(sprecision, 10, 64) if err != nil || precision < 1 || precision > 64 { return "", errInvalidArgument(sprecision) @@ -56,81 +126,145 @@ func (c *Controller) cmdGet(line string) (string, error) { if err != nil { return "", err } - buf.WriteString(`"` + p + `"`) + if msg.OutputType == server.JSON { + buf.WriteString(`"` + p + `"`) + } else { + vals = append(vals, resp.ArrayValue([]resp.Value{resp.StringValue(p)})) + } case "bounds": - buf.WriteString(`,"bounds":`) - buf.WriteString(o.CalculatedBBox().ExternalJSON()) - } - } - if line != "" { - return "", errInvalidNumberOfArguments - } - fmap := col.FieldMap() - if len(fmap) > 0 { - buf.WriteString(`,"fields":{`) - var i int - for field, idx := range fmap { - if len(fields) > idx { - if !math.IsNaN(fields[idx]) { - if i > 0 { - buf.WriteString(`,`) - } - buf.WriteString(jsonString(field) + ":" + strconv.FormatFloat(fields[idx], 'f', -1, 64)) - i++ - } + bbox := o.CalculatedBBox() + if msg.OutputType == server.JSON { + buf.WriteString(`,"bounds":`) + buf.WriteString(bbox.ExternalJSON()) + } else { + vals = append(vals, resp.ArrayValue([]resp.Value{ + resp.StringValue(strconv.FormatFloat(bbox.Min.Y, 'f', -1, 64)), + resp.StringValue(strconv.FormatFloat(bbox.Min.X, 'f', -1, 64)), + resp.StringValue(strconv.FormatFloat(bbox.Max.Y, 'f', -1, 64)), + resp.StringValue(strconv.FormatFloat(bbox.Max.X, 'f', -1, 64)), + })) } } - buf.WriteString(`}`) } - buf.WriteString(`,"elapsed":"` + time.Now().Sub(start).String() + "\"}") - return buf.String(), nil + if len(vs) != 0 { + return "", errInvalidNumberOfArguments + } + + fvs := orderFields(col.FieldMap(), fields) + if len(fvs) > 0 { + fvals := make([]resp.Value, 0, len(fvs)*2) + if msg.OutputType == server.JSON { + buf.WriteString(`,"fields":{`) + } + for i, fv := range fvs { + if msg.OutputType == server.JSON { + if i > 0 { + buf.WriteString(`,`) + } + buf.WriteString(jsonString(fv.field) + ":" + strconv.FormatFloat(fv.value, 'f', -1, 64)) + } else { + fvals = append(fvals, resp.StringValue(fv.field), resp.StringValue(strconv.FormatFloat(fv.value, 'f', -1, 64))) + } + i++ + } + if msg.OutputType == server.JSON { + buf.WriteString(`}`) + } else { + vals = append(vals, resp.ArrayValue(fvals)) + } + } + if msg.OutputType == server.JSON { + buf.WriteString(`,"elapsed":"` + time.Now().Sub(start).String() + "\"}") + return buf.String(), nil + } + data, err := resp.ArrayValue(vals).MarshalRESP() + if err != nil { + return "", err + } + return string(data), nil + } -func (c *Controller) cmdDel(line string) (d commandDetailsT, err error) { - if line, d.key = token(line); d.key == "" { +func (c *Controller) cmdDel(msg *server.Message) (res string, d commandDetailsT, err error) { + start := time.Now() + vs := msg.Values[1:] + var ok bool + if vs, d.key, ok = tokenval(vs); !ok || d.key == "" { err = errInvalidNumberOfArguments return } - if line, d.id = token(line); d.id == "" { + if vs, d.id, ok = tokenval(vs); !ok || d.id == "" { err = errInvalidNumberOfArguments return } - if line != "" { + if len(vs) != 0 { err = errInvalidNumberOfArguments return } + found := false col := c.getCol(d.key) if col != nil { - d.obj, d.fields, _ = col.Remove(d.id) - if col.Count() == 0 { - c.deleteCol(d.key) + d.obj, d.fields, ok = col.Remove(d.id) + if ok { + if col.Count() == 0 { + c.deleteCol(d.key) + } + found = true } } d.command = "del" + d.updated = found + switch msg.OutputType { + case server.JSON: + res = `{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}" + case server.RESP: + if d.updated { + res = ":1\r\n" + } else { + res = ":0\r\n" + } + } return } -func (c *Controller) cmdDrop(line string) (d commandDetailsT, err error) { - if line, d.key = token(line); d.key == "" { +func (c *Controller) cmdDrop(msg *server.Message) (res string, d commandDetailsT, err error) { + start := time.Now() + vs := msg.Values[1:] + var ok bool + if vs, d.key, ok = tokenval(vs); !ok || d.key == "" { err = errInvalidNumberOfArguments return } - if line != "" { + if len(vs) != 0 { err = errInvalidNumberOfArguments return } col := c.getCol(d.key) if col != nil { c.deleteCol(d.key) + d.updated = true } else { d.key = "" // ignore the details + d.updated = false } d.command = "drop" + switch msg.OutputType { + case server.JSON: + res = `{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}" + case server.RESP: + if d.updated { + res = ":1\r\n" + } else { + res = ":0\r\n" + } + } return } -func (c *Controller) cmdFlushDB(line string) (d commandDetailsT, err error) { - if line != "" { +func (c *Controller) cmdFlushDB(msg *server.Message) (res string, d commandDetailsT, err error) { + start := time.Now() + vs := msg.Values[1:] + if len(vs) != 0 { err = errInvalidNumberOfArguments return } @@ -139,35 +273,43 @@ func (c *Controller) cmdFlushDB(line string) (d commandDetailsT, err error) { c.hooks = make(map[string]*Hook) c.hookcols = make(map[string]map[string]*Hook) d.command = "flushdb" + d.updated = true + switch msg.OutputType { + case server.JSON: + res = `{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}" + case server.RESP: + res = "+OK\r\n" + } return } -func (c *Controller) parseSetArgs(line string) (d commandDetailsT, fields []string, values []float64, etype, eline string, err error) { - +func (c *Controller) parseSetArgs(vs []resp.Value) (d commandDetailsT, fields []string, values []float64, etype string, evs []resp.Value, err error) { + var ok bool var typ string - if line, d.key = token(line); d.key == "" { + if vs, d.key, ok = tokenval(vs); !ok || d.key == "" { err = errInvalidNumberOfArguments return } - if line, d.id = token(line); d.id == "" { + if vs, d.id, ok = tokenval(vs); !ok || d.id == "" { err = errInvalidNumberOfArguments return } + var arg string - var nline string + var nvs []resp.Value fields = make([]string, 0, 8) values = make([]float64, 0, 8) for { - if nline, arg = token(line); arg == "" { + if nvs, arg, ok = tokenval(vs); !ok || arg == "" { err = errInvalidNumberOfArguments return } if lc(arg, "field") { - line = nline + vs = nvs var name string var svalue string var value float64 - if line, name = token(line); name == "" { + if vs, name, ok = tokenval(vs); !ok || name == "" { err = errInvalidNumberOfArguments return } @@ -175,7 +317,7 @@ func (c *Controller) parseSetArgs(line string) (d commandDetailsT, fields []stri err = errInvalidArgument(name) return } - if line, svalue = token(line); svalue == "" { + if vs, svalue, ok = tokenval(vs); !ok || svalue == "" { err = errInvalidNumberOfArguments return } @@ -190,16 +332,16 @@ func (c *Controller) parseSetArgs(line string) (d commandDetailsT, fields []stri } break } - if line, typ = token(line); typ == "" { + if vs, typ, ok = tokenval(vs); !ok || typ == "" { err = errInvalidNumberOfArguments return } - if line == "" { + if len(vs) == 0 { err = errInvalidNumberOfArguments return } etype = typ - eline = line + evs = vs switch { default: @@ -207,16 +349,16 @@ func (c *Controller) parseSetArgs(line string) (d commandDetailsT, fields []stri return case lc(typ, "point"): var slat, slon, sz string - if line, slat = token(line); slat == "" { + if vs, slat, ok = tokenval(vs); !ok || slat == "" { err = errInvalidNumberOfArguments return } - if line, slon = token(line); slon == "" { + if vs, slon, ok = tokenval(vs); !ok || slon == "" { err = errInvalidNumberOfArguments return } - line, sz = token(line) - if sz == "" { + vs, sz, ok = tokenval(vs) + if !ok || sz == "" { var sp geojson.SimplePoint sp.Y, err = strconv.ParseFloat(slat, 64) if err != nil { @@ -250,19 +392,19 @@ func (c *Controller) parseSetArgs(line string) (d commandDetailsT, fields []stri } case lc(typ, "bounds"): var sminlat, sminlon, smaxlat, smaxlon string - if line, sminlat = token(line); sminlat == "" { + if vs, sminlat, ok = tokenval(vs); !ok || sminlat == "" { err = errInvalidNumberOfArguments return } - if line, sminlon = token(line); sminlon == "" { + if vs, sminlon, ok = tokenval(vs); !ok || sminlon == "" { err = errInvalidNumberOfArguments return } - if line, smaxlat = token(line); smaxlat == "" { + if vs, smaxlat, ok = tokenval(vs); !ok || smaxlat == "" { err = errInvalidNumberOfArguments return } - if line, smaxlon = token(line); smaxlon == "" { + if vs, smaxlon, ok = tokenval(vs); !ok || smaxlon == "" { err = errInvalidNumberOfArguments return } @@ -302,7 +444,7 @@ func (c *Controller) parseSetArgs(line string) (d commandDetailsT, fields []stri case lc(typ, "hash"): var sp geojson.SimplePoint var shash string - if line, shash = token(line); shash == "" { + if vs, shash, ok = tokenval(vs); !ok || shash == "" { err = errInvalidNumberOfArguments return } @@ -315,18 +457,28 @@ func (c *Controller) parseSetArgs(line string) (d commandDetailsT, fields []stri sp.Y = lat d.obj = sp case lc(typ, "object"): - d.obj, err = geojson.ObjectJSON(line) + var object string + if vs, object, ok = tokenval(vs); !ok || object == "" { + err = errInvalidNumberOfArguments + return + } + d.obj, err = geojson.ObjectJSON(object) if err != nil { return } } + if len(vs) != 0 { + err = errInvalidNumberOfArguments + } return } -func (c *Controller) cmdSet(line string) (d commandDetailsT, err error) { +func (c *Controller) cmdSet(msg *server.Message) (res string, d commandDetailsT, err error) { + start := time.Now() + vs := msg.Values[1:] var fields []string var values []float64 - d, fields, values, _, _, err = c.parseSetArgs(line) + d, fields, values, _, _, err = c.parseSetArgs(vs) if err != nil { return } @@ -337,20 +489,28 @@ func (c *Controller) cmdSet(line string) (d commandDetailsT, err error) { } d.oldObj, d.oldFields, d.fields = col.ReplaceOrInsert(d.id, d.obj, fields, values) d.command = "set" + d.updated = true // perhaps we should do a diff on the previous object? + switch msg.OutputType { + case server.JSON: + res = `{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}" + case server.RESP: + res = "+OK\r\n" + } return } -func (c *Controller) parseFSetArgs(line string) (d commandDetailsT, err error) { +func (c *Controller) parseFSetArgs(vs []resp.Value) (d commandDetailsT, err error) { var svalue string - if line, d.key = token(line); d.key == "" { + var ok bool + if vs, d.key, ok = tokenval(vs); !ok || d.key == "" { err = errInvalidNumberOfArguments return } - if line, d.id = token(line); d.id == "" { + if vs, d.id, ok = tokenval(vs); !ok || d.id == "" { err = errInvalidNumberOfArguments return } - if line, d.field = token(line); d.field == "" { + if vs, d.field, ok = tokenval(vs); !ok || d.field == "" { err = errInvalidNumberOfArguments return } @@ -358,11 +518,11 @@ func (c *Controller) parseFSetArgs(line string) (d commandDetailsT, err error) { err = errInvalidNumberOfArguments return } - if line, svalue = token(line); svalue == "" { + if vs, svalue, ok = tokenval(vs); !ok || svalue == "" { err = errInvalidNumberOfArguments return } - if line != "" { + if len(vs) != 0 { err = errInvalidNumberOfArguments return } @@ -374,19 +534,33 @@ func (c *Controller) parseFSetArgs(line string) (d commandDetailsT, err error) { return } -func (c *Controller) cmdFset(line string) (d commandDetailsT, err error) { - d, err = c.parseFSetArgs(line) +func (c *Controller) cmdFset(msg *server.Message) (res string, d commandDetailsT, err error) { + start := time.Now() + vs := msg.Values[1:] + d, err = c.parseFSetArgs(vs) col := c.getCol(d.key) if col == nil { err = errKeyNotFound return } var ok bool - d.obj, d.fields, ok = col.SetField(d.id, d.field, d.value) + var updated bool + d.obj, d.fields, updated, ok = col.SetField(d.id, d.field, d.value) if !ok { err = errIDNotFound return } d.command = "fset" + d.updated = updated + switch msg.OutputType { + case server.JSON: + res = `{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}" + case server.RESP: + if updated { + res = ":1\r\n" + } else { + res = ":0\r\n" + } + } return } diff --git a/controller/dev.go b/controller/dev.go index 1c9bb582..b0c088c5 100644 --- a/controller/dev.go +++ b/controller/dev.go @@ -1,84 +1,72 @@ package controller -import ( - "errors" - "fmt" - "math/rand" - "strconv" - "sync" - "sync/atomic" - "time" - - "github.com/tidwall/tile38/controller/log" -) - func (c *Controller) cmdMassInsert(line string) error { - // massinsert simply forwards a bunch of cmdSets - var snumCols, snumPoints string - var cols, objs int - if line, snumCols = token(line); snumCols == "" { - return errors.New("invalid number of arguments") - } - if line, snumPoints = token(line); snumPoints == "" { - return errors.New("invalid number of arguments") - } - if line != "" { - return errors.New("invalid number of arguments") - } - n, err := strconv.ParseUint(snumCols, 10, 64) - if err != nil { - return errors.New("invalid argument '" + snumCols + "'") - } - cols = int(n) - n, err = strconv.ParseUint(snumPoints, 10, 64) - if err != nil { - return errors.New("invalid argument '" + snumPoints + "'") - } - docmd := func(cmd string) error { - c.mu.Lock() - defer c.mu.Unlock() - _, d, err := c.command(cmd, nil) - if err != nil { - return err - } - if err := c.writeAOF(cmd, &d); err != nil { - return err - } - return nil - } - rand.Seed(time.Now().UnixNano()) - objs = int(n) - var wg sync.WaitGroup - var k uint64 - wg.Add(cols) - for i := 0; i < cols; i++ { - key := "mi:" + strconv.FormatInt(int64(i), 10) - go func(key string) { - defer func() { - wg.Done() - }() - for j := 0; j < objs; j++ { - id := strconv.FormatInt(int64(j), 10) - lat, lon := rand.Float64()*180-90, rand.Float64()*360-180 - var line string - if true { - fields := fmt.Sprintf("FIELD field %f", rand.Float64()*10) - line = fmt.Sprintf(`set %s %s %s POINT %f %f`, key, id, fields, lat, lon) - } else { - line = fmt.Sprintf(`set %s %s POINT %f %f`, key, id, lat, lon) - } - if err := docmd(line); err != nil { - log.Fatal(err) - return - } - atomic.AddUint64(&k, 1) - if j%10000 == 10000-1 { - log.Infof("massinsert: %s %d/%d", key, atomic.LoadUint64(&k), cols*objs) - } - } - }(key) - } - wg.Wait() - log.Infof("massinsert: done %d objects", atomic.LoadUint64(&k)) + // // massinsert simply forwards a bunch of cmdSets + // var snumCols, snumPoints string + // var cols, objs int + // if line, snumCols = token(line); snumCols == "" { + // return errors.New("invalid number of arguments") + // } + // if line, snumPoints = token(line); snumPoints == "" { + // return errors.New("invalid number of arguments") + // } + // if line != "" { + // return errors.New("invalid number of arguments") + // } + // n, err := strconv.ParseUint(snumCols, 10, 64) + // if err != nil { + // return errors.New("invalid argument '" + snumCols + "'") + // } + // cols = int(n) + // n, err = strconv.ParseUint(snumPoints, 10, 64) + // if err != nil { + // return errors.New("invalid argument '" + snumPoints + "'") + // } + // docmd := func(cmd string) error { + // c.mu.Lock() + // defer c.mu.Unlock() + // _, d, err := c.command(cmd, nil) + // if err != nil { + // return err + // } + // if err := c.writeAOF(cmd, &d); err != nil { + // return err + // } + // return nil + // } + // rand.Seed(time.Now().UnixNano()) + // objs = int(n) + // var wg sync.WaitGroup + // var k uint64 + // wg.Add(cols) + // for i := 0; i < cols; i++ { + // key := "mi:" + strconv.FormatInt(int64(i), 10) + // go func(key string) { + // defer func() { + // wg.Done() + // }() + // for j := 0; j < objs; j++ { + // id := strconv.FormatInt(int64(j), 10) + // lat, lon := rand.Float64()*180-90, rand.Float64()*360-180 + // var line string + // if true { + // fields := fmt.Sprintf("FIELD field %f", rand.Float64()*10) + // line = fmt.Sprintf(`set %s %s %s POINT %f %f`, key, id, fields, lat, lon) + // } else { + // line = fmt.Sprintf(`set %s %s POINT %f %f`, key, id, lat, lon) + // } + // if err := docmd(line); err != nil { + // log.Fatal(err) + // return + // } + // atomic.AddUint64(&k, 1) + // if j%10000 == 10000-1 { + // log.Infof("massinsert: %s %d/%d", key, atomic.LoadUint64(&k), cols*objs) + // } + // } + // }(key) + // } + // wg.Wait() + // log.Infof("massinsert: done %d objects", atomic.LoadUint64(&k)) return nil } diff --git a/controller/follow.go b/controller/follow.go index 280daaf9..2a1440e9 100644 --- a/controller/follow.go +++ b/controller/follow.go @@ -96,16 +96,16 @@ func (c *Controller) cmdFollow(line string) error { func (c *Controller) followHandleCommand(line string, followc uint64, w io.Writer) (int, error) { c.mu.Lock() defer c.mu.Unlock() - if c.followc != followc { - return c.aofsz, errNoLongerFollowing - } - _, d, err := c.command(line, w) - if err != nil { - return c.aofsz, err - } - if err := c.writeAOF(line, &d); err != nil { - return c.aofsz, err - } + // if c.followc != followc { + // return c.aofsz, errNoLongerFollowing + // } + // _, d, err := c.command(line, w) + // if err != nil { + // return c.aofsz, err + // } + // if err := c.writeAOF(line, &d); err != nil { + // return c.aofsz, err + // } return c.aofsz, nil } diff --git a/controller/server/anyreader.go b/controller/server/anyreader.go new file mode 100644 index 00000000..b2100434 --- /dev/null +++ b/controller/server/anyreader.go @@ -0,0 +1,295 @@ +package server + +import ( + "bufio" + "bytes" + "crypto/sha1" + "encoding/base64" + "errors" + "io" + "net/url" + "strconv" + "strings" + + "github.com/tidwall/resp" +) + +const TelnetIsJSON = false + +type Type int + +const ( + Null Type = iota + RESP + Telnet + Native + HTTP + WebSocket + JSON +) + +type errRESPProtocolError struct { + msg string +} + +func (err errRESPProtocolError) Error() string { + return "Protocol error: " + err.msg +} + +type Message struct { + Command string + Values []resp.Value + ConnType Type + OutputType Type + Auth string +} + +type AnyReaderWriter struct { + rd *bufio.Reader + wr io.Writer + ws bool +} + +func NewAnyReaderWriter(rd io.Reader) *AnyReaderWriter { + ar := &AnyReaderWriter{} + if rd2, ok := rd.(*bufio.Reader); ok { + ar.rd = rd2 + } else { + ar.rd = bufio.NewReader(rd) + } + if wr, ok := rd.(io.Writer); ok { + ar.wr = wr + } + return ar +} + +func (ar *AnyReaderWriter) peekcrlfline() (string, error) { + // this is slow operation. + for i := 0; ; i++ { + bb, err := ar.rd.Peek(i) + if err != nil { + return "", err + } + if len(bb) > 2 && bb[len(bb)-2] == '\r' && bb[len(bb)-1] == '\n' { + return string(bb[:len(bb)-2]), nil + } + } +} + +func (ar *AnyReaderWriter) readcrlfline() (string, error) { + var line []byte + for { + bb, err := ar.rd.ReadBytes('\r') + if err != nil { + return "", err + } + if line == nil { + line = bb + } else { + line = append(line, bb...) + } + b, err := ar.rd.ReadByte() + if err != nil { + return "", err + } + if b == '\n' { + return string(line[:len(line)-1]), nil + } + line = append(line, b) + } +} + +func (ar *AnyReaderWriter) ReadMessage() (*Message, error) { + b, err := ar.rd.ReadByte() + if err != nil { + return nil, err + } + if err := ar.rd.UnreadByte(); err != nil { + return nil, err + } + switch b { + case 'G', 'P': + line, err := ar.peekcrlfline() + if err != nil { + return nil, err + } + if strings.HasSuffix(line, " HTTP/1.1") { + return ar.readHTTPMessage() + } + case '$': + return ar.readNativeMessage() + } + // MultiBulk also reads telnet + return ar.readMultiBulkMessage() +} + +func (ar *AnyReaderWriter) readNativeMessage() (*Message, error) { + b, err := ar.rd.ReadBytes(' ') + if err != nil { + return nil, err + } + if len(b) > 0 && b[0] != '$' { + return nil, errors.New("invalid message") + } + n, err := strconv.ParseUint(string(b[1:len(b)-1]), 10, 32) + if err != nil { + return nil, errors.New("invalid size") + } + if n > 0x1FFFFFFF { // 536,870,911 bytes + return nil, errors.New("message too big") + } + b = make([]byte, int(n)+2) + if _, err := io.ReadFull(ar.rd, b); err != nil { + return nil, err + } + if b[len(b)-2] != '\r' || b[len(b)-1] != '\n' { + return nil, errors.New("expecting crlf") + } + values := make([]resp.Value, 0, 16) + line := b[:len(b)-2] +reading: + for len(line) != 0 { + if line[0] == '{' { + // The native protocol cannot understand json boundaries so it assumes that + // a json element must be at the end of the line. + values = append(values, resp.StringValue(string(line))) + break + } + i := 0 + for ; i < len(line); i++ { + if line[i] == ' ' { + values = append(values, resp.StringValue(string(line[:i]))) + line = line[i+1:] + continue reading + } + } + values = append(values, resp.StringValue(string(line))) + break + } + return &Message{Command: commandValues(values), Values: values, ConnType: Native, OutputType: JSON}, nil +} + +func commandValues(values []resp.Value) string { + if len(values) == 0 { + return "" + } + return strings.ToLower(values[0].String()) +} + +func (ar *AnyReaderWriter) readMultiBulkMessage() (*Message, error) { + rd := resp.NewReader(ar.rd) + v, telnet, _, err := rd.ReadMultiBulk() + if err != nil { + return nil, err + } + values := v.Array() + if len(values) == 0 { + return nil, nil + } + if telnet && TelnetIsJSON { + return &Message{Command: commandValues(values), Values: values, ConnType: Telnet, OutputType: JSON}, nil + } + return &Message{Command: commandValues(values), Values: values, ConnType: RESP, OutputType: RESP}, nil + +} + +func (ar *AnyReaderWriter) readHTTPMessage() (*Message, error) { + msg := &Message{ConnType: HTTP, OutputType: JSON} + line, err := ar.readcrlfline() + if err != nil { + return nil, err + } + parts := strings.Split(line, " ") + if len(parts) != 3 { + return nil, errors.New("invalid HTTP request") + } + method := parts[0] + path := parts[1] + if len(path) == 0 || path[0] != '/' { + return nil, errors.New("invalid HTTP request") + } + path, err = url.QueryUnescape(path[1:]) + if err != nil { + return nil, errors.New("invalid HTTP request") + } + if method != "GET" && method != "POST" { + return nil, errors.New("invalid HTTP method") + } + contentLength := 0 + websocket := false + websocketVersion := 0 + websocketKey := "" + for { + header, err := ar.readcrlfline() + if err != nil { + return nil, err + } + if header == "" { + break // end of headers + } + if header[0] == 'a' || header[0] == 'A' { + if strings.HasPrefix(strings.ToLower(header), "authorization:") { + msg.Auth = strings.TrimSpace(header[len("authorization:"):]) + } + } else if header[0] == 'u' || header[0] == 'U' { + if strings.HasPrefix(strings.ToLower(header), "upgrade:") && strings.ToLower(strings.TrimSpace(header[len("upgrade:"):])) == "websocket" { + websocket = true + } + } else if header[0] == 's' || header[0] == 'S' { + if strings.HasPrefix(strings.ToLower(header), "sec-websocket-version:") { + var n uint64 + n, err = strconv.ParseUint(strings.TrimSpace(header[len("sec-websocket-version:"):]), 10, 64) + if err != nil { + return nil, err + } + websocketVersion = int(n) + } else if strings.HasPrefix(strings.ToLower(header), "sec-websocket-key:") { + websocketKey = strings.TrimSpace(header[len("sec-websocket-key:"):]) + } + } else if header[0] == 'c' || header[0] == 'C' { + if strings.HasPrefix(strings.ToLower(header), "content-length:") { + var n uint64 + n, err = strconv.ParseUint(strings.TrimSpace(header[len("content-length:"):]), 10, 64) + if err != nil { + return nil, err + } + contentLength = int(n) + } + } + } + if websocket && websocketVersion >= 13 && websocketKey != "" { + msg.ConnType = WebSocket + if ar.wr == nil { + return nil, errors.New("connection is nil") + } + sum := sha1.Sum([]byte(websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) + accept := base64.StdEncoding.EncodeToString(sum[:]) + wshead := "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + accept + "\r\n\r\n" + if _, err = ar.wr.Write([]byte(wshead)); err != nil { + return nil, err + } + ar.ws = true + } else if contentLength > 0 { + msg.ConnType = HTTP + buf := make([]byte, contentLength) + if _, err = io.ReadFull(ar.rd, buf); err != nil { + return nil, err + } + path += string(buf) + } + if path == "" { + return msg, nil + } + if !strings.HasSuffix(path, "\r\n") { + path += "\r\n" + } + rd := NewAnyReaderWriter(bytes.NewBufferString(path)) + nmsg, err := rd.ReadMessage() + if err != nil { + return nil, err + } + msg.OutputType = nmsg.OutputType + msg.Values = nmsg.Values + msg.Command = commandValues(nmsg.Values) + return msg, nil +} diff --git a/controller/server/server.go b/controller/server/server.go index dcab4bea..741bcc7c 100644 --- a/controller/server/server.go +++ b/controller/server/server.go @@ -2,14 +2,14 @@ package server import ( "bufio" - "bytes" "errors" "fmt" "io" "net" "strings" - "github.com/tidwall/tile38/client" + //"github.com/tidwall/tile38/client" + "github.com/tidwall/tile38/controller/log" "github.com/tidwall/tile38/core" ) @@ -51,7 +51,7 @@ var errCloseHTTP = errors.New("close http") func ListenAndServe( host string, port int, protected func() bool, - handler func(conn *Conn, command []byte, rd *bufio.Reader, w io.Writer, websocket bool) error, + handler func(conn *Conn, msg *Message, rd *bufio.Reader, w io.Writer, websocket bool) error, ) error { ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", host, port)) if err != nil { @@ -68,17 +68,17 @@ func ListenAndServe( } } -func writeCommandErr(proto client.Proto, conn *Conn, err error) error { - if proto == client.HTTP || proto == client.WebSocket { - conn.Write([]byte(`HTTP/1.1 500 ` + err.Error() + "\r\nConnection: close\r\n\r\n")) - } - return err -} +// func writeCommandErr(proto client.Proto, conn *Conn, err error) error { +// if proto == client.HTTP || proto == client.WebSocket { +// conn.Write([]byte(`HTTP/1.1 500 ` + err.Error() + "\r\nConnection: close\r\n\r\n")) +// } +// return err +// } func handleConn( conn *Conn, protected func() bool, - handler func(conn *Conn, command []byte, rd *bufio.Reader, w io.Writer, websocket bool) error, + handler func(conn *Conn, msg *Message, rd *bufio.Reader, w io.Writer, websocket bool) error, ) { addr := conn.RemoteAddr().String() if core.ShowDebugMessages { @@ -96,59 +96,10 @@ func handleConn( } } defer conn.Close() - rd := bufio.NewReader(conn) - for i := 0; ; i++ { - err := func() error { - command, proto, auth, err := client.ReadMessage(rd, conn) - if err != nil { - return err - } - if len(command) > 0 && (command[0] == 'Q' || command[0] == 'q') && strings.ToLower(string(command)) == "quit" { - return io.EOF - } - var b bytes.Buffer - var denied bool - if (proto == client.HTTP || proto == client.WebSocket) && auth != "" { - if err := handler(conn, []byte("AUTH "+auth), rd, &b, proto == client.WebSocket); err != nil { - return writeCommandErr(proto, conn, err) - } - if strings.HasPrefix(b.String(), `{"ok":false`) { - denied = true - } else { - b.Reset() - } - } - if !denied { - if err := handler(conn, command, rd, &b, proto == client.WebSocket); err != nil { - return writeCommandErr(proto, conn, err) - } - } - switch proto { - case client.Native: - if err := client.WriteMessage(conn, b.Bytes()); err != nil { - return err - } - case client.HTTP: - if err := client.WriteHTTP(conn, b.Bytes()); err != nil { - return err - } - return errCloseHTTP - case client.WebSocket: - if err := client.WriteWebSocket(conn, b.Bytes()); err != nil { - return err - } - if _, err := conn.Write([]byte{137, 0}); err != nil { - return err - } - return errCloseHTTP - default: - b.WriteString("\r\n") - if _, err := conn.Write(b.Bytes()); err != nil { - return err - } - } - return nil - }() + rd := NewAnyReaderWriter(conn) + brd := rd.rd + for { + msg, err := rd.ReadMessage() if err != nil { if err == io.EOF { return @@ -160,5 +111,83 @@ func handleConn( log.Error(err) return } + if msg != nil && msg.Command != "" { + if msg.Command == "quit" { + if msg.OutputType == RESP { + io.WriteString(conn, "+OK\r\n") + } + return + } + err := handler(conn, msg, brd, conn, msg.ConnType == WebSocket) + if err != nil { + log.Error(err) + return + } + } } } + +//err := func() error { +// command, proto, auth, err := client.ReadMessage(rd, conn) +// if err != nil { +// return err +// } +// if len(command) > 0 && (command[0] == 'Q' || command[0] == 'q') && strings.ToLower(string(command)) == "quit" { +// return io.EOF +// } +// var b bytes.Buffer +// var denied bool +// if (proto == client.HTTP || proto == client.WebSocket) && auth != "" { +// if err := handler(conn, []byte("AUTH "+auth), rd, &b, proto == client.WebSocket); err != nil { +// return writeCommandErr(proto, conn, err) +// } +// if strings.HasPrefix(b.String(), `{"ok":false`) { +// denied = true +// } else { +// b.Reset() +// } +// } +// if !denied { +// if err := handler(conn, command, rd, &b, proto == client.WebSocket); err != nil { +// return writeCommandErr(proto, conn, err) +// } +// } +// switch proto { +// case client.Native: +// if err := client.WriteMessage(conn, b.Bytes()); err != nil { +// return err +// } +// case client.HTTP: +// if err := client.WriteHTTP(conn, b.Bytes()); err != nil { +// return err +// } +// return errCloseHTTP +// case client.WebSocket: +// if err := client.WriteWebSocket(conn, b.Bytes()); err != nil { +// return err +// } +// if _, err := conn.Write([]byte{137, 0}); err != nil { +// return err +// } +// return errCloseHTTP +// default: +// b.WriteString("\r\n") +// if _, err := conn.Write(b.Bytes()); err != nil { +// return err +// } +// } +// return nil +//}() +// if err != nil { +// if err == io.EOF { +// return +// } +// if err == errCloseHTTP || +// strings.Contains(err.Error(), "use of closed network connection") { +// return +// } +// log.Error(err) +// return +// } +// } +// } diff --git a/controller/token.go b/controller/token.go index 37725fd1..6376e59a 100644 --- a/controller/token.go +++ b/controller/token.go @@ -6,6 +6,8 @@ import ( "math" "strconv" "strings" + + "github.com/tidwall/resp" ) const defaultSearchOutput = outputObjects @@ -29,6 +31,15 @@ func token(line string) (newLine, token string) { return "", line } +func tokenval(vs []resp.Value) (nvs []resp.Value, token string, ok bool) { + if len(vs) > 0 { + token = vs[0].String() + nvs = vs[1:] + ok = true + } + return +} + func tokenlc(line string) (newLine, token string) { for i := 0; i < len(line); i++ { ch := line[i]