diff --git a/.travis.yml b/.travis.yml index 1dea73b..6ef52f7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,7 +6,7 @@ services: go: - 1.3 - 1.4 - - tip + - 1.5 install: - go get gopkg.in/bsm/ratelimit.v1 diff --git a/Makefile b/Makefile index 1107e5f..1b43765 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,9 @@ all: testdeps - go test ./... -v=1 -cpu=1,2,4 - go test ./... -short -race + go test ./... -test.v -test.cpu=1,2,4 + go test ./... -test.short -test.race test: testdeps - go test ./... -v=1 + go test ./... -test.v=1 testdeps: .test/redis/src/redis-server @@ -11,7 +11,7 @@ testdeps: .test/redis/src/redis-server .test/redis: mkdir -p $@ - wget -qO- https://github.com/antirez/redis/archive/3.0.3.tar.gz | tar xvz --strip-components=1 -C $@ + wget -qO- https://github.com/antirez/redis/archive/unstable.tar.gz | tar xvz --strip-components=1 -C $@ .test/redis/src/redis-server: .test/redis cd $< && make all diff --git a/cluster_pipeline.go b/cluster_pipeline.go index 619ad82..90687a8 100644 --- a/cluster_pipeline.go +++ b/cluster_pipeline.go @@ -100,7 +100,7 @@ func (pipe *ClusterPipeline) execClusterCmds( var firstCmdErr error for i, cmd := range cmds { - err := cmd.parseReply(cn) + err := cmd.readReply(cn) if err == nil { continue } diff --git a/command.go b/command.go index 6c80906..6b4465e 100644 --- a/command.go +++ b/command.go @@ -28,7 +28,7 @@ var ( type Cmder interface { args() []interface{} - parseReply(*conn) error + readReply(*conn) error setErr(error) reset() @@ -152,14 +152,20 @@ func (cmd *Cmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *Cmd) parseReply(cn *conn) error { - cmd.val, cmd.err = parseReply(cn, parseSlice) - // Convert to string to preserve old behaviour. - // TODO: remove in v4 - if v, ok := cmd.val.([]byte); ok { - cmd.val = string(v) +func (cmd *Cmd) readReply(cn *conn) error { + val, err := readReply(cn, sliceParser) + if err != nil { + cmd.err = err + return cmd.err } - return cmd.err + if v, ok := val.([]byte); ok { + // Convert to string to preserve old behaviour. + // TODO: remove in v4 + cmd.val = string(v) + } else { + cmd.val = val + } + return nil } //------------------------------------------------------------------------------ @@ -191,8 +197,8 @@ func (cmd *SliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *SliceCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, parseSlice) +func (cmd *SliceCmd) readReply(cn *conn) error { + v, err := readReply(cn, sliceParser) if err != nil { cmd.err = err return err @@ -234,8 +240,8 @@ func (cmd *StatusCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StatusCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, nil) +func (cmd *StatusCmd) readReply(cn *conn) error { + v, err := readReply(cn, nil) if err != nil { cmd.err = err return err @@ -273,8 +279,8 @@ func (cmd *IntCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *IntCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, nil) +func (cmd *IntCmd) readReply(cn *conn) error { + v, err := readReply(cn, nil) if err != nil { cmd.err = err return err @@ -316,8 +322,8 @@ func (cmd *DurationCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *DurationCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, nil) +func (cmd *DurationCmd) readReply(cn *conn) error { + v, err := readReply(cn, nil) if err != nil { cmd.err = err return err @@ -357,8 +363,8 @@ func (cmd *BoolCmd) String() string { var ok = []byte("OK") -func (cmd *BoolCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, nil) +func (cmd *BoolCmd) readReply(cn *conn) error { + v, err := readReply(cn, nil) // `SET key value NX` returns nil when key already exists, which // is inconsistent with `SETNX key value`. // TODO: is this okay? @@ -443,8 +449,8 @@ func (cmd *StringCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StringCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, nil) +func (cmd *StringCmd) readReply(cn *conn) error { + v, err := readReply(cn, nil) if err != nil { cmd.err = err return err @@ -482,8 +488,8 @@ func (cmd *FloatCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *FloatCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, nil) +func (cmd *FloatCmd) readReply(cn *conn) error { + v, err := readReply(cn, nil) if err != nil { cmd.err = err return err @@ -522,8 +528,8 @@ func (cmd *StringSliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StringSliceCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, parseStringSlice) +func (cmd *StringSliceCmd) readReply(cn *conn) error { + v, err := readReply(cn, stringSliceParser) if err != nil { cmd.err = err return err @@ -561,8 +567,8 @@ func (cmd *BoolSliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *BoolSliceCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, parseBoolSlice) +func (cmd *BoolSliceCmd) readReply(cn *conn) error { + v, err := readReply(cn, boolSliceParser) if err != nil { cmd.err = err return err @@ -600,8 +606,8 @@ func (cmd *StringStringMapCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StringStringMapCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, parseStringStringMap) +func (cmd *StringStringMapCmd) readReply(cn *conn) error { + v, err := readReply(cn, stringStringMapParser) if err != nil { cmd.err = err return err @@ -639,8 +645,8 @@ func (cmd *StringIntMapCmd) reset() { cmd.err = nil } -func (cmd *StringIntMapCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, parseStringIntMap) +func (cmd *StringIntMapCmd) readReply(cn *conn) error { + v, err := readReply(cn, stringIntMapParser) if err != nil { cmd.err = err return err @@ -678,8 +684,8 @@ func (cmd *ZSliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *ZSliceCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, parseZSlice) +func (cmd *ZSliceCmd) readReply(cn *conn) error { + v, err := readReply(cn, zSliceParser) if err != nil { cmd.err = err return err @@ -719,8 +725,8 @@ func (cmd *ScanCmd) String() string { return cmdString(cmd, cmd.keys) } -func (cmd *ScanCmd) parseReply(cn *conn) error { - vi, err := parseReply(cn, parseSlice) +func (cmd *ScanCmd) readReply(cn *conn) error { + vi, err := readReply(cn, sliceParser) if err != nil { cmd.err = err return cmd.err @@ -743,8 +749,9 @@ func (cmd *ScanCmd) parseReply(cn *conn) error { //------------------------------------------------------------------------------ type ClusterSlotInfo struct { - Start, End int - Addrs []string + Start int + End int + Addrs []string } type ClusterSlotCmd struct { @@ -774,8 +781,8 @@ func (cmd *ClusterSlotCmd) reset() { cmd.err = nil } -func (cmd *ClusterSlotCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, parseClusterSlotInfoSlice) +func (cmd *ClusterSlotCmd) readReply(cn *conn) error { + v, err := readReply(cn, clusterSlotInfoSliceParser) if err != nil { cmd.err = err return err @@ -783,3 +790,65 @@ func (cmd *ClusterSlotCmd) parseReply(cn *conn) error { cmd.val = v.([]ClusterSlotInfo) return nil } + +//------------------------------------------------------------------------------ + +// GeoLocation is used with GeoAdd to add geospatial location. +type GeoLocation struct { + Name string + Longitude, Latitude, Distance float64 + GeoHash int64 +} + +// GeoRadiusQuery is used with GeoRadius to query geospatial index. +type GeoRadiusQuery struct { + Key string + Longitude float64 + Latitude float64 + Radius float64 + // Can be m, km, ft, or mi. Default is km. + Unit string + WithCoordinates bool + WithDistance bool + WithGeoHash bool + Count int + // Can be ASC or DESC. Default is no sort order. + Sort string +} + +type GeoLocationCmd struct { + baseCmd + + locations []GeoLocation +} + +func NewGeoLocationCmd(args ...interface{}) *GeoLocationCmd { + return &GeoLocationCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} +} + +func (cmd *GeoLocationCmd) reset() { + cmd.locations = nil + cmd.err = nil +} + +func (cmd *GeoLocationCmd) Val() []GeoLocation { + return cmd.locations +} + +func (cmd *GeoLocationCmd) Result() ([]GeoLocation, error) { + return cmd.locations, cmd.err +} + +func (cmd *GeoLocationCmd) String() string { + return cmdString(cmd, cmd.locations) +} + +func (cmd *GeoLocationCmd) readReply(cn *conn) error { + reply, err := readReply(cn, geoLocationSliceParser) + if err != nil { + cmd.err = err + return err + } + cmd.locations = reply.([]GeoLocation) + return nil +} diff --git a/commands.go b/commands.go index 44a77d0..1916aec 100644 --- a/commands.go +++ b/commands.go @@ -1671,3 +1671,52 @@ func (c *commandable) ClusterAddSlotsRange(min, max int) *StatusCmd { } return c.ClusterAddSlots(slots...) } + +//------------------------------------------------------------------------------ + +func (c *commandable) GeoAdd(key string, geoLocation ...*GeoLocation) *IntCmd { + args := make([]interface{}, 2+3*len(geoLocation)) + args[0] = "GEOADD" + args[1] = key + for i, eachLoc := range geoLocation { + args[2+3*i] = eachLoc.Longitude + args[2+3*i+1] = eachLoc.Latitude + args[2+3*i+2] = eachLoc.Name + } + cmd := NewIntCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) GeoRadius(query *GeoRadiusQuery) *GeoLocationCmd { + args := make([]interface{}, 6) + args[0] = "GEORADIUS" + args[1] = query.Key + args[2] = query.Longitude + args[3] = query.Latitude + args[4] = query.Radius + if query.Unit != "" { + args[5] = query.Unit + } else { + args[5] = "km" + } + if query.WithCoordinates { + args = append(args, "WITHCOORD") + } + if query.WithDistance { + args = append(args, "WITHDIST") + } + if query.WithGeoHash { + args = append(args, "WITHHASH") + } + if query.Count > 0 { + args = append(args, "COUNT", query.Count) + } + if query.Sort != "" { + args = append(args, query.Sort) + } + + cmd := NewGeoLocationCmd(args...) + c.Process(cmd) + return cmd +} diff --git a/commands_test.go b/commands_test.go index 448e042..f10cfc3 100644 --- a/commands_test.go +++ b/commands_test.go @@ -193,7 +193,7 @@ var _ = Describe("Commands", func() { dump := client.Dump("key") Expect(dump.Err()).NotTo(HaveOccurred()) - Expect(dump.Val()).To(Equal("\x00\x05hello\x06\x00\xf5\x9f\xb7\xf6\x90a\x1c\x99")) + Expect(dump.Val()).NotTo(BeEmpty()) }) It("should Exists", func() { @@ -2521,6 +2521,66 @@ var _ = Describe("Commands", func() { }) + Describe("Geo add and radius search", func() { + It("should add one geo location", func() { + geoAdd := client.GeoAdd("Sicily", &redis.GeoLocation{Longitude: 13.361389, Latitude: 38.115556, Name: "Palermo"}) + Expect(geoAdd.Err()).NotTo(HaveOccurred()) + Expect(geoAdd.Val()).To(Equal(int64(1))) + }) + + It("should add multiple geo locations", func() { + geoAdd := client.GeoAdd("Sicily", &redis.GeoLocation{Longitude: 13.361389, Latitude: 38.115556, Name: "Palermo"}, + &redis.GeoLocation{Longitude: 15.087269, Latitude: 37.502669, Name: "Catania"}) + Expect(geoAdd.Err()).NotTo(HaveOccurred()) + Expect(geoAdd.Val()).To(Equal(int64(2))) + }) + + It("should search geo radius", func() { + geoAdd := client.GeoAdd("Sicily", &redis.GeoLocation{Longitude: 13.361389, Latitude: 38.115556, Name: "Palermo"}, + &redis.GeoLocation{Longitude: 15.087269, Latitude: 37.502669, Name: "Catania"}) + Expect(geoAdd.Err()).NotTo(HaveOccurred()) + Expect(geoAdd.Val()).To(Equal(int64(2))) + + geoRadius := client.GeoRadius(&redis.GeoRadiusQuery{Key: "Sicily", Longitude: 15, Latitude: 37, Radius: 200}) + Expect(geoRadius.Err()).NotTo(HaveOccurred()) + Expect(geoRadius.Val()[0].Name).To(Equal("Palermo")) + Expect(geoRadius.Val()[1].Name).To(Equal("Catania")) + }) + + It("should search geo radius with options", func() { + locations := []*redis.GeoLocation{&redis.GeoLocation{Longitude: 13.361389, Latitude: 38.115556, Name: "Palermo"}, + &redis.GeoLocation{Longitude: 15.087269, Latitude: 37.502669, Name: "Catania"}} + + geoAdd := client.GeoAdd("Sicily", locations...) + Expect(geoAdd.Err()).NotTo(HaveOccurred()) + Expect(geoAdd.Val()).To(Equal(int64(2))) + + geoRadius := client.GeoRadius(&redis.GeoRadiusQuery{Key: "Sicily", Longitude: 15, Latitude: 37, Radius: 200, Unit: "km", WithGeoHash: true, WithCoordinates: true, WithDistance: true, Count: 2, Sort: "ASC"}) + Expect(geoRadius.Err()).NotTo(HaveOccurred()) + Expect(geoRadius.Val()[1].Name).To(Equal("Palermo")) + Expect(geoRadius.Val()[1].Distance).To(Equal(190.4424)) + Expect(geoRadius.Val()[1].GeoHash).To(Equal(int64(3479099956230698))) + Expect(geoRadius.Val()[1].Longitude).To(Equal(13.361389338970184)) + Expect(geoRadius.Val()[1].Latitude).To(Equal(38.115556395496299)) + Expect(geoRadius.Val()[0].Name).To(Equal("Catania")) + Expect(geoRadius.Val()[0].Distance).To(Equal(56.4413)) + Expect(geoRadius.Val()[0].GeoHash).To(Equal(int64(3479447370796909))) + Expect(geoRadius.Val()[0].Longitude).To(Equal(15.087267458438873)) + Expect(geoRadius.Val()[0].Latitude).To(Equal(37.50266842333162)) + }) + + It("should search geo radius with no results", func() { + geoAdd := client.GeoAdd("Sicily", &redis.GeoLocation{Longitude: 13.361389, Latitude: 38.115556, Name: "Palermo"}, + &redis.GeoLocation{Longitude: 15.087269, Latitude: 37.502669, Name: "Catania"}) + Expect(geoAdd.Err()).NotTo(HaveOccurred()) + Expect(geoAdd.Val()).To(Equal(int64(2))) + + geoRadius := client.GeoRadius(&redis.GeoRadiusQuery{Key: "Sicily", Longitude: 99, Latitude: 37, Radius: 200, Unit: "km", WithGeoHash: true, WithCoordinates: true, WithDistance: true}) + Expect(geoRadius.Err()).NotTo(HaveOccurred()) + Expect(len(geoRadius.Val())).To(Equal(0)) + }) + }) + Describe("marshaling/unmarshaling", func() { type convTest struct { diff --git a/multi.go b/multi.go index 7b55c7b..e3d628f 100644 --- a/multi.go +++ b/multi.go @@ -116,7 +116,7 @@ func (c *Multi) execCmds(cn *conn, cmds []Cmder) error { // Parse queued replies. for i := 0; i < cmdsLen; i++ { - if err := statusCmd.parseReply(cn); err != nil { + if err := statusCmd.readReply(cn); err != nil { setCmdsErr(cmds[1:len(cmds)-1], err) return err } @@ -144,7 +144,7 @@ func (c *Multi) execCmds(cn *conn, cmds []Cmder) error { // Loop starts from 1 to omit MULTI cmd. for i := 1; i < cmdsLen; i++ { cmd := cmds[i] - if err := cmd.parseReply(cn); err != nil { + if err := cmd.readReply(cn); err != nil { if firstCmdErr == nil { firstCmdErr = err } diff --git a/parser.go b/parser.go index 5b6e073..9acb0f1 100644 --- a/parser.go +++ b/parser.go @@ -8,6 +8,14 @@ import ( "strconv" ) +const ( + errorReply = '-' + statusReply = '+' + intReply = ':' + stringReply = '$' + arrayReply = '*' +) + type multiBulkParser func(cn *conn, n int64) (interface{}, error) var ( @@ -239,57 +247,157 @@ func readN(cn *conn, n int) ([]byte, error) { //------------------------------------------------------------------------------ -func parseReply(cn *conn, p multiBulkParser) (interface{}, error) { +func parseErrorReply(cn *conn, line []byte) error { + return errorf(string(line[1:])) +} + +func parseIntReply(cn *conn, line []byte) (int64, error) { + n, err := strconv.ParseInt(bytesToString(line[1:]), 10, 64) + if err != nil { + return 0, err + } + return n, nil +} + +func readIntReply(cn *conn) (int64, error) { + line, err := readLine(cn) + if err != nil { + return 0, err + } + switch line[0] { + case errorReply: + return 0, parseErrorReply(cn, line) + case intReply: + return parseIntReply(cn, line) + default: + return 0, fmt.Errorf("readIntReply: can't parse %.100q", line) + } +} + +func parseBytesReply(cn *conn, line []byte) ([]byte, error) { + if len(line) == 3 && line[1] == '-' && line[2] == '1' { + return nil, Nil + } + + replyLen, err := strconv.Atoi(bytesToString(line[1:])) + if err != nil { + return nil, err + } + + b, err := readN(cn, replyLen+2) + if err != nil { + return nil, err + } + + return b[:replyLen], nil +} + +func readBytesReply(cn *conn) ([]byte, error) { + line, err := readLine(cn) + if err != nil { + return nil, err + } + switch line[0] { + case errorReply: + return nil, parseErrorReply(cn, line) + case stringReply: + return parseBytesReply(cn, line) + default: + return nil, fmt.Errorf("readBytesReply: can't parse %.100q", line) + } +} + +func readStringReply(cn *conn) (string, error) { + b, err := readBytesReply(cn) + if err != nil { + return "", err + } + return string(b), nil +} + +func readFloatReply(cn *conn) (float64, error) { + b, err := readBytesReply(cn) + if err != nil { + return 0, err + } + return strconv.ParseFloat(bytesToString(b), 64) +} + +func parseArrayHeader(cn *conn, line []byte) (int64, error) { + if len(line) == 3 && line[1] == '-' && line[2] == '1' { + return 0, Nil + } + + n, err := strconv.ParseInt(bytesToString(line[1:]), 10, 64) + if err != nil { + return 0, err + } + return n, nil +} + +func parseArrayReply(cn *conn, p multiBulkParser, line []byte) (interface{}, error) { + n, err := parseArrayHeader(cn, line) + if err != nil { + return nil, err + } + return p(cn, n) +} + +func readArrayHeader(cn *conn) (int64, error) { + line, err := readLine(cn) + if err != nil { + return 0, err + } + switch line[0] { + case errorReply: + return 0, parseErrorReply(cn, line) + case arrayReply: + return parseArrayHeader(cn, line) + default: + return 0, fmt.Errorf("readArrayReply: can't parse %.100q", line) + } +} + +func readArrayReply(cn *conn, p multiBulkParser) (interface{}, error) { + line, err := readLine(cn) + if err != nil { + return nil, err + } + switch line[0] { + case errorReply: + return nil, parseErrorReply(cn, line) + case arrayReply: + return parseArrayReply(cn, p, line) + default: + return nil, fmt.Errorf("readArrayReply: can't parse %.100q", line) + } +} + +func readReply(cn *conn, p multiBulkParser) (interface{}, error) { line, err := readLine(cn) if err != nil { return nil, err } switch line[0] { - case '-': - return nil, errorf(string(line[1:])) - case '+': + case errorReply: + return nil, parseErrorReply(cn, line) + case statusReply: return line[1:], nil - case ':': - v, err := strconv.ParseInt(bytesToString(line[1:]), 10, 64) - if err != nil { - return nil, err - } - return v, nil - case '$': - if len(line) == 3 && line[1] == '-' && line[2] == '1' { - return nil, Nil - } - - replyLen, err := strconv.Atoi(bytesToString(line[1:])) - if err != nil { - return nil, err - } - - b, err := readN(cn, replyLen+2) - if err != nil { - return nil, err - } - return b[:replyLen], nil - case '*': - if len(line) == 3 && line[1] == '-' && line[2] == '1' { - return nil, Nil - } - - repliesNum, err := strconv.ParseInt(bytesToString(line[1:]), 10, 64) - if err != nil { - return nil, err - } - - return p(cn, repliesNum) + case intReply: + return parseIntReply(cn, line) + case stringReply: + return parseBytesReply(cn, line) + case arrayReply: + return parseArrayReply(cn, p, line) } return nil, fmt.Errorf("redis: can't parse %q", line) } -func parseSlice(cn *conn, n int64) (interface{}, error) { +func sliceParser(cn *conn, n int64) (interface{}, error) { vals := make([]interface{}, 0, n) for i := int64(0); i < n; i++ { - v, err := parseReply(cn, parseSlice) + v, err := readReply(cn, sliceParser) if err == Nil { vals = append(vals, nil) } else if err != nil { @@ -306,171 +414,224 @@ func parseSlice(cn *conn, n int64) (interface{}, error) { return vals, nil } -func parseStringSlice(cn *conn, n int64) (interface{}, error) { - vals := make([]string, 0, n) +func intSliceParser(cn *conn, n int64) (interface{}, error) { + ints := make([]int64, 0, n) for i := int64(0); i < n; i++ { - viface, err := parseReply(cn, nil) + n, err := readIntReply(cn) if err != nil { return nil, err } - v, ok := viface.([]byte) - if !ok { - return nil, fmt.Errorf("got %T, expected string", viface) - } - vals = append(vals, string(v)) + ints = append(ints, n) } - return vals, nil + return ints, nil } -func parseBoolSlice(cn *conn, n int64) (interface{}, error) { - vals := make([]bool, 0, n) +func boolSliceParser(cn *conn, n int64) (interface{}, error) { + bools := make([]bool, 0, n) for i := int64(0); i < n; i++ { - viface, err := parseReply(cn, nil) + n, err := readIntReply(cn) if err != nil { return nil, err } - v, ok := viface.(int64) - if !ok { - return nil, fmt.Errorf("got %T, expected int64", viface) - } - vals = append(vals, v == 1) + bools = append(bools, n == 1) } - return vals, nil + return bools, nil } -func parseStringStringMap(cn *conn, n int64) (interface{}, error) { +func stringSliceParser(cn *conn, n int64) (interface{}, error) { + ss := make([]string, 0, n) + for i := int64(0); i < n; i++ { + s, err := readStringReply(cn) + if err != nil { + return nil, err + } + ss = append(ss, s) + } + return ss, nil +} + +func floatSliceParser(cn *conn, n int64) (interface{}, error) { + nn := make([]float64, 0, n) + for i := int64(0); i < n; i++ { + n, err := readFloatReply(cn) + if err != nil { + return nil, err + } + nn = append(nn, n) + } + return nn, nil +} + +func stringStringMapParser(cn *conn, n int64) (interface{}, error) { m := make(map[string]string, n/2) for i := int64(0); i < n; i += 2 { - keyIface, err := parseReply(cn, nil) + key, err := readStringReply(cn) if err != nil { return nil, err } - keyBytes, ok := keyIface.([]byte) - if !ok { - return nil, fmt.Errorf("got %T, expected []byte", keyIface) - } - key := string(keyBytes) - valueIface, err := parseReply(cn, nil) + value, err := readStringReply(cn) if err != nil { return nil, err } - valueBytes, ok := valueIface.([]byte) - if !ok { - return nil, fmt.Errorf("got %T, expected []byte", valueIface) - } - m[key] = string(valueBytes) + m[key] = value } return m, nil } -func parseStringIntMap(cn *conn, n int64) (interface{}, error) { +func stringIntMapParser(cn *conn, n int64) (interface{}, error) { m := make(map[string]int64, n/2) for i := int64(0); i < n; i += 2 { - keyiface, err := parseReply(cn, nil) + key, err := readStringReply(cn) if err != nil { return nil, err } - key, ok := keyiface.([]byte) - if !ok { - return nil, fmt.Errorf("got %T, expected string", keyiface) - } - valueiface, err := parseReply(cn, nil) + n, err := readIntReply(cn) if err != nil { return nil, err } - switch value := valueiface.(type) { - case int64: - m[string(key)] = value - case string: - m[string(key)], err = strconv.ParseInt(value, 10, 64) - if err != nil { - return nil, fmt.Errorf("got %v, expected number", value) - } - default: - return nil, fmt.Errorf("got %T, expected number or string", valueiface) - } + + m[key] = n } return m, nil } -func parseZSlice(cn *conn, n int64) (interface{}, error) { +func zSliceParser(cn *conn, n int64) (interface{}, error) { zz := make([]Z, n/2) for i := int64(0); i < n; i += 2 { + var err error + z := &zz[i/2] - memberiface, err := parseReply(cn, nil) + z.Member, err = readStringReply(cn) if err != nil { return nil, err } - member, ok := memberiface.([]byte) - if !ok { - return nil, fmt.Errorf("got %T, expected string", memberiface) - } - z.Member = string(member) - scoreiface, err := parseReply(cn, nil) + z.Score, err = readFloatReply(cn) if err != nil { return nil, err } - scoreb, ok := scoreiface.([]byte) - if !ok { - return nil, fmt.Errorf("got %T, expected string", scoreiface) - } - score, err := strconv.ParseFloat(bytesToString(scoreb), 64) - if err != nil { - return nil, err - } - z.Score = score } return zz, nil } -func parseClusterSlotInfoSlice(cn *conn, n int64) (interface{}, error) { +func clusterSlotInfoSliceParser(cn *conn, n int64) (interface{}, error) { infos := make([]ClusterSlotInfo, 0, n) for i := int64(0); i < n; i++ { - viface, err := parseReply(cn, parseSlice) + n, err := readArrayHeader(cn) + if err != nil { + return nil, err + } + if n < 2 { + return nil, fmt.Errorf("got %d elements in cluster info, expected at least 2", n) + } + + start, err := readIntReply(cn) if err != nil { return nil, err } - item, ok := viface.([]interface{}) - if !ok { - return nil, fmt.Errorf("got %T, expected []interface{}", viface) - } else if len(item) < 3 { - return nil, fmt.Errorf("got %v, expected {int64, int64, string...}", item) + end, err := readIntReply(cn) + if err != nil { + return nil, err } - start, ok := item[0].(int64) - if !ok || start < 0 || start > hashSlots { - return nil, fmt.Errorf("got %v, expected {int64, int64, string...}", item) - } - end, ok := item[1].(int64) - if !ok || end < 0 || end > hashSlots { - return nil, fmt.Errorf("got %v, expected {int64, int64, string...}", item) + addrsn := n - 2 + info := ClusterSlotInfo{ + Start: int(start), + End: int(end), + Addrs: make([]string, addrsn), } - info := ClusterSlotInfo{int(start), int(end), make([]string, len(item)-2)} - for n, ipair := range item[2:] { - pair, ok := ipair.([]interface{}) - if !ok || len(pair) != 2 { - return nil, fmt.Errorf("got %v, expected []interface{host, port}", viface) + for i := int64(0); i < addrsn; i++ { + n, err := readArrayHeader(cn) + if err != nil { + return nil, err + } + if n != 2 { + return nil, fmt.Errorf("got %d elements in cluster info address, expected 2", n) } - ip, ok := pair[0].(string) - if !ok || len(ip) < 1 { - return nil, fmt.Errorf("got %v, expected IP PORT pair", pair) - } - port, ok := pair[1].(int64) - if !ok || port < 1 { - return nil, fmt.Errorf("got %v, expected IP PORT pair", pair) + ip, err := readStringReply(cn) + if err != nil { + return nil, err } - info.Addrs[n] = net.JoinHostPort(ip, strconv.FormatInt(port, 10)) + port, err := readIntReply(cn) + if err != nil { + return nil, err + } + + info.Addrs[i] = net.JoinHostPort(ip, strconv.FormatInt(port, 10)) } + infos = append(infos, info) } return infos, nil } + +func geoLocationParser(cn *conn, n int64) (interface{}, error) { + loc := &GeoLocation{} + + var err error + loc.Name, err = readStringReply(cn) + if err != nil { + return nil, err + } + if n >= 2 { + loc.Distance, err = readFloatReply(cn) + if err != nil { + return nil, err + } + } + if n >= 3 { + loc.GeoHash, err = readIntReply(cn) + if err != nil { + return nil, err + } + } + if n >= 4 { + n, err := readArrayHeader(cn) + if err != nil { + return nil, err + } + if n != 2 { + return nil, fmt.Errorf("got %d coordinates, expected 2", n) + } + + loc.Longitude, err = readFloatReply(cn) + if err != nil { + return nil, err + } + loc.Latitude, err = readFloatReply(cn) + if err != nil { + return nil, err + } + } + + return loc, nil +} + +func geoLocationSliceParser(cn *conn, n int64) (interface{}, error) { + locs := make([]GeoLocation, 0, n) + for i := int64(0); i < n; i++ { + v, err := readReply(cn, geoLocationParser) + if err != nil { + return nil, err + } + switch vv := v.(type) { + case []byte: + locs = append(locs, GeoLocation{ + Name: string(vv), + }) + case *GeoLocation: + locs = append(locs, *vv) + default: + return nil, fmt.Errorf("got %T, expected string or *GeoLocation", v) + } + } + return locs, nil +} diff --git a/parser_test.go b/parser_test.go index 10403f6..b1c7434 100644 --- a/parser_test.go +++ b/parser_test.go @@ -23,7 +23,7 @@ func BenchmarkParseReplyString(b *testing.B) { } func BenchmarkParseReplySlice(b *testing.B) { - benchmarkParseReply(b, "*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", parseSlice, false) + benchmarkParseReply(b, "*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", sliceParser, false) } func benchmarkParseReply(b *testing.B, reply string, p multiBulkParser, wanterr bool) { @@ -39,7 +39,7 @@ func benchmarkParseReply(b *testing.B, reply string, p multiBulkParser, wanterr b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := parseReply(cn, p) + _, err := readReply(cn, p) if !wanterr && err != nil { b.Fatal(err) } diff --git a/pipeline.go b/pipeline.go index 02ecbac..d7d1304 100644 --- a/pipeline.go +++ b/pipeline.go @@ -97,7 +97,7 @@ func execCmds(cn *conn, cmds []Cmder) ([]Cmder, error) { var firstCmdErr error var failedCmds []Cmder for _, cmd := range cmds { - err := cmd.parseReply(cn) + err := cmd.readReply(cn) if err == nil { continue } diff --git a/pubsub.go b/pubsub.go index fa804eb..ba053e4 100644 --- a/pubsub.go +++ b/pubsub.go @@ -215,7 +215,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { cn.ReadTimeout = timeout cmd := NewSliceCmd() - if err := cmd.parseReply(cn); err != nil { + if err := cmd.readReply(cn); err != nil { return nil, err } return newMessage(cmd.Val()) diff --git a/redis.go b/redis.go index aea53a2..2d1076d 100644 --- a/redis.go +++ b/redis.go @@ -69,7 +69,7 @@ func (c *baseClient) process(cmd Cmder) { return } - err = cmd.parseReply(cn) + err = cmd.readReply(cn) c.putConn(cn, err) if shouldRetry(err) { continue