From 6d944ada32407bbadcf1a5720773ef6c18261fc3 Mon Sep 17 00:00:00 2001 From: Josh Baker Date: Thu, 8 Sep 2016 16:11:53 -0700 Subject: [PATCH] fixed #49. fragmented pipeline requests. --- controller/controller.go | 5 +- controller/server/server.go | 6 +- tests/tests.go | 1 + tests/tests_test.go | 268 ++++++++++++++++++++ vendor/github.com/tidwall/resp/aof.go | 24 +- vendor/github.com/tidwall/resp/aof_test.go | 56 +++- vendor/github.com/tidwall/resp/resp.go | 186 +++----------- vendor/github.com/tidwall/resp/resp_test.go | 154 +++++++++++ 8 files changed, 544 insertions(+), 156 deletions(-) create mode 100644 tests/tests.go create mode 100644 tests/tests_test.go diff --git a/controller/controller.go b/controller/controller.go index 1eea6b19..dbf4c8b8 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -87,6 +87,9 @@ type Controller struct { // ListenAndServe starts a new tile38 server func ListenAndServe(host string, port int, dir string) error { + return ListenAndServeEx(host, port, dir, nil) +} +func ListenAndServeEx(host string, port int, dir string, ln *net.Listener) error { log.Infof("Server started, Tile38 version %s, git %s", core.Version, core.GitSHA) c := &Controller{ host: host, @@ -178,7 +181,7 @@ func ListenAndServe(host string, port int, dir string) error { delete(c.conns, conn) c.mu.Unlock() } - return server.ListenAndServe(host, port, protected, handler, opened, closed) + return server.ListenAndServe(host, port, protected, handler, opened, closed, ln) } func (c *Controller) watchMemory() { diff --git a/controller/server/server.go b/controller/server/server.go index 48e93621..f83cd8d1 100644 --- a/controller/server/server.go +++ b/controller/server/server.go @@ -54,17 +54,21 @@ func ListenAndServe( handler func(conn *Conn, msg *Message, rd *AnyReaderWriter, w io.Writer, websocket bool) error, opened func(conn *Conn), closed func(conn *Conn), + lnp *net.Listener, ) error { ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", host, port)) if err != nil { return err } + if lnp != nil { + *lnp = ln + } log.Infof("The server is now ready to accept connections on port %d", port) for { conn, err := ln.Accept() if err != nil { log.Error(err) - continue + return err } go handleConn(&Conn{Conn: conn}, protected, handler, opened, closed) } diff --git a/tests/tests.go b/tests/tests.go new file mode 100644 index 00000000..ca8701d2 --- /dev/null +++ b/tests/tests.go @@ -0,0 +1 @@ +package tests diff --git a/tests/tests_test.go b/tests/tests_test.go new file mode 100644 index 00000000..55f8396f --- /dev/null +++ b/tests/tests_test.go @@ -0,0 +1,268 @@ +package tests + +import ( + "bufio" + "bytes" + "encoding/hex" + "errors" + "fmt" + "io" + "io/ioutil" + "math/rand" + "net" + "os" + "strconv" + "testing" + "time" + + "github.com/tidwall/tile38/controller" +) + +const port = 21098 + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func uid() string { + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + panic("random error: " + err.Error()) + } + return hex.EncodeToString(b) +} +func makeKey(size int) string { + s := "key+" + uid() + for len(s) < size { + s += "+" + uid() + } + return s +} +func makeID(size int) string { + s := "key+" + uid() + for len(s) < size { + s += "+" + uid() + } + return s +} +func makeJSON(size int) string { + var buf bytes.Buffer + buf.WriteString(`{"type":"MultiPoint","coordinates":[`) + for i := 0; buf.Len() < size; i++ { + if i > 0 { + buf.WriteString(",") + } + fmt.Fprintf(&buf, "[%f,%f]", rand.Float64()*360-180, rand.Float64()*180-90) + } + + buf.WriteString(`]}`) + return buf.String() +} + +func TestServer(t *testing.T) { + dir, err := ioutil.TempDir("", "tile38") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + var ln net.Listener + var done = make(chan bool, 2) + var ignoreErrs bool + go func() { + // log.Default = log.New(ioutil.Discard, nil) + err := controller.ListenAndServeEx("localhost", port, dir, &ln) + if err != nil { + if !ignoreErrs { + t.Fatal(err) + } + } + done <- true + }() + defer func() { + ignoreErrs = true + ln.Close() + <-done + }() + time.Sleep(time.Millisecond * 100) + t.Run("PingPong", SubTestPingPong) + t.Run("SetPoint", SubTestSetPoint) + t.Run("Set100KB", SubTestSet100KB) + t.Run("Set1MB", SubTestSet1MB) + t.Run("Set10MB", SubTestSet10MB) +} + +func SubTestPingPong(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + rd := bufio.NewReader(conn) + if _, err := conn.Write(buildCommand("PING")); err != nil { + t.Fatal(err) + } + resp, err := readResponse(rd) + if err != nil { + t.Fatal(err) + } + if resp != "+PONG\r\n" { + t.Fatal("expected pong") + } +} + +func SubTestSetPoint(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + rd := bufio.NewReader(conn) + cmd := buildCommand("SET", makeKey(100), makeID(100), "POINT", "33.5", "-115.5") + if _, err := conn.Write(cmd); err != nil { + t.Fatal(err) + } + resp, err := readResponse(rd) + if err != nil { + t.Fatal(err) + } + if resp != "+OK\r\n" { + t.Fatal("expected pong") + } +} +func testSet(t *testing.T, jsonSize, keyIDSize, frag int) { + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + rd := bufio.NewReader(conn) + key := makeKey(keyIDSize) + id := makeID(keyIDSize) + json := makeJSON(jsonSize) + cmd := buildCommand("SET", key, id, "OBJECT", json) + if frag == 0 { + if _, err := conn.Write(cmd); err != nil { + t.Fatal(err) + } + } else { + var nn int + olen := len(cmd) + for len(cmd) >= frag { + if n, err := conn.Write(cmd[:frag]); err != nil { + t.Fatal(err) + } else { + nn += int(n) + } + cmd = cmd[frag:] + } + if len(cmd) > 0 { + if n, err := conn.Write(cmd); err != nil { + t.Fatal(err) + } else { + nn += int(n) + } + } + if nn != olen { + t.Fatal("invalid sent amount") + } + } + resp, err := readResponse(rd) + if err != nil { + println(len(resp)) + t.Fatal(err) + } + if resp != "+OK\r\n" { + t.Fatal("expected pong") + } + cmd = buildCommand("GET", key, id) + if _, err := conn.Write(cmd); err != nil { + t.Fatal(err) + } + resp, err = readResponse(rd) + if err != nil { + t.Fatal(err) + } + diff := float64(len(json))/float64(len(resp)) - 1.0 + if diff > 0.1 { + t.Fatal("too big of a difference") + } +} +func SubTestSet100KB(t *testing.T) { + testSet(t, 100*1024, 100, 1024) +} +func SubTestSet1MB(t *testing.T) { + testSet(t, 1024*1024, 100, 1024) +} +func SubTestSet10MB(t *testing.T) { + testSet(t, 10*1024*1024, 100, 1024) +} +func buildCommand(arg ...string) []byte { + var b []byte + b = append(b, '*') + b = append(b, []byte(strconv.FormatInt(int64(len(arg)), 10))...) + b = append(b, '\r', '\n') + for _, arg := range arg { + b = append(b, '$') + b = append(b, []byte(strconv.FormatInt(int64(len(arg)), 10))...) + b = append(b, '\r', '\n') + b = append(b, []byte(arg)...) + b = append(b, '\r', '\n') + } + return b +} + +func readResponse(rd *bufio.Reader) (string, error) { + c, err := rd.ReadByte() + if err != nil { + return "", err + } + var resp []byte + switch c { + default: + return string(resp), errors.New("invalid response") + case '+': + resp, err = readString(rd, []byte{'+'}) + case '$': + resp, err = readBulk(rd, []byte{'$'}) + } + if err != nil { + return string(resp), err + } + return string(resp), nil +} + +func readString(rd *bufio.Reader, b []byte) ([]byte, error) { + line, err := rd.ReadBytes('\n') + if err != nil { + return b, err + } + if len(line) == 1 || line[len(line)-2] != '\r' { + return b, errors.New("invalid response") + } + b = append(b, line...) + return b, nil +} + +func readBulk(rd *bufio.Reader, b []byte) ([]byte, error) { + line, err := rd.ReadBytes('\n') + if err != nil { + return b, err + } + if len(line) == 1 || line[len(line)-2] != '\r' { + return b, errors.New("invalid response") + } + b = append(b, line...) + sz, err := strconv.ParseUint(string(line[:len(line)-2]), 10, 64) + if err != nil { + return b, err + } + data := make([]byte, int(sz)) + if _, err := io.ReadFull(rd, data); err != nil { + return b, err + } + if len(data) < 2 || line[len(line)-2] != '\r' || line[len(line)-1] != '\n' { + return b, errors.New("invalid response") + } + b = append(b, data...) + return b, nil +} diff --git a/vendor/github.com/tidwall/resp/aof.go b/vendor/github.com/tidwall/resp/aof.go index 4930b88b..7f96bcdd 100644 --- a/vendor/github.com/tidwall/resp/aof.go +++ b/vendor/github.com/tidwall/resp/aof.go @@ -54,7 +54,6 @@ func OpenAOF(path string) (*AOF, error) { aof.policy = EverySecond go func() { for { - time.Sleep(time.Second) aof.mu.Lock() if aof.closed { aof.mu.Unlock() @@ -64,6 +63,7 @@ func OpenAOF(path string) (*AOF, error) { aof.f.Sync() } aof.mu.Unlock() + time.Sleep(time.Second) } }() return aof, nil @@ -127,9 +127,23 @@ func (aof *AOF) readValues(iterator func(v Value)) error { // Append writes a value to the end of the file. func (aof *AOF) Append(v Value) error { - b, err := v.MarshalRESP() - if err != nil { - return err + return aof.AppendMulti([]Value{v}) +} + +// AppendMulti writes multiple values to the end of the file. +// This operation can increase performance over calling multiple Append()s and also has the benefit of transactional writes. +func (aof *AOF) AppendMulti(vs []Value) error { + var bs []byte + for _, v := range vs { + b, err := v.MarshalRESP() + if err != nil { + return err + } + if bs == nil { + bs = b + } else { + bs = append(bs, b...) + } } aof.mu.Lock() defer aof.mu.Unlock() @@ -141,7 +155,7 @@ func (aof *AOF) Append(v Value) error { return err } } - _, err = aof.f.Write(b) + _, err := aof.f.Write(bs) if err != nil { return err } diff --git a/vendor/github.com/tidwall/resp/aof_test.go b/vendor/github.com/tidwall/resp/aof_test.go index 25913f64..0c39a0d4 100644 --- a/vendor/github.com/tidwall/resp/aof_test.go +++ b/vendor/github.com/tidwall/resp/aof_test.go @@ -4,20 +4,43 @@ import ( "fmt" "os" "testing" + "time" ) func TestAOF(t *testing.T) { + os.RemoveAll("aof.tmp") + if err := os.MkdirAll("aof.tmp", 0700); err != nil { + t.Fatal(err) + } defer func() { os.RemoveAll("aof.tmp") }() - os.RemoveAll("aof.tmp") - f, err := OpenAOF("aof.tmp") + + if _, err := OpenAOF("aof.tmp"); err == nil { + t.Fatal("expecting error, got nil") + } + + f, err := OpenAOF("aof.tmp/aof") if err != nil { t.Fatal(err) } defer func() { f.Close() + if err := f.Close(); err == nil || err.Error() != "closed" { + t.Fatalf("expected 'closed', got '%v'", err) + } }() + // Test Setting Sync Policies + f.SetSyncPolicy(Never) + sps := fmt.Sprintf("%s %s %s %s", SyncPolicy(99), Never, Always, EverySecond) + if sps != "unknown never always every second" { + t.Fatalf("expected '%s', got '%s'", "unknown never always every second", sps) + } + f.SetSyncPolicy(99) + f.SetSyncPolicy(Never) + f.SetSyncPolicy(Always) + f.SetSyncPolicy(EverySecond) + f.SetSyncPolicy(EverySecond) for i := 0; i < 12345; i++ { if err := f.Append(StringValue(fmt.Sprintf("hello world #%d\n", i))); err != nil { t.Fatal(err) @@ -35,7 +58,8 @@ func TestAOF(t *testing.T) { t.Fatal(err) } f.Close() - f, err = OpenAOF("aof.tmp") + f.Close() // Test closing twice + f, err = OpenAOF("aof.tmp/aof") if err != nil { t.Fatal(err) } @@ -56,4 +80,30 @@ func TestAOF(t *testing.T) { }); err != nil { t.Fatal(err) } + + var multi []Value + for i := 0; i < 50; i++ { + multi = append(multi, StringValue(fmt.Sprintf("hello multi world #%d\n", i))) + } + if err := f.AppendMulti(multi); err != nil { + t.Fatal(err) + } + + skip := i + i = 0 + j := 0 + if err := f.Scan(func(v Value) { + if i >= skip { + s := v.String() + e := fmt.Sprintf("hello multi world #%d\n", j) + if s != e { + t.Fatalf("#%d is '%s', expect '%s'", j, s, e) + } + j++ + } + i++ + }); err != nil { + t.Fatal(err) + } + time.Sleep(time.Millisecond * 10) } diff --git a/vendor/github.com/tidwall/resp/resp.go b/vendor/github.com/tidwall/resp/resp.go index 3ba435a2..79c28e53 100644 --- a/vendor/github.com/tidwall/resp/resp.go +++ b/vendor/github.com/tidwall/resp/resp.go @@ -1,6 +1,7 @@ package resp import ( + "bufio" "bytes" "errors" "fmt" @@ -227,15 +228,12 @@ func (err errProtocol) Error() string { // Reader is a specialized RESP Value type reader. type Reader struct { - rd io.Reader - buf []byte - p, l, s int - rerr error + rd *bufio.Reader } // NewReader returns a Reader for reading Value types. func NewReader(rd io.Reader) *Reader { - r := &Reader{rd: rd} + r := &Reader{rd: bufio.NewReader(rd)} return r } @@ -255,13 +253,13 @@ func (rd *Reader) ReadMultiBulk() (value Value, telnet bool, n int, err error) { func (rd *Reader) readValue(multibulk, child bool) (val Value, telnet bool, n int, err error) { var rn int var c byte - c, rn, err = rd.readByte() - n += rn + c, err = rd.rd.ReadByte() if err != nil { return nullValue, false, n, err } + n++ if c == '*' { - val, n, err = rd.readArrayValue(multibulk) + val, rn, err = rd.readArrayValue(multibulk) } else if multibulk && !child { telnet = true } else { @@ -275,16 +273,17 @@ func (rd *Reader) readValue(multibulk, child bool) (val Value, telnet bool, n in } telnet = true case '-', '+': - val, n, err = rd.readSimpleValue(c) + val, rn, err = rd.readSimpleValue(c) case ':': - val, n, err = rd.readIntegerValue() + val, rn, err = rd.readIntegerValue() case '$': - val, n, err = rd.readBulkValue() + val, rn, err = rd.readBulkValue() } } if telnet { - rd.unreadByte(c) - val, n, err = rd.readTelnetMultiBulk() + n-- + rd.rd.UnreadByte() + val, rn, err = rd.readTelnetMultiBulk() if err == nil { telnet = true } @@ -297,17 +296,16 @@ func (rd *Reader) readValue(multibulk, child bool) (val Value, telnet bool, n in } func (rd *Reader) readTelnetMultiBulk() (v Value, n int, err error) { - var rn int values := make([]Value, 0, 8) var c byte var bline []byte var quote, mustspace bool for { - c, rn, err = rd.readByte() - n += rn + c, err = rd.rd.ReadByte() if err != nil { return nullValue, n, err } + n += 1 if c == '\n' { if len(bline) > 0 && bline[len(bline)-1] == '\r' { bline = bline[:len(bline)-1] @@ -354,7 +352,20 @@ func (rd *Reader) readSimpleValue(typ byte) (val Value, n int, err error) { } return Value{typ: Type(typ), str: line}, n, nil } - +func (rd *Reader) readLine() (line []byte, n int, err error) { + for { + b, err := rd.rd.ReadBytes('\n') + if err != nil { + return nil, 0, err + } + n += len(b) + line = append(line, b...) + if len(line) >= 2 && line[len(line)-2] == '\r' { + break + } + } + return line[:len(line)-2], n, nil +} func (rd *Reader) readBulkValue() (val Value, n int, err error) { var rn int var l int @@ -372,8 +383,8 @@ func (rd *Reader) readBulkValue() (val Value, n int, err error) { if l > 512*1024*1024 { return nullValue, n, &errProtocol{"invalid bulk length"} } - var b []byte - b, rn, err = rd.readBytes(l + 2) + b := make([]byte, l+2) + rn, err = io.ReadFull(rd.rd, b) n += rn if err != nil { return nullValue, n, err @@ -427,136 +438,15 @@ func (rd *Reader) readIntegerValue() (val Value, n int, err error) { } func (rd *Reader) readInt() (x int, n int, err error) { - var rn int - var c byte - neg := 1 - c, rn, err = rd.readByte() - n += rn + line, n, err := rd.readLine() + if err != nil { + return 0, 0, err + } + i64, err := strconv.ParseInt(string(line), 10, 64) if err != nil { return 0, n, err } - if c == '-' { - neg = -1 - c, rn, err = rd.readByte() - n += rn - if err != nil { - return 0, n, err - } - } - var length int - for { - switch c { - default: - return 0, n, &errProtocol{"invalid length"} - case '\r': - c, rn, err = rd.readByte() - n += rn - if err != nil { - return 0, n, err - } - if c != '\n' { - return 0, n, &errProtocol{"invalid length"} - } - return length * neg, n, nil - case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': - length = (length * 10) + int(c-'0') - } - c, rn, err = rd.readByte() - n += rn - if err != nil { - return 0, n, err - } - } -} - -func (rd *Reader) readLine() (b []byte, n int, err error) { - var lc byte - p := rd.p - l := rd.l - for { - // read byte - for l == 0 { - if err := rd.fillBuffer(true); err != nil { - return nil, 0, err - } - l = rd.l - (p - rd.p) - } - c := rd.buf[p] - p++ - l-- - n++ - if c == '\n' && lc == '\r' { - b = rd.buf[rd.p : rd.p+n-2] - rd.p = p - rd.l -= n - return b, n, nil - } - lc = c - } -} - -func (rd *Reader) readBytes(count int) (b []byte, n int, err error) { - if count < 0 { - return nil, 0, errors.New("invalid argument") - } - for rd.l < count { - if err := rd.fillBuffer(false); err != nil { - return nil, 0, err - } - } - b = rd.buf[rd.p : rd.p+count] - rd.p += count - rd.l -= count - return b, count, nil -} - -func (rd *Reader) readByte() (c byte, n int, err error) { - for rd.l < 1 { - if err := rd.fillBuffer(false); err != nil { - return 0, 0, err - } - } - c = rd.buf[rd.p] - rd.p++ - rd.l-- - return c, 1, nil -} - -func (rd *Reader) unreadByte(c byte) { - if rd.p > 0 { - rd.p-- - rd.l++ - rd.buf[rd.p] = c - return - } - buf := make([]byte, rd.l+1) - buf[0] = c - copy(buf[1:], rd.buf[:rd.l]) - rd.l++ - rd.s = rd.l -} - -func (rd *Reader) fillBuffer(ignoreRebuffering bool) error { - if rd.rerr != nil { - return rd.rerr - } - buf := make([]byte, bufsz) - n, err := rd.rd.Read(buf) - rd.rerr = err - if n > 0 { - if !ignoreRebuffering && rd.l == 0 { - rd.l = n - rd.s = n - rd.p = 0 - rd.buf = buf - } else { - rd.buf = append(rd.buf, buf[:n]...) - rd.s += n - rd.l += n - } - return nil - } - return rd.rerr + return int(i64), n, nil } // AnyValue returns a RESP value from an interface. This function infers the types. Arrays are not allowed. @@ -663,6 +553,10 @@ func MultiBulkValue(commandName string, args ...interface{}) Value { vals := make([]Value, len(args)+1) vals[0] = StringValue(commandName) for i, arg := range args { + if rval, ok := arg.(Value); ok && rval.Type() == BulkString { + vals[i+1] = rval + continue + } switch arg := arg.(type) { default: vals[i+1] = StringValue(fmt.Sprintf("%v", arg)) diff --git a/vendor/github.com/tidwall/resp/resp_test.go b/vendor/github.com/tidwall/resp/resp_test.go index 4e37708b..9ffd7a61 100644 --- a/vendor/github.com/tidwall/resp/resp_test.go +++ b/vendor/github.com/tidwall/resp/resp_test.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/rand" "encoding/binary" + "errors" "fmt" "io" "strconv" @@ -118,6 +119,20 @@ func TestLotsaRandomness(t *testing.T) { if err != nil { t.Fatal(err) } + ts := fmt.Sprintf("%v", v.Type()) + if ts == "Unknown" { + t.Fatal("got 'Unknown'") + } + tvs := fmt.Sprintf("%v %v %v %v %v %v %v %v", + v.String(), v.Float(), v.Integer(), v.Array(), + v.Bool(), v.Bytes(), v.IsNull(), v.Error(), + ) + if len(tvs) < 10 { + t.Fatal("conversion error") + } + if !v.Equals(v) { + t.Fatal("equals failed") + } resp, err := v.MarshalRESP() if err != nil { t.Fatal(err) @@ -127,6 +142,145 @@ func TestLotsaRandomness(t *testing.T) { } } } +func TestBigFragmented(t *testing.T) { + b := make([]byte, 10*1024*1024) + if _, err := rand.Read(b); err != nil { + t.Fatal(err) + } + cmd := []byte("*3\r\n$3\r\nSET\r\n$3\r\nKEY\r\n$" + strconv.FormatInt(int64(len(b)), 10) + "\r\n" + string(b) + "\r\n") + cmdlen := len(cmd) + pr, pw := io.Pipe() + frag := 1024 + go func() { + defer pw.Close() + for len(cmd) >= frag { + if _, err := pw.Write(cmd[:frag]); err != nil { + t.Fatal(err) + } + cmd = cmd[frag:] + } + if len(cmd) > 0 { + if _, err := pw.Write(cmd); err != nil { + t.Fatal(err) + } + } + }() + r := NewReader(pr) + value, telnet, n, err := r.ReadMultiBulk() + if err != nil { + t.Fatal(err) + } + if n != cmdlen { + t.Fatalf("expected %v, got %v", cmdlen, n) + } + if telnet { + t.Fatalf("expected false, got true") + } + arr := value.Array() + if len(arr) != 3 { + t.Fatalf("expected 3, got %v", len(arr)) + } + if arr[0].String() != "SET" { + t.Fatalf("expected 'SET', got %v", arr[0].String()) + } + if arr[1].String() != "KEY" { + t.Fatalf("expected 'KEY', got %v", arr[0].String()) + } + if bytes.Compare(arr[2].Bytes(), b) != 0 { + t.Fatal("bytes not equal") + } +} + +func TestAnyValues(t *testing.T) { + var vs = []interface{}{ + nil, + int(10), uint(10), int8(10), + uint8(10), int16(10), uint16(10), + int32(10), uint32(10), int64(10), + uint64(10), bool(true), bool(false), + float32(10), float64(10), + []byte("hello"), string("hello"), + } + for i, v := range vs { + if AnyValue(v).String() == "" && v != nil { + t.Fatalf("missing string value for #%d: '%v'", i, v) + } + } +} + +func TestMarshalStrangeValue(t *testing.T) { + var v Value + v.null = true + b, err := marshalAnyRESP(v) + if err != nil { + t.Fatal(err) + } + if string(b) != "$-1\r\n" { + t.Fatalf("expected '%v', got '%v'", "$-1\r\n", string(b)) + } + v.null = false + + _, err = marshalAnyRESP(v) + if err == nil || err.Error() != "unknown resp type encountered" { + t.Fatalf("expected '%v', got '%v'", "unknown resp type encountered", err) + } +} + +func TestTelnetReader(t *testing.T) { + rd := NewReader(bytes.NewBufferString("SET HELLO WORLD\r\nGET HELLO\r\n")) + for i := 0; ; i++ { + v, telnet, _, err := rd.ReadMultiBulk() + if err != nil { + if err == io.EOF { + break + } + t.Fatal(err) + } + if !telnet { + t.Fatalf("epxected true") + } + arr := v.Array() + switch i { + default: + t.Fatalf("i is %v, expected 0 or 1", i) + case 0: + if len(arr) != 3 { + t.Fatalf("expected 3, got %v", len(arr)) + } + case 1: + if len(arr) != 2 { + t.Fatalf("expected 2, got %v", len(arr)) + } + } + } +} + +func TestWriter(t *testing.T) { + var buf bytes.Buffer + wr := NewWriter(&buf) + wr.WriteArray(MultiBulkValue("HELLO", 1, 2, 3).Array()) + wr.WriteBytes([]byte("HELLO")) + wr.WriteString("HELLO") + wr.WriteSimpleString("HELLO") + wr.WriteError(errors.New("HELLO")) + wr.WriteInteger(1) + wr.WriteNull() + wr.WriteValue(SimpleStringValue("HELLO")) + + res := "" + + "*4\r\n$5\r\nHELLO\r\n$1\r\n1\r\n$1\r\n2\r\n$1\r\n3\r\n" + + "$5\r\nHELLO\r\n" + + "$5\r\nHELLO\r\n" + + "+HELLO\r\n" + + "-HELLO\r\n" + + ":1\r\n" + + "$-1\r\n" + + "+HELLO\r\n" + if buf.String() != res { + t.Fatalf("expected '%v', got '%v'", res, buf.String()) + } + +} func randRESPInteger() string { return fmt.Sprintf(":%d\r\n", (randInt()%1000000)-500000)