resp crud

This commit is contained in:
Josh Baker 2016-03-28 08:57:41 -07:00
parent 572d0776bb
commit ba9139be02
9 changed files with 1281 additions and 776 deletions

View File

@ -13,10 +13,11 @@ import (
"sync" "sync"
"time" "time"
"github.com/boltdb/bolt"
"github.com/google/btree" "github.com/google/btree"
"github.com/tidwall/resp"
"github.com/tidwall/tile38/client" "github.com/tidwall/tile38/client"
"github.com/tidwall/tile38/controller/log" "github.com/tidwall/tile38/controller/log"
"github.com/tidwall/tile38/controller/server"
) )
const backwardsBufferSize = 50000 const backwardsBufferSize = 50000
@ -104,61 +105,36 @@ func (c *Controller) loadAOF() error {
ps := float64(count) / (float64(d) / float64(time.Second)) ps := float64(count) / (float64(d) / float64(time.Second))
log.Infof("AOF loaded %d commands: %s: %.0f/sec", count, d, ps) log.Infof("AOF loaded %d commands: %s: %.0f/sec", count, d, ps)
}() }()
rd := NewAOFReader(c.f) rd := resp.NewReader(c.f)
for { for {
buf, err := rd.ReadCommand() v, _, n, err := rd.ReadMultiBulk()
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
return nil return nil
} }
if err == io.ErrUnexpectedEOF || err == errCorruptedAOF {
log.Warnf("aof is corrupted, likely data loss. Truncating to %d", c.aofsz)
fname := c.f.Name()
c.f.Close()
if err := os.Truncate(c.f.Name(), int64(c.aofsz)); err != nil {
log.Fatalf("could not truncate aof, possible data loss. %s", err.Error())
return err return err
} }
c.f, err = os.OpenFile(fname, os.O_CREATE|os.O_RDWR, 0600) values := v.Array()
if err != nil { if len(values) == 0 {
log.Fatalf("could not create aof, possible data loss. %s", err.Error()) return errors.New("multibulk missing command component")
}
msg := &server.Message{
Command: strings.ToLower(values[0].String()),
Values: values,
}
if _, _, err := c.command(msg, nil); err != nil {
return err return err
} }
if _, err := c.f.Seek(int64(c.aofsz), 0); err != nil { c.aofsz += n
log.Fatalf("could not seek aof, possible data loss. %s", err.Error())
return err
}
}
return err
}
empty := true
for i := 0; i < len(buf); i++ {
if buf[i] != 0 {
empty = false
break
}
}
if empty {
return nil
}
if _, _, err := c.command(string(buf), nil); err != nil {
return err
}
c.aofsz += 9 + len(buf)
count++ count++
} }
} }
func writeCommand(w io.Writer, line []byte) (n int, err error) { func (c *Controller) writeAOF(value resp.Value, d *commandDetailsT) error {
b := make([]byte, len(line)+9)
binary.LittleEndian.PutUint32(b, uint32(len(line)))
copy(b[4:], line)
binary.LittleEndian.PutUint32(b[len(b)-5:], uint32(len(line)))
return w.Write(b)
}
func (c *Controller) writeAOF(line string, d *commandDetailsT) error {
if d != nil { if d != nil {
if !d.updated {
return nil // just ignore writes if the command did not update
}
// process hooks // process hooks
if hm, ok := c.hookcols[d.key]; ok { if hm, ok := c.hookcols[d.key]; ok {
for _, hook := range hm { for _, hook := range hm {
@ -168,8 +144,11 @@ func (c *Controller) writeAOF(line string, d *commandDetailsT) error {
} }
} }
} }
data, err := value.MarshalRESP()
n, err := writeCommand(c.f, []byte(line)) if err != nil {
return err
}
n, err := c.f.Write(data)
if err != nil { if err != nil {
return err return err
} }
@ -306,14 +285,18 @@ func (c *Controller) liveAOF(pos int64, conn net.Conn, rd *bufio.Reader) error {
if err != nil { if err != nil {
return err return err
} }
rd := NewAOFReader(f) rd := resp.NewReader(f)
for { for {
cmd, err := rd.ReadCommand() v, _, err := rd.ReadValue()
if err != io.EOF { if err != io.EOF {
if err != nil { if err != nil {
return err return err
} }
if _, err := writeCommand(conn, cmd); err != nil { data, err := v.MarshalRESP()
if err != nil {
return err
}
if _, err := conn.Write(data); err != nil {
return err return err
} }
continue continue
@ -387,413 +370,413 @@ func (k *treeKeyBoolT) Less(item btree.Item) bool {
// - Stop shrinking, nothing left to do // - Stop shrinking, nothing left to do
func (c *Controller) aofshrink() { func (c *Controller) aofshrink() {
c.mu.Lock() // c.mu.Lock()
c.f.Sync() // c.f.Sync()
if c.shrinking { // if c.shrinking {
c.mu.Unlock() // c.mu.Unlock()
return // return
} // }
c.shrinking = true // c.shrinking = true
endpos := int64(c.aofsz) // endpos := int64(c.aofsz)
start := time.Now() // start := time.Now()
log.Infof("aof shrink started at pos %d", endpos) // log.Infof("aof shrink started at pos %d", endpos)
var hooks []string // var hooks []string
for _, hook := range c.hooks { // for _, hook := range c.hooks {
var orgs []string // var orgs []string
for _, endpoint := range hook.Endpoints { // for _, endpoint := range hook.Endpoints {
orgs = append(orgs, endpoint.Original) // orgs = append(orgs, endpoint.Original)
} // }
hooks = append(hooks, "SETHOOK "+hook.Name+" "+strings.Join(orgs, ",")+" "+hook.Command) // hooks = append(hooks, "SETHOOK "+hook.Name+" "+strings.Join(orgs, ",")+" "+hook.Command)
} // }
c.mu.Unlock() // c.mu.Unlock()
var err error // var err error
defer func() { // defer func() {
c.mu.Lock() // c.mu.Lock()
c.shrinking = false // c.shrinking = false
c.mu.Unlock() // c.mu.Unlock()
os.RemoveAll(c.dir + "/shrink.db") // os.RemoveAll(c.dir + "/shrink.db")
os.RemoveAll(c.dir + "/shrink") // os.RemoveAll(c.dir + "/shrink")
if err != nil { // if err != nil {
log.Error("aof shrink failed: " + err.Error()) // log.Error("aof shrink failed: " + err.Error())
} else { // } else {
log.Info("aof shrink completed: " + time.Now().Sub(start).String()) // log.Info("aof shrink completed: " + time.Now().Sub(start).String())
} // }
}() // }()
var db *bolt.DB // var db *bolt.DB
db, err = bolt.Open(c.dir+"/shrink.db", 0600, nil) // db, err = bolt.Open(c.dir+"/shrink.db", 0600, nil)
if err != nil { // if err != nil {
return // return
} // }
defer db.Close() // defer db.Close()
var nf *os.File // var nf *os.File
nf, err = os.Create(c.dir + "/shrink") // nf, err = os.Create(c.dir + "/shrink")
if err != nil { // if err != nil {
return // return
} // }
defer nf.Close() // defer nf.Close()
defer func() { // defer func() {
c.mu.Lock() // c.mu.Lock()
defer c.mu.Unlock() // defer c.mu.Unlock()
if err == nil { // if err == nil {
c.f.Sync() // c.f.Sync()
_, err = nf.Seek(0, 2) // _, err = nf.Seek(0, 2)
if err == nil { // if err == nil {
var f *os.File // var f *os.File
f, err = os.Open(c.dir + "/aof") // f, err = os.Open(c.dir + "/aof")
if err != nil { // if err != nil {
return // return
} // }
defer f.Close() // defer f.Close()
_, err = f.Seek(endpos, 0) // _, err = f.Seek(endpos, 0)
if err == nil { // if err == nil {
_, err = io.Copy(nf, f) // _, err = io.Copy(nf, f)
if err == nil { // if err == nil {
f.Close() // f.Close()
nf.Close() // nf.Close()
// At this stage we need to kill all aof followers. To do so we will // // At this stage we need to kill all aof followers. To do so we will
// write a KILLAOF command to the stream. KILLAOF isn't really a command. // // write a KILLAOF command to the stream. KILLAOF isn't really a command.
// This will cause the followers will close their connection and then // // This will cause the followers will close their connection and then
// automatically reconnect. The reconnection will force a sync of the aof. // // automatically reconnect. The reconnection will force a sync of the aof.
err = c.writeAOF("KILLAOF", nil) // err = c.writeAOF(resp.MultiBulkValue("KILLAOF"), nil)
if err == nil { // if err == nil {
c.f.Close() // c.f.Close()
err = os.Rename(c.dir+"/shrink", c.dir+"/aof") // err = os.Rename(c.dir+"/shrink", c.dir+"/aof")
if err != nil { // if err != nil {
log.Fatal("shink rename fatal operation") // log.Fatal("shink rename fatal operation")
} // }
c.f, err = os.OpenFile(c.dir+"/aof", os.O_CREATE|os.O_RDWR, 0600) // c.f, err = os.OpenFile(c.dir+"/aof", os.O_CREATE|os.O_RDWR, 0600)
if err != nil { // if err != nil {
log.Fatal("shink openfile fatal operation") // log.Fatal("shink openfile fatal operation")
} // }
var n int64 // var n int64
n, err = c.f.Seek(0, 2) // n, err = c.f.Seek(0, 2)
if err != nil { // if err != nil {
log.Fatal("shink seek end fatal operation") // log.Fatal("shink seek end fatal operation")
} // }
c.aofsz = int(n) // c.aofsz = int(n)
} // }
} // }
} // }
} // }
} // }
}() // }()
var f *os.File // var f *os.File
f, err = os.Open(c.dir + "/aof") // f, err = os.Open(c.dir + "/aof")
if err != nil { // if err != nil {
return // return
} // }
defer f.Close() // defer f.Close()
var buf []byte // var buf []byte
var pos int64 // var pos int64
pos, err = f.Seek(endpos, 0) // pos, err = f.Seek(endpos, 0)
if err != nil { // if err != nil {
return // return
} // }
var readPreviousCommand func() ([]byte, error) // var readPreviousCommand func() ([]byte, error)
readPreviousCommand = func() ([]byte, error) { // readPreviousCommand = func() ([]byte, error) {
if len(buf) >= 5 { // if len(buf) >= 5 {
if buf[len(buf)-1] != 0 { // if buf[len(buf)-1] != 0 {
return nil, errCorruptedAOF // return nil, errCorruptedAOF
} // }
sz2 := int(binary.LittleEndian.Uint32(buf[len(buf)-5:])) // sz2 := int(binary.LittleEndian.Uint32(buf[len(buf)-5:]))
if len(buf) >= sz2+9 { // if len(buf) >= sz2+9 {
sz1 := int(binary.LittleEndian.Uint32(buf[len(buf)-(sz2+9):])) // sz1 := int(binary.LittleEndian.Uint32(buf[len(buf)-(sz2+9):]))
if sz1 != sz2 { // if sz1 != sz2 {
return nil, errCorruptedAOF // return nil, errCorruptedAOF
} // }
command := buf[len(buf)-(sz2+5) : len(buf)-5] // command := buf[len(buf)-(sz2+5) : len(buf)-5]
buf = buf[:len(buf)-(sz2+9)] // buf = buf[:len(buf)-(sz2+9)]
return command, nil // return command, nil
} // }
} // }
if pos == 0 { // if pos == 0 {
if len(buf) > 0 { // if len(buf) > 0 {
return nil, io.ErrUnexpectedEOF // return nil, io.ErrUnexpectedEOF
} else { // } else {
return nil, io.EOF // return nil, io.EOF
} // }
} // }
sz := int64(backwardsBufferSize) // sz := int64(backwardsBufferSize)
offset := pos - sz // offset := pos - sz
if offset < 0 { // if offset < 0 {
sz = pos // sz = pos
offset = 0 // offset = 0
} // }
pos, err = f.Seek(offset, 0) // pos, err = f.Seek(offset, 0)
if err != nil { // if err != nil {
return nil, err // return nil, err
} // }
nbuf := make([]byte, int(sz)) // nbuf := make([]byte, int(sz))
_, err = io.ReadFull(f, nbuf) // _, err = io.ReadFull(f, nbuf)
if err != nil { // if err != nil {
return nil, err // return nil, err
} // }
if len(buf) > 0 { // if len(buf) > 0 {
nbuf = append(nbuf, buf...) // nbuf = append(nbuf, buf...)
} // }
buf = nbuf // buf = nbuf
return readPreviousCommand() // return readPreviousCommand()
} // }
var tx *bolt.Tx // var tx *bolt.Tx
tx, err = db.Begin(true) // tx, err = db.Begin(true)
if err != nil { // if err != nil {
return // return
} // }
defer func() { // defer func() {
tx.Rollback() // tx.Rollback()
}() // }()
var keyIgnoreM = map[string]bool{} // var keyIgnoreM = map[string]bool{}
var keyBucketM = btree.New(16) // var keyBucketM = btree.New(16)
var cmd, key, id, field string // var cmd, key, id, field string
var line string // var line string
var command []byte // var command []byte
var val []byte // var val []byte
var b *bolt.Bucket // var b *bolt.Bucket
reading: // reading:
for i := 0; ; i++ { // for i := 0; ; i++ {
if i%500 == 0 { // if i%500 == 0 {
if err = tx.Commit(); err != nil { // if err = tx.Commit(); err != nil {
return // return
} // }
tx, err = db.Begin(true) // tx, err = db.Begin(true)
if err != nil { // if err != nil {
return // return
} // }
} // }
command, err = readPreviousCommand() // command, err = readPreviousCommand()
if err != nil { // if err != nil {
if err == io.EOF { // if err == io.EOF {
err = nil // err = nil
break // break
} // }
return // return
} // }
// quick path // // quick path
if len(command) == 0 { // if len(command) == 0 {
continue // ignore blank commands // continue // ignore blank commands
} // }
line, cmd = token(string(command)) // line, cmd = token(string(command))
cmd = strings.ToLower(cmd) // cmd = strings.ToLower(cmd)
switch cmd { // switch cmd {
case "flushdb": // case "flushdb":
break reading // all done // break reading // all done
case "drop": // case "drop":
if line, key = token(line); key == "" { // if line, key = token(line); key == "" {
err = errors.New("DROP is missing key") // err = errors.New("DROP is missing key")
return // return
} // }
if !keyIgnoreM[key] { // if !keyIgnoreM[key] {
keyIgnoreM[key] = true // keyIgnoreM[key] = true
} // }
case "del": // case "del":
if line, key = token(line); key == "" { // if line, key = token(line); key == "" {
err = errors.New("DEL is missing key") // err = errors.New("DEL is missing key")
return // return
} // }
if keyIgnoreM[key] { // if keyIgnoreM[key] {
continue // ignore // continue // ignore
} // }
if line, id = token(line); id == "" { // if line, id = token(line); id == "" {
err = errors.New("DEL is missing id") // err = errors.New("DEL is missing id")
return // return
} // }
if keyBucketM.Get(&treeKeyBoolT{key}) == nil { // if keyBucketM.Get(&treeKeyBoolT{key}) == nil {
if _, err = tx.CreateBucket([]byte(key + ".ids")); err != nil { // if _, err = tx.CreateBucket([]byte(key + ".ids")); err != nil {
return // return
} // }
if _, err = tx.CreateBucket([]byte(key + ".ignore_ids")); err != nil { // if _, err = tx.CreateBucket([]byte(key + ".ignore_ids")); err != nil {
return // return
} // }
keyBucketM.ReplaceOrInsert(&treeKeyBoolT{key}) // keyBucketM.ReplaceOrInsert(&treeKeyBoolT{key})
} // }
b = tx.Bucket([]byte(key + ".ignore_ids")) // b = tx.Bucket([]byte(key + ".ignore_ids"))
err = b.Put([]byte(id), []byte("2")) // 2 for hard ignore // err = b.Put([]byte(id), []byte("2")) // 2 for hard ignore
if err != nil { // if err != nil {
return // return
} // }
case "set": // case "set":
if line, key = token(line); key == "" { // if line, key = token(line); key == "" {
err = errors.New("SET is missing key") // err = errors.New("SET is missing key")
return // return
} // }
if keyIgnoreM[key] { // if keyIgnoreM[key] {
continue // ignore // continue // ignore
} // }
if line, id = token(line); id == "" { // if line, id = token(line); id == "" {
err = errors.New("SET is missing id") // err = errors.New("SET is missing id")
return // return
} // }
if keyBucketM.Get(&treeKeyBoolT{key}) == nil { // if keyBucketM.Get(&treeKeyBoolT{key}) == nil {
if _, err = tx.CreateBucket([]byte(key + ".ids")); err != nil { // if _, err = tx.CreateBucket([]byte(key + ".ids")); err != nil {
return // return
} // }
if _, err = tx.CreateBucket([]byte(key + ".ignore_ids")); err != nil { // if _, err = tx.CreateBucket([]byte(key + ".ignore_ids")); err != nil {
return // return
} // }
keyBucketM.ReplaceOrInsert(&treeKeyBoolT{key}) // keyBucketM.ReplaceOrInsert(&treeKeyBoolT{key})
} // }
b = tx.Bucket([]byte(key + ".ignore_ids")) // b = tx.Bucket([]byte(key + ".ignore_ids"))
val = b.Get([]byte(id)) // val = b.Get([]byte(id))
if val == nil { // if val == nil {
if err = b.Put([]byte(id), []byte("1")); err != nil { // if err = b.Put([]byte(id), []byte("1")); err != nil {
return // return
} // }
b = tx.Bucket([]byte(key + ".ids")) // b = tx.Bucket([]byte(key + ".ids"))
if err = b.Put([]byte(id), command); err != nil { // if err = b.Put([]byte(id), command); err != nil {
return // return
} // }
} else { // } else {
switch string(val) { // switch string(val) {
default: // default:
err = errors.New("invalid ignore") // err = errors.New("invalid ignore")
case "1", "2": // case "1", "2":
continue // ignore // continue // ignore
} // }
} // }
case "fset": // case "fset":
if line, key = token(line); key == "" { // if line, key = token(line); key == "" {
err = errors.New("FSET is missing key") // err = errors.New("FSET is missing key")
return // return
} // }
if keyIgnoreM[key] { // if keyIgnoreM[key] {
continue // ignore // continue // ignore
} // }
if line, id = token(line); id == "" { // if line, id = token(line); id == "" {
err = errors.New("FSET is missing id") // err = errors.New("FSET is missing id")
return // return
} // }
if line, field = token(line); field == "" { // if line, field = token(line); field == "" {
err = errors.New("FSET is missing field") // err = errors.New("FSET is missing field")
return // return
} // }
if keyBucketM.Get(&treeKeyBoolT{key}) == nil { // if keyBucketM.Get(&treeKeyBoolT{key}) == nil {
if _, err = tx.CreateBucket([]byte(key + ".ids")); err != nil { // if _, err = tx.CreateBucket([]byte(key + ".ids")); err != nil {
return // return
} // }
if _, err = tx.CreateBucket([]byte(key + ".ignore_ids")); err != nil { // if _, err = tx.CreateBucket([]byte(key + ".ignore_ids")); err != nil {
return // return
} // }
keyBucketM.ReplaceOrInsert(&treeKeyBoolT{key}) // keyBucketM.ReplaceOrInsert(&treeKeyBoolT{key})
} // }
b = tx.Bucket([]byte(key + ".ignore_ids")) // b = tx.Bucket([]byte(key + ".ignore_ids"))
val = b.Get([]byte(id)) // val = b.Get([]byte(id))
if val == nil { // if val == nil {
b = tx.Bucket([]byte(key + ":" + id + ":0")) // b = tx.Bucket([]byte(key + ":" + id + ":0"))
if b == nil { // if b == nil {
if b, err = tx.CreateBucket([]byte(key + ":" + id + ":0")); err != nil { // if b, err = tx.CreateBucket([]byte(key + ":" + id + ":0")); err != nil {
return // return
} // }
} // }
if b.Get([]byte(field)) == nil { // if b.Get([]byte(field)) == nil {
if err = b.Put([]byte(field), command); err != nil { // if err = b.Put([]byte(field), command); err != nil {
return // return
} // }
} // }
} else { // } else {
switch string(val) { // switch string(val) {
default: // default:
err = errors.New("invalid ignore") // err = errors.New("invalid ignore")
case "1": // case "1":
b = tx.Bucket([]byte(key + ":" + id + ":1")) // b = tx.Bucket([]byte(key + ":" + id + ":1"))
if b == nil { // if b == nil {
if b, err = tx.CreateBucket([]byte(key + ":" + id + ":1")); err != nil { // if b, err = tx.CreateBucket([]byte(key + ":" + id + ":1")); err != nil {
return // return
} // }
} // }
if b.Get([]byte(field)) == nil { // if b.Get([]byte(field)) == nil {
if err = b.Put([]byte(field), command); err != nil { // if err = b.Put([]byte(field), command); err != nil {
return // return
} // }
} // }
case "2": // case "2":
continue // ignore // continue // ignore
} // }
} // }
} // }
} // }
if err = tx.Commit(); err != nil { // if err = tx.Commit(); err != nil {
return // return
} // }
tx, err = db.Begin(false) // tx, err = db.Begin(false)
if err != nil { // if err != nil {
return // return
} // }
keyBucketM.Ascend(func(item btree.Item) bool { // keyBucketM.Ascend(func(item btree.Item) bool {
key := item.(*treeKeyBoolT).key // key := item.(*treeKeyBoolT).key
b := tx.Bucket([]byte(key + ".ids")) // b := tx.Bucket([]byte(key + ".ids"))
if b != nil { // if b != nil {
err = b.ForEach(func(id, command []byte) error { // err = b.ForEach(func(id, command []byte) error {
// parse the SET command // // parse the SET command
_, fields, values, etype, eargs, err := c.parseSetArgs(string(command[4:])) // _, fields, values, etype, eargs, err := c.parseSetArgs(string(command[4:]))
if err != nil { // if err != nil {
return err // return err
} // }
// store the fields in a map // // store the fields in a map
var fieldm = map[string]float64{} // var fieldm = map[string]float64{}
for i, field := range fields { // for i, field := range fields {
fieldm[field] = values[i] // fieldm[field] = values[i]
} // }
// append old FSET values. these are FSETs that existed prior to the last SET. // // append old FSET values. these are FSETs that existed prior to the last SET.
f1 := tx.Bucket([]byte(key + ":" + string(id) + ":1")) // f1 := tx.Bucket([]byte(key + ":" + string(id) + ":1"))
if f1 != nil { // if f1 != nil {
err = f1.ForEach(func(field, command []byte) error { // err = f1.ForEach(func(field, command []byte) error {
d, err := c.parseFSetArgs(string(command[5:])) // d, err := c.parseFSetArgs(string(command[5:]))
if err != nil { // if err != nil {
return err // return err
} // }
if _, ok := fieldm[d.field]; !ok { // if _, ok := fieldm[d.field]; !ok {
fieldm[d.field] = d.value // fieldm[d.field] = d.value
} // }
return nil // return nil
}) // })
if err != nil { // if err != nil {
return err // return err
} // }
} // }
// append new FSET values. these are FSETs that were added after the last SET. // // append new FSET values. these are FSETs that were added after the last SET.
f0 := tx.Bucket([]byte(key + ":" + string(id) + ":0")) // f0 := tx.Bucket([]byte(key + ":" + string(id) + ":0"))
if f0 != nil { // if f0 != nil {
f0.ForEach(func(field, command []byte) error { // f0.ForEach(func(field, command []byte) error {
d, err := c.parseFSetArgs(string(command[5:])) // d, err := c.parseFSetArgs(string(command[5:]))
if err != nil { // if err != nil {
return err // return err
} // }
fieldm[d.field] = d.value // fieldm[d.field] = d.value
return nil // return nil
}) // })
} // }
// rebuild the SET command // // rebuild the SET command
ncommand := "set " + key + " " + string(id) // ncommand := "set " + key + " " + string(id)
for field, value := range fieldm { // for field, value := range fieldm {
if value != 0 { // if value != 0 {
ncommand += " field " + field + " " + strconv.FormatFloat(value, 'f', -1, 64) // ncommand += " field " + field + " " + strconv.FormatFloat(value, 'f', -1, 64)
} // }
} // }
ncommand += " " + strings.ToUpper(etype) + " " + eargs // ncommand += " " + strings.ToUpper(etype) + " " + eargs
_, err = writeCommand(nf, []byte(ncommand)) // _, err = writeCommand(nf, []byte(ncommand))
if err != nil { // if err != nil {
return err // return err
} // }
return nil // return nil
}) // })
if err != nil { // if err != nil {
return false // return false
} // }
} // }
return true // return true
}) // })
if err == nil { // if err == nil {
// add all of the hooks // // add all of the hooks
for _, line := range hooks { // for _, line := range hooks {
_, err = writeCommand(nf, []byte(line)) // _, err = writeCommand(nf, []byte(line))
if err != nil { // if err != nil {
return // return
} // }
} // }
} // }
} }

View File

@ -151,18 +151,18 @@ func (c *Collection) Get(id string) (obj geojson.Object, fields []float64, ok bo
// SetField set a field value for an object and returns that object. // SetField set a field value for an object and returns that object.
// If the object does not exist then the 'ok' return value will be false. // If the object does not exist then the 'ok' return value will be false.
func (c *Collection) SetField(id, field string, value float64) (obj geojson.Object, fields []float64, ok bool) { func (c *Collection) SetField(id, field string, value float64) (obj geojson.Object, fields []float64, updated bool, ok bool) {
i := c.items.Get(&itemT{ID: id}) i := c.items.Get(&itemT{ID: id})
if i == nil { if i == nil {
ok = false ok = false
return return
} }
item := i.(*itemT) item := i.(*itemT)
c.setField(item, field, value) updated = c.setField(item, field, value)
return item.Object, item.Fields, true return item.Object, item.Fields, updated, true
} }
func (c *Collection) setField(item *itemT, field string, value float64) { func (c *Collection) setField(item *itemT, field string, value float64) (updated bool) {
idx, ok := c.fieldMap[field] idx, ok := c.fieldMap[field]
if !ok { if !ok {
idx = len(c.fieldMap) idx = len(c.fieldMap)
@ -173,7 +173,12 @@ func (c *Collection) setField(item *itemT, field string, value float64) {
item.Fields = append(item.Fields, math.NaN()) item.Fields = append(item.Fields, math.NaN())
} }
c.weight += len(item.Fields) * 8 c.weight += len(item.Fields) * 8
ovalue := item.Fields[idx]
if math.IsNaN(ovalue) {
ovalue = 0
}
item.Fields[idx] = value item.Fields[idx] = value
return ovalue != value
} }
// FieldMap return a maps of the field names. // FieldMap return a maps of the field names.

View File

@ -8,12 +8,12 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"runtime"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/google/btree" "github.com/google/btree"
"github.com/tidwall/resp"
"github.com/tidwall/tile38/controller/collection" "github.com/tidwall/tile38/controller/collection"
"github.com/tidwall/tile38/controller/log" "github.com/tidwall/tile38/controller/log"
"github.com/tidwall/tile38/controller/server" "github.com/tidwall/tile38/controller/server"
@ -35,6 +35,7 @@ type commandDetailsT struct {
fields []float64 fields []float64
oldObj geojson.Object oldObj geojson.Object
oldFields []float64 oldFields []float64
updated bool
} }
func (col *collectionT) Less(item btree.Item) bool { func (col *collectionT) Less(item btree.Item) bool {
@ -105,8 +106,8 @@ func ListenAndServe(host string, port int, dir string) error {
c.mu.Unlock() c.mu.Unlock()
}() }()
go c.processLives() go c.processLives()
handler := func(conn *server.Conn, command []byte, rd *bufio.Reader, w io.Writer, websocket bool) error { handler := func(conn *server.Conn, msg *server.Message, rd *bufio.Reader, w io.Writer, websocket bool) error {
err := c.handleInputCommand(conn, string(command), w) err := c.handleInputCommand(conn, msg, w)
if err != nil { if err != nil {
if err.Error() == "going live" { if err.Error() == "going live" {
return c.goLive(err, conn, rd, websocket) return c.goLive(err, conn, rd, websocket)
@ -162,55 +163,73 @@ func isReservedFieldName(field string) bool {
return false return false
} }
func (c *Controller) handleInputCommand(conn *server.Conn, line string, w io.Writer) error { func (c *Controller) handleInputCommand(conn *server.Conn, msg *server.Message, w io.Writer) error {
if core.ShowDebugMessages && line != "pInG" { var words []string
log.Debug(line) for _, v := range msg.Values {
words = append(words, v.String())
} }
// line := strings.Join(words, " ")
// if core.ShowDebugMessages && line != "pInG" {
// log.Debug(line)
// }
start := time.Now() start := time.Now()
// Ping. Just send back the response. No need to put through the pipeline. // Ping. Just send back the response. No need to put through the pipeline.
if len(line) == 4 && (line[0] == 'p' || line[0] == 'P') && lc(line, "ping") { if msg.Command == "ping" {
switch msg.OutputType {
case server.JSON:
w.Write([]byte(`{"ok":true,"ping":"pong","elapsed":"` + time.Now().Sub(start).String() + `"}`)) w.Write([]byte(`{"ok":true,"ping":"pong","elapsed":"` + time.Now().Sub(start).String() + `"}`))
case server.RESP:
io.WriteString(w, "+PONG\r\n")
}
return nil return nil
} }
writeErr := func(err error) error { writeErr := func(err error) error {
js := `{"ok":false,"err":` + jsonString(err.Error()) + `,"elapsed":"` + time.Now().Sub(start).String() + "\"}" switch msg.OutputType {
if _, err := w.Write([]byte(js)); err != nil { case server.JSON:
return err io.WriteString(w, `{"ok":false,"err":`+jsonString(err.Error())+`,"elapsed":"`+time.Now().Sub(start).String()+"\"}")
case server.RESP:
if err == errInvalidNumberOfArguments {
io.WriteString(w, "-ERR wrong number of arguments for '"+msg.Command+"' command\r\n")
} else {
v, _ := resp.ErrorValue(errors.New("ERR " + err.Error())).MarshalRESP()
io.WriteString(w, string(v))
}
} }
return nil return nil
} }
var write bool var write bool
_, cmd := tokenlc(line)
if cmd == "" {
return writeErr(errors.New("empty command"))
}
if !conn.Authenticated || cmd == "auth" { if !conn.Authenticated || msg.Command == "auth" {
c.mu.RLock() c.mu.RLock()
requirePass := c.config.RequirePass requirePass := c.config.RequirePass
c.mu.RUnlock() c.mu.RUnlock()
if requirePass != "" { if requirePass != "" {
// This better be an AUTH command. // This better be an AUTH command.
if cmd != "auth" { if msg.Command != "auth" {
// Just shut down the pipeline now. The less the client connection knows the better. // Just shut down the pipeline now. The less the client connection knows the better.
return writeErr(errors.New("authentication required")) return writeErr(errors.New("authentication required"))
} }
password, _ := token(line) password := ""
if len(msg.Values) > 1 {
password = msg.Values[1].String()
}
if requirePass != strings.TrimSpace(password) { if requirePass != strings.TrimSpace(password) {
return writeErr(errors.New("invalid password")) return writeErr(errors.New("invalid password"))
} }
conn.Authenticated = true conn.Authenticated = true
w.Write([]byte(`{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}")) w.Write([]byte(`{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}"))
return nil return nil
} else if cmd == "auth" { } else if msg.Command == "auth" {
return writeErr(errors.New("invalid password")) return writeErr(errors.New("invalid password"))
} }
} }
// choose the locking strategy // choose the locking strategy
switch cmd { switch msg.Command {
default: default:
c.mu.RLock() c.mu.RLock()
defer c.mu.RUnlock() defer c.mu.RUnlock()
@ -243,7 +262,7 @@ func (c *Controller) handleInputCommand(conn *server.Conn, line string, w io.Wri
// no locks! DEV MODE ONLY // no locks! DEV MODE ONLY
} }
resp, d, err := c.command(line, w) res, d, err := c.command(msg, w)
if err != nil { if err != nil {
if err.Error() == "going live" { if err.Error() == "going live" {
return err return err
@ -251,7 +270,7 @@ func (c *Controller) handleInputCommand(conn *server.Conn, line string, w io.Wri
return writeErr(err) return writeErr(err)
} }
if write { if write {
if err := c.writeAOF(line, &d); err != nil { if err := c.writeAOF(resp.ArrayValue(msg.Values), &d); err != nil {
if _, ok := err.(errAOFHook); ok { if _, ok := err.(errAOFHook); ok {
return writeErr(err) return writeErr(err)
} }
@ -259,8 +278,8 @@ func (c *Controller) handleInputCommand(conn *server.Conn, line string, w io.Wri
return err return err
} }
} }
if resp != "" { if res != "" {
if _, err := io.WriteString(w, resp); err != nil { if _, err := io.WriteString(w, res); err != nil {
return err return err
} }
} }
@ -284,84 +303,85 @@ func (c *Controller) reset() {
c.cols = btree.New(16) c.cols = btree.New(16)
} }
func (c *Controller) command(line string, w io.Writer) (resp string, d commandDetailsT, err error) { func (c *Controller) command(msg *server.Message, w io.Writer) (res string, d commandDetailsT, err error) {
start := time.Now() start := time.Now()
okResp := func() string { okResp := func() string {
if w == nil { if w != nil {
switch msg.OutputType {
case server.JSON:
return `{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}"
case server.RESP:
return "+OK\r\n"
}
}
return "" return ""
} }
return `{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}" okResp = okResp
} switch msg.Command {
nline, cmd := tokenlc(line)
switch cmd {
default: default:
err = fmt.Errorf("unknown command '%s'", cmd) err = fmt.Errorf("unknown command '%s'", msg.Values[0])
return return
// lock // lock
case "set": case "set":
d, err = c.cmdSet(nline) res, d, err = c.cmdSet(msg)
resp = okResp()
case "fset": case "fset":
d, err = c.cmdFset(nline) res, d, err = c.cmdFset(msg)
resp = okResp()
case "del": case "del":
d, err = c.cmdDel(nline) res, d, err = c.cmdDel(msg)
resp = okResp()
case "drop": case "drop":
d, err = c.cmdDrop(nline) res, d, err = c.cmdDrop(msg)
resp = okResp()
case "flushdb": case "flushdb":
d, err = c.cmdFlushDB(nline) res, d, err = c.cmdFlushDB(msg)
resp = okResp() // case "sethook":
case "sethook": // err = c.cmdSetHook(nline)
err = c.cmdSetHook(nline) // resp = okResp()
resp = okResp() // case "delhook":
case "delhook": // err = c.cmdDelHook(nline)
err = c.cmdDelHook(nline) // resp = okResp()
resp = okResp() // case "hooks":
case "hooks": // err = c.cmdHooks(nline, w)
err = c.cmdHooks(nline, w) // case "massinsert":
case "massinsert": // if !core.DevMode {
if !core.DevMode { // err = fmt.Errorf("unknown command '%s'", cmd)
err = fmt.Errorf("unknown command '%s'", cmd) // return
return // }
} // err = c.cmdMassInsert(nline)
err = c.cmdMassInsert(nline) // resp = okResp()
resp = okResp() // case "follow":
case "follow": // err = c.cmdFollow(nline)
err = c.cmdFollow(nline) // resp = okResp()
resp = okResp() // case "config":
case "config": // resp, err = c.cmdConfig(nline)
resp, err = c.cmdConfig(nline) // case "readonly":
case "readonly": // err = c.cmdReadOnly(nline)
err = c.cmdReadOnly(nline) // resp = okResp()
resp = okResp() // case "stats":
case "stats": // resp, err = c.cmdStats(nline)
resp, err = c.cmdStats(nline) // case "server":
case "server": // resp, err = c.cmdServer(nline)
resp, err = c.cmdServer(nline) // case "scan":
case "scan": // err = c.cmdScan(nline, w)
err = c.cmdScan(nline, w) // case "nearby":
case "nearby": // err = c.cmdNearby(nline, w)
err = c.cmdNearby(nline, w) // case "within":
case "within": // err = c.cmdWithin(nline, w)
err = c.cmdWithin(nline, w) // case "intersects":
case "intersects": // err = c.cmdIntersects(nline, w)
err = c.cmdIntersects(nline, w)
case "get": case "get":
resp, err = c.cmdGet(nline) res, err = c.cmdGet(msg)
case "keys": // case "keys":
err = c.cmdKeys(nline, w) // err = c.cmdKeys(nline, w)
case "aof": // case "aof":
err = c.cmdAOF(nline, w) // err = c.cmdAOF(nline, w)
case "aofmd5": // case "aofmd5":
resp, err = c.cmdAOFMD5(nline) // resp, err = c.cmdAOFMD5(nline)
case "gc": // case "gc":
go runtime.GC() // go runtime.GC()
resp = okResp() // resp = okResp()
case "aofshrink": // case "aofshrink":
go c.aofshrink() // go c.aofshrink()
resp = okResp() // resp = okResp()
} }
return return
} }

View File

@ -3,51 +3,121 @@ package controller
import ( import (
"bytes" "bytes"
"math" "math"
"sort"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/google/btree" "github.com/google/btree"
"github.com/tidwall/resp"
"github.com/tidwall/tile38/controller/collection" "github.com/tidwall/tile38/controller/collection"
"github.com/tidwall/tile38/controller/server"
"github.com/tidwall/tile38/geojson" "github.com/tidwall/tile38/geojson"
"github.com/tidwall/tile38/geojson/geohash" "github.com/tidwall/tile38/geojson/geohash"
) )
func (c *Controller) cmdGet(line string) (string, error) { type fvt struct {
field string
value float64
}
type byField []fvt
func (a byField) Len() int {
return len(a)
}
func (a byField) Less(i, j int) bool {
return a[i].field < a[j].field
}
func (a byField) Swap(i, j int) {
a[i], a[j] = a[j], a[i]
}
func orderFields(fmap map[string]int, fields []float64) []fvt {
var fv fvt
fvs := make([]fvt, 0, len(fmap))
for field, idx := range fmap {
if idx < len(fields) {
fv.field = field
fv.value = fields[idx]
if !math.IsNaN(fv.value) && fv.value != 0 {
fvs = append(fvs, fv)
}
}
}
sort.Sort(byField(fvs))
return fvs
}
func (c *Controller) cmdGet(msg *server.Message) (string, error) {
start := time.Now() start := time.Now()
vs := msg.Values[1:]
var ok bool
var key, id, typ, sprecision string var key, id, typ, sprecision string
if line, key = token(line); key == "" { if vs, key, ok = tokenval(vs); !ok || key == "" {
return "", errInvalidNumberOfArguments return "", errInvalidNumberOfArguments
} }
if line, id = token(line); id == "" { if vs, id, ok = tokenval(vs); !ok || id == "" {
return "", errInvalidNumberOfArguments return "", errInvalidNumberOfArguments
} }
col := c.getCol(key) col := c.getCol(key)
if col == nil { if col == nil {
if msg.OutputType == server.RESP {
return "$-1\r\n", nil
}
return "", errKeyNotFound return "", errKeyNotFound
} }
o, fields, ok := col.Get(id) o, fields, ok := col.Get(id)
if !ok { if !ok {
if msg.OutputType == server.RESP {
return "$-1\r\n", nil
}
return "", errIDNotFound return "", errIDNotFound
} }
vals := make([]resp.Value, 0, 2)
var buf bytes.Buffer var buf bytes.Buffer
if msg.OutputType == server.JSON {
buf.WriteString(`{"ok":true`) buf.WriteString(`{"ok":true`)
if line, typ = token(line); typ == "" || strings.ToLower(typ) == "object" { }
if vs, typ, ok = tokenval(vs); !ok || strings.ToLower(typ) == "object" {
if msg.OutputType == server.JSON {
buf.WriteString(`,"object":`) buf.WriteString(`,"object":`)
buf.WriteString(o.JSON()) buf.WriteString(o.JSON())
} else { } else {
ltyp := strings.ToLower(typ) vals = append(vals, resp.ArrayValue([]resp.Value{resp.StringValue(o.JSON())}))
switch ltyp { }
} else {
switch strings.ToLower(typ) {
default: default:
return "", errInvalidArgument(typ) return "", errInvalidArgument(typ)
case "point": case "point":
point := o.CalculatedPoint()
if msg.OutputType == server.JSON {
buf.WriteString(`,"point":`) buf.WriteString(`,"point":`)
buf.WriteString(o.CalculatedPoint().ExternalJSON()) buf.WriteString(point.ExternalJSON())
} else {
if point.Z != 0 {
vals = append(vals, resp.ArrayValue([]resp.Value{
resp.StringValue(strconv.FormatFloat(point.Y, 'f', -1, 64)),
resp.StringValue(strconv.FormatFloat(point.X, 'f', -1, 64)),
resp.StringValue(strconv.FormatFloat(point.Z, 'f', -1, 64)),
}))
} else {
vals = append(vals, resp.ArrayValue([]resp.Value{
resp.StringValue(strconv.FormatFloat(point.Y, 'f', -1, 64)),
resp.StringValue(strconv.FormatFloat(point.X, 'f', -1, 64)),
}))
}
}
case "hash": case "hash":
if line, sprecision = token(line); sprecision == "" { if vs, sprecision, ok = tokenval(vs); !ok || sprecision == "" {
return "", errInvalidNumberOfArguments return "", errInvalidNumberOfArguments
} }
if msg.OutputType == server.JSON {
buf.WriteString(`,"hash":`) buf.WriteString(`,"hash":`)
}
precision, err := strconv.ParseInt(sprecision, 10, 64) precision, err := strconv.ParseInt(sprecision, 10, 64)
if err != nil || precision < 1 || precision > 64 { if err != nil || precision < 1 || precision > 64 {
return "", errInvalidArgument(sprecision) return "", errInvalidArgument(sprecision)
@ -56,81 +126,145 @@ func (c *Controller) cmdGet(line string) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
if msg.OutputType == server.JSON {
buf.WriteString(`"` + p + `"`) buf.WriteString(`"` + p + `"`)
} else {
vals = append(vals, resp.ArrayValue([]resp.Value{resp.StringValue(p)}))
}
case "bounds": case "bounds":
bbox := o.CalculatedBBox()
if msg.OutputType == server.JSON {
buf.WriteString(`,"bounds":`) buf.WriteString(`,"bounds":`)
buf.WriteString(o.CalculatedBBox().ExternalJSON()) buf.WriteString(bbox.ExternalJSON())
} else {
vals = append(vals, resp.ArrayValue([]resp.Value{
resp.StringValue(strconv.FormatFloat(bbox.Min.Y, 'f', -1, 64)),
resp.StringValue(strconv.FormatFloat(bbox.Min.X, 'f', -1, 64)),
resp.StringValue(strconv.FormatFloat(bbox.Max.Y, 'f', -1, 64)),
resp.StringValue(strconv.FormatFloat(bbox.Max.X, 'f', -1, 64)),
}))
} }
} }
if line != "" { }
if len(vs) != 0 {
return "", errInvalidNumberOfArguments return "", errInvalidNumberOfArguments
} }
fmap := col.FieldMap()
if len(fmap) > 0 { fvs := orderFields(col.FieldMap(), fields)
if len(fvs) > 0 {
fvals := make([]resp.Value, 0, len(fvs)*2)
if msg.OutputType == server.JSON {
buf.WriteString(`,"fields":{`) buf.WriteString(`,"fields":{`)
var i int }
for field, idx := range fmap { for i, fv := range fvs {
if len(fields) > idx { if msg.OutputType == server.JSON {
if !math.IsNaN(fields[idx]) {
if i > 0 { if i > 0 {
buf.WriteString(`,`) buf.WriteString(`,`)
} }
buf.WriteString(jsonString(field) + ":" + strconv.FormatFloat(fields[idx], 'f', -1, 64)) buf.WriteString(jsonString(fv.field) + ":" + strconv.FormatFloat(fv.value, 'f', -1, 64))
} else {
fvals = append(fvals, resp.StringValue(fv.field), resp.StringValue(strconv.FormatFloat(fv.value, 'f', -1, 64)))
}
i++ i++
} }
} if msg.OutputType == server.JSON {
}
buf.WriteString(`}`) buf.WriteString(`}`)
} else {
vals = append(vals, resp.ArrayValue(fvals))
} }
}
if msg.OutputType == server.JSON {
buf.WriteString(`,"elapsed":"` + time.Now().Sub(start).String() + "\"}") buf.WriteString(`,"elapsed":"` + time.Now().Sub(start).String() + "\"}")
return buf.String(), nil return buf.String(), nil
}
data, err := resp.ArrayValue(vals).MarshalRESP()
if err != nil {
return "", err
}
return string(data), nil
} }
func (c *Controller) cmdDel(line string) (d commandDetailsT, err error) { func (c *Controller) cmdDel(msg *server.Message) (res string, d commandDetailsT, err error) {
if line, d.key = token(line); d.key == "" { start := time.Now()
vs := msg.Values[1:]
var ok bool
if vs, d.key, ok = tokenval(vs); !ok || d.key == "" {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
if line, d.id = token(line); d.id == "" { if vs, d.id, ok = tokenval(vs); !ok || d.id == "" {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
if line != "" { if len(vs) != 0 {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
found := false
col := c.getCol(d.key) col := c.getCol(d.key)
if col != nil { if col != nil {
d.obj, d.fields, _ = col.Remove(d.id) d.obj, d.fields, ok = col.Remove(d.id)
if ok {
if col.Count() == 0 { if col.Count() == 0 {
c.deleteCol(d.key) c.deleteCol(d.key)
} }
found = true
}
} }
d.command = "del" d.command = "del"
d.updated = found
switch msg.OutputType {
case server.JSON:
res = `{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}"
case server.RESP:
if d.updated {
res = ":1\r\n"
} else {
res = ":0\r\n"
}
}
return return
} }
func (c *Controller) cmdDrop(line string) (d commandDetailsT, err error) { func (c *Controller) cmdDrop(msg *server.Message) (res string, d commandDetailsT, err error) {
if line, d.key = token(line); d.key == "" { start := time.Now()
vs := msg.Values[1:]
var ok bool
if vs, d.key, ok = tokenval(vs); !ok || d.key == "" {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
if line != "" { if len(vs) != 0 {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
col := c.getCol(d.key) col := c.getCol(d.key)
if col != nil { if col != nil {
c.deleteCol(d.key) c.deleteCol(d.key)
d.updated = true
} else { } else {
d.key = "" // ignore the details d.key = "" // ignore the details
d.updated = false
} }
d.command = "drop" d.command = "drop"
switch msg.OutputType {
case server.JSON:
res = `{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}"
case server.RESP:
if d.updated {
res = ":1\r\n"
} else {
res = ":0\r\n"
}
}
return return
} }
func (c *Controller) cmdFlushDB(line string) (d commandDetailsT, err error) { func (c *Controller) cmdFlushDB(msg *server.Message) (res string, d commandDetailsT, err error) {
if line != "" { start := time.Now()
vs := msg.Values[1:]
if len(vs) != 0 {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
@ -139,35 +273,43 @@ func (c *Controller) cmdFlushDB(line string) (d commandDetailsT, err error) {
c.hooks = make(map[string]*Hook) c.hooks = make(map[string]*Hook)
c.hookcols = make(map[string]map[string]*Hook) c.hookcols = make(map[string]map[string]*Hook)
d.command = "flushdb" d.command = "flushdb"
d.updated = true
switch msg.OutputType {
case server.JSON:
res = `{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}"
case server.RESP:
res = "+OK\r\n"
}
return return
} }
func (c *Controller) parseSetArgs(line string) (d commandDetailsT, fields []string, values []float64, etype, eline string, err error) { func (c *Controller) parseSetArgs(vs []resp.Value) (d commandDetailsT, fields []string, values []float64, etype string, evs []resp.Value, err error) {
var ok bool
var typ string var typ string
if line, d.key = token(line); d.key == "" { if vs, d.key, ok = tokenval(vs); !ok || d.key == "" {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
if line, d.id = token(line); d.id == "" { if vs, d.id, ok = tokenval(vs); !ok || d.id == "" {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
var arg string var arg string
var nline string var nvs []resp.Value
fields = make([]string, 0, 8) fields = make([]string, 0, 8)
values = make([]float64, 0, 8) values = make([]float64, 0, 8)
for { for {
if nline, arg = token(line); arg == "" { if nvs, arg, ok = tokenval(vs); !ok || arg == "" {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
if lc(arg, "field") { if lc(arg, "field") {
line = nline vs = nvs
var name string var name string
var svalue string var svalue string
var value float64 var value float64
if line, name = token(line); name == "" { if vs, name, ok = tokenval(vs); !ok || name == "" {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
@ -175,7 +317,7 @@ func (c *Controller) parseSetArgs(line string) (d commandDetailsT, fields []stri
err = errInvalidArgument(name) err = errInvalidArgument(name)
return return
} }
if line, svalue = token(line); svalue == "" { if vs, svalue, ok = tokenval(vs); !ok || svalue == "" {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
@ -190,16 +332,16 @@ func (c *Controller) parseSetArgs(line string) (d commandDetailsT, fields []stri
} }
break break
} }
if line, typ = token(line); typ == "" { if vs, typ, ok = tokenval(vs); !ok || typ == "" {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
if line == "" { if len(vs) == 0 {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
etype = typ etype = typ
eline = line evs = vs
switch { switch {
default: default:
@ -207,16 +349,16 @@ func (c *Controller) parseSetArgs(line string) (d commandDetailsT, fields []stri
return return
case lc(typ, "point"): case lc(typ, "point"):
var slat, slon, sz string var slat, slon, sz string
if line, slat = token(line); slat == "" { if vs, slat, ok = tokenval(vs); !ok || slat == "" {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
if line, slon = token(line); slon == "" { if vs, slon, ok = tokenval(vs); !ok || slon == "" {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
line, sz = token(line) vs, sz, ok = tokenval(vs)
if sz == "" { if !ok || sz == "" {
var sp geojson.SimplePoint var sp geojson.SimplePoint
sp.Y, err = strconv.ParseFloat(slat, 64) sp.Y, err = strconv.ParseFloat(slat, 64)
if err != nil { if err != nil {
@ -250,19 +392,19 @@ func (c *Controller) parseSetArgs(line string) (d commandDetailsT, fields []stri
} }
case lc(typ, "bounds"): case lc(typ, "bounds"):
var sminlat, sminlon, smaxlat, smaxlon string var sminlat, sminlon, smaxlat, smaxlon string
if line, sminlat = token(line); sminlat == "" { if vs, sminlat, ok = tokenval(vs); !ok || sminlat == "" {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
if line, sminlon = token(line); sminlon == "" { if vs, sminlon, ok = tokenval(vs); !ok || sminlon == "" {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
if line, smaxlat = token(line); smaxlat == "" { if vs, smaxlat, ok = tokenval(vs); !ok || smaxlat == "" {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
if line, smaxlon = token(line); smaxlon == "" { if vs, smaxlon, ok = tokenval(vs); !ok || smaxlon == "" {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
@ -302,7 +444,7 @@ func (c *Controller) parseSetArgs(line string) (d commandDetailsT, fields []stri
case lc(typ, "hash"): case lc(typ, "hash"):
var sp geojson.SimplePoint var sp geojson.SimplePoint
var shash string var shash string
if line, shash = token(line); shash == "" { if vs, shash, ok = tokenval(vs); !ok || shash == "" {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
@ -315,18 +457,28 @@ func (c *Controller) parseSetArgs(line string) (d commandDetailsT, fields []stri
sp.Y = lat sp.Y = lat
d.obj = sp d.obj = sp
case lc(typ, "object"): case lc(typ, "object"):
d.obj, err = geojson.ObjectJSON(line) var object string
if vs, object, ok = tokenval(vs); !ok || object == "" {
err = errInvalidNumberOfArguments
return
}
d.obj, err = geojson.ObjectJSON(object)
if err != nil { if err != nil {
return return
} }
} }
if len(vs) != 0 {
err = errInvalidNumberOfArguments
}
return return
} }
func (c *Controller) cmdSet(line string) (d commandDetailsT, err error) { func (c *Controller) cmdSet(msg *server.Message) (res string, d commandDetailsT, err error) {
start := time.Now()
vs := msg.Values[1:]
var fields []string var fields []string
var values []float64 var values []float64
d, fields, values, _, _, err = c.parseSetArgs(line) d, fields, values, _, _, err = c.parseSetArgs(vs)
if err != nil { if err != nil {
return return
} }
@ -337,20 +489,28 @@ func (c *Controller) cmdSet(line string) (d commandDetailsT, err error) {
} }
d.oldObj, d.oldFields, d.fields = col.ReplaceOrInsert(d.id, d.obj, fields, values) d.oldObj, d.oldFields, d.fields = col.ReplaceOrInsert(d.id, d.obj, fields, values)
d.command = "set" d.command = "set"
d.updated = true // perhaps we should do a diff on the previous object?
switch msg.OutputType {
case server.JSON:
res = `{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}"
case server.RESP:
res = "+OK\r\n"
}
return return
} }
func (c *Controller) parseFSetArgs(line string) (d commandDetailsT, err error) { func (c *Controller) parseFSetArgs(vs []resp.Value) (d commandDetailsT, err error) {
var svalue string var svalue string
if line, d.key = token(line); d.key == "" { var ok bool
if vs, d.key, ok = tokenval(vs); !ok || d.key == "" {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
if line, d.id = token(line); d.id == "" { if vs, d.id, ok = tokenval(vs); !ok || d.id == "" {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
if line, d.field = token(line); d.field == "" { if vs, d.field, ok = tokenval(vs); !ok || d.field == "" {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
@ -358,11 +518,11 @@ func (c *Controller) parseFSetArgs(line string) (d commandDetailsT, err error) {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
if line, svalue = token(line); svalue == "" { if vs, svalue, ok = tokenval(vs); !ok || svalue == "" {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
if line != "" { if len(vs) != 0 {
err = errInvalidNumberOfArguments err = errInvalidNumberOfArguments
return return
} }
@ -374,19 +534,33 @@ func (c *Controller) parseFSetArgs(line string) (d commandDetailsT, err error) {
return return
} }
func (c *Controller) cmdFset(line string) (d commandDetailsT, err error) { func (c *Controller) cmdFset(msg *server.Message) (res string, d commandDetailsT, err error) {
d, err = c.parseFSetArgs(line) start := time.Now()
vs := msg.Values[1:]
d, err = c.parseFSetArgs(vs)
col := c.getCol(d.key) col := c.getCol(d.key)
if col == nil { if col == nil {
err = errKeyNotFound err = errKeyNotFound
return return
} }
var ok bool var ok bool
d.obj, d.fields, ok = col.SetField(d.id, d.field, d.value) var updated bool
d.obj, d.fields, updated, ok = col.SetField(d.id, d.field, d.value)
if !ok { if !ok {
err = errIDNotFound err = errIDNotFound
return return
} }
d.command = "fset" d.command = "fset"
d.updated = updated
switch msg.OutputType {
case server.JSON:
res = `{"ok":true,"elapsed":"` + time.Now().Sub(start).String() + "\"}"
case server.RESP:
if updated {
res = ":1\r\n"
} else {
res = ":0\r\n"
}
}
return return
} }

View File

@ -1,84 +1,72 @@
package controller package controller
import (
"errors"
"fmt"
"math/rand"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/tidwall/tile38/controller/log"
)
func (c *Controller) cmdMassInsert(line string) error { func (c *Controller) cmdMassInsert(line string) error {
// massinsert simply forwards a bunch of cmdSets // // massinsert simply forwards a bunch of cmdSets
var snumCols, snumPoints string // var snumCols, snumPoints string
var cols, objs int // var cols, objs int
if line, snumCols = token(line); snumCols == "" { // if line, snumCols = token(line); snumCols == "" {
return errors.New("invalid number of arguments") // return errors.New("invalid number of arguments")
} // }
if line, snumPoints = token(line); snumPoints == "" { // if line, snumPoints = token(line); snumPoints == "" {
return errors.New("invalid number of arguments") // return errors.New("invalid number of arguments")
} // }
if line != "" { // if line != "" {
return errors.New("invalid number of arguments") // return errors.New("invalid number of arguments")
} // }
n, err := strconv.ParseUint(snumCols, 10, 64) // n, err := strconv.ParseUint(snumCols, 10, 64)
if err != nil { // if err != nil {
return errors.New("invalid argument '" + snumCols + "'") // return errors.New("invalid argument '" + snumCols + "'")
} // }
cols = int(n) // cols = int(n)
n, err = strconv.ParseUint(snumPoints, 10, 64) // n, err = strconv.ParseUint(snumPoints, 10, 64)
if err != nil { // if err != nil {
return errors.New("invalid argument '" + snumPoints + "'") // return errors.New("invalid argument '" + snumPoints + "'")
} // }
docmd := func(cmd string) error { // docmd := func(cmd string) error {
c.mu.Lock() // c.mu.Lock()
defer c.mu.Unlock() // defer c.mu.Unlock()
_, d, err := c.command(cmd, nil) // _, d, err := c.command(cmd, nil)
if err != nil { // if err != nil {
return err // return err
} // }
if err := c.writeAOF(cmd, &d); err != nil { // if err := c.writeAOF(cmd, &d); err != nil {
return err // return err
} // }
return nil // return nil
} // }
rand.Seed(time.Now().UnixNano()) // rand.Seed(time.Now().UnixNano())
objs = int(n) // objs = int(n)
var wg sync.WaitGroup // var wg sync.WaitGroup
var k uint64 // var k uint64
wg.Add(cols) // wg.Add(cols)
for i := 0; i < cols; i++ { // for i := 0; i < cols; i++ {
key := "mi:" + strconv.FormatInt(int64(i), 10) // key := "mi:" + strconv.FormatInt(int64(i), 10)
go func(key string) { // go func(key string) {
defer func() { // defer func() {
wg.Done() // wg.Done()
}() // }()
for j := 0; j < objs; j++ { // for j := 0; j < objs; j++ {
id := strconv.FormatInt(int64(j), 10) // id := strconv.FormatInt(int64(j), 10)
lat, lon := rand.Float64()*180-90, rand.Float64()*360-180 // lat, lon := rand.Float64()*180-90, rand.Float64()*360-180
var line string // var line string
if true { // if true {
fields := fmt.Sprintf("FIELD field %f", rand.Float64()*10) // fields := fmt.Sprintf("FIELD field %f", rand.Float64()*10)
line = fmt.Sprintf(`set %s %s %s POINT %f %f`, key, id, fields, lat, lon) // line = fmt.Sprintf(`set %s %s %s POINT %f %f`, key, id, fields, lat, lon)
} else { // } else {
line = fmt.Sprintf(`set %s %s POINT %f %f`, key, id, lat, lon) // line = fmt.Sprintf(`set %s %s POINT %f %f`, key, id, lat, lon)
} // }
if err := docmd(line); err != nil { // if err := docmd(line); err != nil {
log.Fatal(err) // log.Fatal(err)
return // return
} // }
atomic.AddUint64(&k, 1) // atomic.AddUint64(&k, 1)
if j%10000 == 10000-1 { // if j%10000 == 10000-1 {
log.Infof("massinsert: %s %d/%d", key, atomic.LoadUint64(&k), cols*objs) // log.Infof("massinsert: %s %d/%d", key, atomic.LoadUint64(&k), cols*objs)
} // }
} // }
}(key) // }(key)
} // }
wg.Wait() // wg.Wait()
log.Infof("massinsert: done %d objects", atomic.LoadUint64(&k)) // log.Infof("massinsert: done %d objects", atomic.LoadUint64(&k))
return nil return nil
} }

View File

@ -96,16 +96,16 @@ func (c *Controller) cmdFollow(line string) error {
func (c *Controller) followHandleCommand(line string, followc uint64, w io.Writer) (int, error) { func (c *Controller) followHandleCommand(line string, followc uint64, w io.Writer) (int, error) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if c.followc != followc { // if c.followc != followc {
return c.aofsz, errNoLongerFollowing // return c.aofsz, errNoLongerFollowing
} // }
_, d, err := c.command(line, w) // _, d, err := c.command(line, w)
if err != nil { // if err != nil {
return c.aofsz, err // return c.aofsz, err
} // }
if err := c.writeAOF(line, &d); err != nil { // if err := c.writeAOF(line, &d); err != nil {
return c.aofsz, err // return c.aofsz, err
} // }
return c.aofsz, nil return c.aofsz, nil
} }

View File

@ -0,0 +1,295 @@
package server
import (
"bufio"
"bytes"
"crypto/sha1"
"encoding/base64"
"errors"
"io"
"net/url"
"strconv"
"strings"
"github.com/tidwall/resp"
)
const TelnetIsJSON = false
type Type int
const (
Null Type = iota
RESP
Telnet
Native
HTTP
WebSocket
JSON
)
type errRESPProtocolError struct {
msg string
}
func (err errRESPProtocolError) Error() string {
return "Protocol error: " + err.msg
}
type Message struct {
Command string
Values []resp.Value
ConnType Type
OutputType Type
Auth string
}
type AnyReaderWriter struct {
rd *bufio.Reader
wr io.Writer
ws bool
}
func NewAnyReaderWriter(rd io.Reader) *AnyReaderWriter {
ar := &AnyReaderWriter{}
if rd2, ok := rd.(*bufio.Reader); ok {
ar.rd = rd2
} else {
ar.rd = bufio.NewReader(rd)
}
if wr, ok := rd.(io.Writer); ok {
ar.wr = wr
}
return ar
}
func (ar *AnyReaderWriter) peekcrlfline() (string, error) {
// this is slow operation.
for i := 0; ; i++ {
bb, err := ar.rd.Peek(i)
if err != nil {
return "", err
}
if len(bb) > 2 && bb[len(bb)-2] == '\r' && bb[len(bb)-1] == '\n' {
return string(bb[:len(bb)-2]), nil
}
}
}
func (ar *AnyReaderWriter) readcrlfline() (string, error) {
var line []byte
for {
bb, err := ar.rd.ReadBytes('\r')
if err != nil {
return "", err
}
if line == nil {
line = bb
} else {
line = append(line, bb...)
}
b, err := ar.rd.ReadByte()
if err != nil {
return "", err
}
if b == '\n' {
return string(line[:len(line)-1]), nil
}
line = append(line, b)
}
}
func (ar *AnyReaderWriter) ReadMessage() (*Message, error) {
b, err := ar.rd.ReadByte()
if err != nil {
return nil, err
}
if err := ar.rd.UnreadByte(); err != nil {
return nil, err
}
switch b {
case 'G', 'P':
line, err := ar.peekcrlfline()
if err != nil {
return nil, err
}
if strings.HasSuffix(line, " HTTP/1.1") {
return ar.readHTTPMessage()
}
case '$':
return ar.readNativeMessage()
}
// MultiBulk also reads telnet
return ar.readMultiBulkMessage()
}
func (ar *AnyReaderWriter) readNativeMessage() (*Message, error) {
b, err := ar.rd.ReadBytes(' ')
if err != nil {
return nil, err
}
if len(b) > 0 && b[0] != '$' {
return nil, errors.New("invalid message")
}
n, err := strconv.ParseUint(string(b[1:len(b)-1]), 10, 32)
if err != nil {
return nil, errors.New("invalid size")
}
if n > 0x1FFFFFFF { // 536,870,911 bytes
return nil, errors.New("message too big")
}
b = make([]byte, int(n)+2)
if _, err := io.ReadFull(ar.rd, b); err != nil {
return nil, err
}
if b[len(b)-2] != '\r' || b[len(b)-1] != '\n' {
return nil, errors.New("expecting crlf")
}
values := make([]resp.Value, 0, 16)
line := b[:len(b)-2]
reading:
for len(line) != 0 {
if line[0] == '{' {
// The native protocol cannot understand json boundaries so it assumes that
// a json element must be at the end of the line.
values = append(values, resp.StringValue(string(line)))
break
}
i := 0
for ; i < len(line); i++ {
if line[i] == ' ' {
values = append(values, resp.StringValue(string(line[:i])))
line = line[i+1:]
continue reading
}
}
values = append(values, resp.StringValue(string(line)))
break
}
return &Message{Command: commandValues(values), Values: values, ConnType: Native, OutputType: JSON}, nil
}
func commandValues(values []resp.Value) string {
if len(values) == 0 {
return ""
}
return strings.ToLower(values[0].String())
}
func (ar *AnyReaderWriter) readMultiBulkMessage() (*Message, error) {
rd := resp.NewReader(ar.rd)
v, telnet, _, err := rd.ReadMultiBulk()
if err != nil {
return nil, err
}
values := v.Array()
if len(values) == 0 {
return nil, nil
}
if telnet && TelnetIsJSON {
return &Message{Command: commandValues(values), Values: values, ConnType: Telnet, OutputType: JSON}, nil
}
return &Message{Command: commandValues(values), Values: values, ConnType: RESP, OutputType: RESP}, nil
}
func (ar *AnyReaderWriter) readHTTPMessage() (*Message, error) {
msg := &Message{ConnType: HTTP, OutputType: JSON}
line, err := ar.readcrlfline()
if err != nil {
return nil, err
}
parts := strings.Split(line, " ")
if len(parts) != 3 {
return nil, errors.New("invalid HTTP request")
}
method := parts[0]
path := parts[1]
if len(path) == 0 || path[0] != '/' {
return nil, errors.New("invalid HTTP request")
}
path, err = url.QueryUnescape(path[1:])
if err != nil {
return nil, errors.New("invalid HTTP request")
}
if method != "GET" && method != "POST" {
return nil, errors.New("invalid HTTP method")
}
contentLength := 0
websocket := false
websocketVersion := 0
websocketKey := ""
for {
header, err := ar.readcrlfline()
if err != nil {
return nil, err
}
if header == "" {
break // end of headers
}
if header[0] == 'a' || header[0] == 'A' {
if strings.HasPrefix(strings.ToLower(header), "authorization:") {
msg.Auth = strings.TrimSpace(header[len("authorization:"):])
}
} else if header[0] == 'u' || header[0] == 'U' {
if strings.HasPrefix(strings.ToLower(header), "upgrade:") && strings.ToLower(strings.TrimSpace(header[len("upgrade:"):])) == "websocket" {
websocket = true
}
} else if header[0] == 's' || header[0] == 'S' {
if strings.HasPrefix(strings.ToLower(header), "sec-websocket-version:") {
var n uint64
n, err = strconv.ParseUint(strings.TrimSpace(header[len("sec-websocket-version:"):]), 10, 64)
if err != nil {
return nil, err
}
websocketVersion = int(n)
} else if strings.HasPrefix(strings.ToLower(header), "sec-websocket-key:") {
websocketKey = strings.TrimSpace(header[len("sec-websocket-key:"):])
}
} else if header[0] == 'c' || header[0] == 'C' {
if strings.HasPrefix(strings.ToLower(header), "content-length:") {
var n uint64
n, err = strconv.ParseUint(strings.TrimSpace(header[len("content-length:"):]), 10, 64)
if err != nil {
return nil, err
}
contentLength = int(n)
}
}
}
if websocket && websocketVersion >= 13 && websocketKey != "" {
msg.ConnType = WebSocket
if ar.wr == nil {
return nil, errors.New("connection is nil")
}
sum := sha1.Sum([]byte(websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
accept := base64.StdEncoding.EncodeToString(sum[:])
wshead := "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + accept + "\r\n\r\n"
if _, err = ar.wr.Write([]byte(wshead)); err != nil {
return nil, err
}
ar.ws = true
} else if contentLength > 0 {
msg.ConnType = HTTP
buf := make([]byte, contentLength)
if _, err = io.ReadFull(ar.rd, buf); err != nil {
return nil, err
}
path += string(buf)
}
if path == "" {
return msg, nil
}
if !strings.HasSuffix(path, "\r\n") {
path += "\r\n"
}
rd := NewAnyReaderWriter(bytes.NewBufferString(path))
nmsg, err := rd.ReadMessage()
if err != nil {
return nil, err
}
msg.OutputType = nmsg.OutputType
msg.Values = nmsg.Values
msg.Command = commandValues(nmsg.Values)
return msg, nil
}

View File

@ -2,14 +2,14 @@ package server
import ( import (
"bufio" "bufio"
"bytes"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"strings" "strings"
"github.com/tidwall/tile38/client" //"github.com/tidwall/tile38/client"
"github.com/tidwall/tile38/controller/log" "github.com/tidwall/tile38/controller/log"
"github.com/tidwall/tile38/core" "github.com/tidwall/tile38/core"
) )
@ -51,7 +51,7 @@ var errCloseHTTP = errors.New("close http")
func ListenAndServe( func ListenAndServe(
host string, port int, host string, port int,
protected func() bool, protected func() bool,
handler func(conn *Conn, command []byte, rd *bufio.Reader, w io.Writer, websocket bool) error, handler func(conn *Conn, msg *Message, rd *bufio.Reader, w io.Writer, websocket bool) error,
) error { ) error {
ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", host, port)) ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", host, port))
if err != nil { if err != nil {
@ -68,17 +68,17 @@ func ListenAndServe(
} }
} }
func writeCommandErr(proto client.Proto, conn *Conn, err error) error { // func writeCommandErr(proto client.Proto, conn *Conn, err error) error {
if proto == client.HTTP || proto == client.WebSocket { // if proto == client.HTTP || proto == client.WebSocket {
conn.Write([]byte(`HTTP/1.1 500 ` + err.Error() + "\r\nConnection: close\r\n\r\n")) // conn.Write([]byte(`HTTP/1.1 500 ` + err.Error() + "\r\nConnection: close\r\n\r\n"))
} // }
return err // return err
} // }
func handleConn( func handleConn(
conn *Conn, conn *Conn,
protected func() bool, protected func() bool,
handler func(conn *Conn, command []byte, rd *bufio.Reader, w io.Writer, websocket bool) error, handler func(conn *Conn, msg *Message, rd *bufio.Reader, w io.Writer, websocket bool) error,
) { ) {
addr := conn.RemoteAddr().String() addr := conn.RemoteAddr().String()
if core.ShowDebugMessages { if core.ShowDebugMessages {
@ -96,59 +96,10 @@ func handleConn(
} }
} }
defer conn.Close() defer conn.Close()
rd := bufio.NewReader(conn) rd := NewAnyReaderWriter(conn)
for i := 0; ; i++ { brd := rd.rd
err := func() error { for {
command, proto, auth, err := client.ReadMessage(rd, conn) msg, err := rd.ReadMessage()
if err != nil {
return err
}
if len(command) > 0 && (command[0] == 'Q' || command[0] == 'q') && strings.ToLower(string(command)) == "quit" {
return io.EOF
}
var b bytes.Buffer
var denied bool
if (proto == client.HTTP || proto == client.WebSocket) && auth != "" {
if err := handler(conn, []byte("AUTH "+auth), rd, &b, proto == client.WebSocket); err != nil {
return writeCommandErr(proto, conn, err)
}
if strings.HasPrefix(b.String(), `{"ok":false`) {
denied = true
} else {
b.Reset()
}
}
if !denied {
if err := handler(conn, command, rd, &b, proto == client.WebSocket); err != nil {
return writeCommandErr(proto, conn, err)
}
}
switch proto {
case client.Native:
if err := client.WriteMessage(conn, b.Bytes()); err != nil {
return err
}
case client.HTTP:
if err := client.WriteHTTP(conn, b.Bytes()); err != nil {
return err
}
return errCloseHTTP
case client.WebSocket:
if err := client.WriteWebSocket(conn, b.Bytes()); err != nil {
return err
}
if _, err := conn.Write([]byte{137, 0}); err != nil {
return err
}
return errCloseHTTP
default:
b.WriteString("\r\n")
if _, err := conn.Write(b.Bytes()); err != nil {
return err
}
}
return nil
}()
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
return return
@ -160,5 +111,83 @@ func handleConn(
log.Error(err) log.Error(err)
return return
} }
if msg != nil && msg.Command != "" {
if msg.Command == "quit" {
if msg.OutputType == RESP {
io.WriteString(conn, "+OK\r\n")
}
return
}
err := handler(conn, msg, brd, conn, msg.ConnType == WebSocket)
if err != nil {
log.Error(err)
return
}
}
} }
} }
//err := func() error {
// command, proto, auth, err := client.ReadMessage(rd, conn)
// if err != nil {
// return err
// }
// if len(command) > 0 && (command[0] == 'Q' || command[0] == 'q') && strings.ToLower(string(command)) == "quit" {
// return io.EOF
// }
// var b bytes.Buffer
// var denied bool
// if (proto == client.HTTP || proto == client.WebSocket) && auth != "" {
// if err := handler(conn, []byte("AUTH "+auth), rd, &b, proto == client.WebSocket); err != nil {
// return writeCommandErr(proto, conn, err)
// }
// if strings.HasPrefix(b.String(), `{"ok":false`) {
// denied = true
// } else {
// b.Reset()
// }
// }
// if !denied {
// if err := handler(conn, command, rd, &b, proto == client.WebSocket); err != nil {
// return writeCommandErr(proto, conn, err)
// }
// }
// switch proto {
// case client.Native:
// if err := client.WriteMessage(conn, b.Bytes()); err != nil {
// return err
// }
// case client.HTTP:
// if err := client.WriteHTTP(conn, b.Bytes()); err != nil {
// return err
// }
// return errCloseHTTP
// case client.WebSocket:
// if err := client.WriteWebSocket(conn, b.Bytes()); err != nil {
// return err
// }
// if _, err := conn.Write([]byte{137, 0}); err != nil {
// return err
// }
// return errCloseHTTP
// default:
// b.WriteString("\r\n")
// if _, err := conn.Write(b.Bytes()); err != nil {
// return err
// }
// }
// return nil
//}()
// if err != nil {
// if err == io.EOF {
// return
// }
// if err == errCloseHTTP ||
// strings.Contains(err.Error(), "use of closed network connection") {
// return
// }
// log.Error(err)
// return
// }
// }
// }

View File

@ -6,6 +6,8 @@ import (
"math" "math"
"strconv" "strconv"
"strings" "strings"
"github.com/tidwall/resp"
) )
const defaultSearchOutput = outputObjects const defaultSearchOutput = outputObjects
@ -29,6 +31,15 @@ func token(line string) (newLine, token string) {
return "", line return "", line
} }
func tokenval(vs []resp.Value) (nvs []resp.Value, token string, ok bool) {
if len(vs) > 0 {
token = vs[0].String()
nvs = vs[1:]
ok = true
}
return
}
func tokenlc(line string) (newLine, token string) { func tokenlc(line string) (newLine, token string) {
for i := 0; i < len(line); i++ { for i := 0; i < len(line); i++ {
ch := line[i] ch := line[i]