diff --git a/cluster_pipeline.go b/cluster_pipeline.go index 619ad82f..90687a8b 100644 --- a/cluster_pipeline.go +++ b/cluster_pipeline.go @@ -100,7 +100,7 @@ func (pipe *ClusterPipeline) execClusterCmds( var firstCmdErr error for i, cmd := range cmds { - err := cmd.parseReply(cn) + err := cmd.readReply(cn) if err == nil { continue } diff --git a/command.go b/command.go index 8bae5f4b..6b4465e8 100644 --- a/command.go +++ b/command.go @@ -28,7 +28,7 @@ var ( type Cmder interface { args() []interface{} - parseReply(*conn) error + readReply(*conn) error setErr(error) reset() @@ -152,14 +152,20 @@ func (cmd *Cmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *Cmd) parseReply(cn *conn) error { - cmd.val, cmd.err = parseReply(cn, parseSlice) - // Convert to string to preserve old behaviour. - // TODO: remove in v4 - if v, ok := cmd.val.([]byte); ok { - cmd.val = string(v) +func (cmd *Cmd) readReply(cn *conn) error { + val, err := readReply(cn, sliceParser) + if err != nil { + cmd.err = err + return cmd.err } - return cmd.err + if v, ok := val.([]byte); ok { + // Convert to string to preserve old behaviour. + // TODO: remove in v4 + cmd.val = string(v) + } else { + cmd.val = val + } + return nil } //------------------------------------------------------------------------------ @@ -191,8 +197,8 @@ func (cmd *SliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *SliceCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, parseSlice) +func (cmd *SliceCmd) readReply(cn *conn) error { + v, err := readReply(cn, sliceParser) if err != nil { cmd.err = err return err @@ -234,8 +240,8 @@ func (cmd *StatusCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StatusCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, nil) +func (cmd *StatusCmd) readReply(cn *conn) error { + v, err := readReply(cn, nil) if err != nil { cmd.err = err return err @@ -273,8 +279,8 @@ func (cmd *IntCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *IntCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, nil) +func (cmd *IntCmd) readReply(cn *conn) error { + v, err := readReply(cn, nil) if err != nil { cmd.err = err return err @@ -316,8 +322,8 @@ func (cmd *DurationCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *DurationCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, nil) +func (cmd *DurationCmd) readReply(cn *conn) error { + v, err := readReply(cn, nil) if err != nil { cmd.err = err return err @@ -357,8 +363,8 @@ func (cmd *BoolCmd) String() string { var ok = []byte("OK") -func (cmd *BoolCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, nil) +func (cmd *BoolCmd) readReply(cn *conn) error { + v, err := readReply(cn, nil) // `SET key value NX` returns nil when key already exists, which // is inconsistent with `SETNX key value`. // TODO: is this okay? @@ -443,8 +449,8 @@ func (cmd *StringCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StringCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, nil) +func (cmd *StringCmd) readReply(cn *conn) error { + v, err := readReply(cn, nil) if err != nil { cmd.err = err return err @@ -482,8 +488,8 @@ func (cmd *FloatCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *FloatCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, nil) +func (cmd *FloatCmd) readReply(cn *conn) error { + v, err := readReply(cn, nil) if err != nil { cmd.err = err return err @@ -522,8 +528,8 @@ func (cmd *StringSliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StringSliceCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, parseStringSlice) +func (cmd *StringSliceCmd) readReply(cn *conn) error { + v, err := readReply(cn, stringSliceParser) if err != nil { cmd.err = err return err @@ -561,8 +567,8 @@ func (cmd *BoolSliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *BoolSliceCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, parseBoolSlice) +func (cmd *BoolSliceCmd) readReply(cn *conn) error { + v, err := readReply(cn, boolSliceParser) if err != nil { cmd.err = err return err @@ -600,8 +606,8 @@ func (cmd *StringStringMapCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StringStringMapCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, parseStringStringMap) +func (cmd *StringStringMapCmd) readReply(cn *conn) error { + v, err := readReply(cn, stringStringMapParser) if err != nil { cmd.err = err return err @@ -639,8 +645,8 @@ func (cmd *StringIntMapCmd) reset() { cmd.err = nil } -func (cmd *StringIntMapCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, parseStringIntMap) +func (cmd *StringIntMapCmd) readReply(cn *conn) error { + v, err := readReply(cn, stringIntMapParser) if err != nil { cmd.err = err return err @@ -678,8 +684,8 @@ func (cmd *ZSliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *ZSliceCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, parseZSlice) +func (cmd *ZSliceCmd) readReply(cn *conn) error { + v, err := readReply(cn, zSliceParser) if err != nil { cmd.err = err return err @@ -719,8 +725,8 @@ func (cmd *ScanCmd) String() string { return cmdString(cmd, cmd.keys) } -func (cmd *ScanCmd) parseReply(cn *conn) error { - vi, err := parseReply(cn, parseSlice) +func (cmd *ScanCmd) readReply(cn *conn) error { + vi, err := readReply(cn, sliceParser) if err != nil { cmd.err = err return cmd.err @@ -743,8 +749,9 @@ func (cmd *ScanCmd) parseReply(cn *conn) error { //------------------------------------------------------------------------------ type ClusterSlotInfo struct { - Start, End int - Addrs []string + Start int + End int + Addrs []string } type ClusterSlotCmd struct { @@ -774,8 +781,8 @@ func (cmd *ClusterSlotCmd) reset() { cmd.err = nil } -func (cmd *ClusterSlotCmd) parseReply(cn *conn) error { - v, err := parseReply(cn, parseClusterSlotInfoSlice) +func (cmd *ClusterSlotCmd) readReply(cn *conn) error { + v, err := readReply(cn, clusterSlotInfoSliceParser) if err != nil { cmd.err = err return err @@ -786,99 +793,62 @@ func (cmd *ClusterSlotCmd) parseReply(cn *conn) error { //------------------------------------------------------------------------------ -// Location type for GEO operations in Redis +// GeoLocation is used with GeoAdd to add geospatial location. type GeoLocation struct { - Name string + Name string Longitude, Latitude, Distance float64 - GeoHash int64 + GeoHash int64 } -type GeoCmd struct { +// GeoRadiusQuery is used with GeoRadius to query geospatial index. +type GeoRadiusQuery struct { + Key string + Longitude float64 + Latitude float64 + Radius float64 + // Can be m, km, ft, or mi. Default is km. + Unit string + WithCoordinates bool + WithDistance bool + WithGeoHash bool + Count int + // Can be ASC or DESC. Default is no sort order. + Sort string +} + +type GeoLocationCmd struct { baseCmd locations []GeoLocation } -// Query type for geo radius -type GeoRadiusQuery struct { - Key string - Longitude, Latitude, Radius float64 - // Unit default to km when nil - Unit string - WithCoordinates, WithDistance, WithGeoHash bool - // Count default to 0 and ignored limit. - Count int - // Sort default to unsorted, ASC or DESC otherwise - Sort string +func NewGeoLocationCmd(args ...interface{}) *GeoLocationCmd { + return &GeoLocationCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} } -func NewGeoCmd(args ...interface{}) *GeoCmd { - return &GeoCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} -} - -func (cmd *GeoCmd) reset() { +func (cmd *GeoLocationCmd) reset() { cmd.locations = nil cmd.err = nil } -func (cmd *GeoCmd) Val() ([]GeoLocation) { +func (cmd *GeoLocationCmd) Val() []GeoLocation { return cmd.locations } -func (cmd *GeoCmd) Result() ([]GeoLocation, error) { +func (cmd *GeoLocationCmd) Result() ([]GeoLocation, error) { return cmd.locations, cmd.err } -func (cmd *GeoCmd) String() string { +func (cmd *GeoLocationCmd) String() string { return cmdString(cmd, cmd.locations) } -func (cmd *GeoCmd) parseReply(cn *conn) error { - vi, err := parseReply(cn, parseSlice) +func (cmd *GeoLocationCmd) readReply(cn *conn) error { + reply, err := readReply(cn, geoLocationSliceParser) if err != nil { cmd.err = err - return cmd.err - } - - v := vi.([]interface{}) - - if len(v) == 0 { - return nil - } - - if _, ok := v[0].(string); ok { // Location names only (single level string array) - for _, keyi := range v { - cmd.locations = append(cmd.locations, GeoLocation{Name: keyi.(string)}) - } - } else { // Full location details (nested arrays) - for _, keyi := range v { - tmpLocation := GeoLocation{} - keyiface := keyi.([]interface{}) - for _, subKeyi := range keyiface { - if strVal, ok := subKeyi.(string); ok { - if len(tmpLocation.Name) == 0 { - tmpLocation.Name = strVal - } else { - tmpLocation.Distance, err = strconv.ParseFloat(strVal, 64) - if err != nil { - return err - } - } - } else if intVal, ok := subKeyi.(int64); ok { - tmpLocation.GeoHash = intVal - } else if ifcVal, ok := subKeyi.([]interface{}); ok { - tmpLocation.Longitude, err = strconv.ParseFloat(ifcVal[0].(string), 64) - if err != nil { - return err - } - tmpLocation.Latitude, err = strconv.ParseFloat(ifcVal[1].(string), 64) - if err != nil { - return err - } - } - } - cmd.locations = append(cmd.locations, tmpLocation) - } + return err } + cmd.locations = reply.([]GeoLocation) return nil } diff --git a/commands.go b/commands.go index 9e8abd94..1916aecf 100644 --- a/commands.go +++ b/commands.go @@ -1688,25 +1688,8 @@ func (c *commandable) GeoAdd(key string, geoLocation ...*GeoLocation) *IntCmd { return cmd } -func (c *commandable) GeoRadius(query *GeoRadiusQuery) *GeoCmd { - var options, optionsCtr int - if query.WithCoordinates { - options++ - } - if query.WithDistance { - options++ - } - if query.WithGeoHash { - options++ - } - if query.Count > 0 { - options += 2 - } - if query.Sort != "" { - options++ - } - - args := make([]interface{}, 6 + options) +func (c *commandable) GeoRadius(query *GeoRadiusQuery) *GeoLocationCmd { + args := make([]interface{}, 6) args[0] = "GEORADIUS" args[1] = query.Key args[2] = query.Longitude @@ -1718,28 +1701,22 @@ func (c *commandable) GeoRadius(query *GeoRadiusQuery) *GeoCmd { args[5] = "km" } if query.WithCoordinates { - args[6+optionsCtr] = "WITHCOORD" - optionsCtr++ + args = append(args, "WITHCOORD") } if query.WithDistance { - args[6+optionsCtr] = "WITHDIST" - optionsCtr++ + args = append(args, "WITHDIST") } if query.WithGeoHash { - args[6+optionsCtr] = "WITHHASH" - optionsCtr++ + args = append(args, "WITHHASH") } if query.Count > 0 { - args[6+optionsCtr] = "COUNT" - optionsCtr++ - args[6+optionsCtr] = query.Count - optionsCtr++ + args = append(args, "COUNT", query.Count) } if query.Sort != "" { - args[6+optionsCtr] = query.Sort + args = append(args, query.Sort) } - cmd := NewGeoCmd(args...) + cmd := NewGeoLocationCmd(args...) c.Process(cmd) return cmd } diff --git a/multi.go b/multi.go index 7b55c7b6..e3d628fd 100644 --- a/multi.go +++ b/multi.go @@ -116,7 +116,7 @@ func (c *Multi) execCmds(cn *conn, cmds []Cmder) error { // Parse queued replies. for i := 0; i < cmdsLen; i++ { - if err := statusCmd.parseReply(cn); err != nil { + if err := statusCmd.readReply(cn); err != nil { setCmdsErr(cmds[1:len(cmds)-1], err) return err } @@ -144,7 +144,7 @@ func (c *Multi) execCmds(cn *conn, cmds []Cmder) error { // Loop starts from 1 to omit MULTI cmd. for i := 1; i < cmdsLen; i++ { cmd := cmds[i] - if err := cmd.parseReply(cn); err != nil { + if err := cmd.readReply(cn); err != nil { if firstCmdErr == nil { firstCmdErr = err } diff --git a/parser.go b/parser.go index 5b6e0733..9acb0f17 100644 --- a/parser.go +++ b/parser.go @@ -8,6 +8,14 @@ import ( "strconv" ) +const ( + errorReply = '-' + statusReply = '+' + intReply = ':' + stringReply = '$' + arrayReply = '*' +) + type multiBulkParser func(cn *conn, n int64) (interface{}, error) var ( @@ -239,57 +247,157 @@ func readN(cn *conn, n int) ([]byte, error) { //------------------------------------------------------------------------------ -func parseReply(cn *conn, p multiBulkParser) (interface{}, error) { +func parseErrorReply(cn *conn, line []byte) error { + return errorf(string(line[1:])) +} + +func parseIntReply(cn *conn, line []byte) (int64, error) { + n, err := strconv.ParseInt(bytesToString(line[1:]), 10, 64) + if err != nil { + return 0, err + } + return n, nil +} + +func readIntReply(cn *conn) (int64, error) { + line, err := readLine(cn) + if err != nil { + return 0, err + } + switch line[0] { + case errorReply: + return 0, parseErrorReply(cn, line) + case intReply: + return parseIntReply(cn, line) + default: + return 0, fmt.Errorf("readIntReply: can't parse %.100q", line) + } +} + +func parseBytesReply(cn *conn, line []byte) ([]byte, error) { + if len(line) == 3 && line[1] == '-' && line[2] == '1' { + return nil, Nil + } + + replyLen, err := strconv.Atoi(bytesToString(line[1:])) + if err != nil { + return nil, err + } + + b, err := readN(cn, replyLen+2) + if err != nil { + return nil, err + } + + return b[:replyLen], nil +} + +func readBytesReply(cn *conn) ([]byte, error) { + line, err := readLine(cn) + if err != nil { + return nil, err + } + switch line[0] { + case errorReply: + return nil, parseErrorReply(cn, line) + case stringReply: + return parseBytesReply(cn, line) + default: + return nil, fmt.Errorf("readBytesReply: can't parse %.100q", line) + } +} + +func readStringReply(cn *conn) (string, error) { + b, err := readBytesReply(cn) + if err != nil { + return "", err + } + return string(b), nil +} + +func readFloatReply(cn *conn) (float64, error) { + b, err := readBytesReply(cn) + if err != nil { + return 0, err + } + return strconv.ParseFloat(bytesToString(b), 64) +} + +func parseArrayHeader(cn *conn, line []byte) (int64, error) { + if len(line) == 3 && line[1] == '-' && line[2] == '1' { + return 0, Nil + } + + n, err := strconv.ParseInt(bytesToString(line[1:]), 10, 64) + if err != nil { + return 0, err + } + return n, nil +} + +func parseArrayReply(cn *conn, p multiBulkParser, line []byte) (interface{}, error) { + n, err := parseArrayHeader(cn, line) + if err != nil { + return nil, err + } + return p(cn, n) +} + +func readArrayHeader(cn *conn) (int64, error) { + line, err := readLine(cn) + if err != nil { + return 0, err + } + switch line[0] { + case errorReply: + return 0, parseErrorReply(cn, line) + case arrayReply: + return parseArrayHeader(cn, line) + default: + return 0, fmt.Errorf("readArrayReply: can't parse %.100q", line) + } +} + +func readArrayReply(cn *conn, p multiBulkParser) (interface{}, error) { + line, err := readLine(cn) + if err != nil { + return nil, err + } + switch line[0] { + case errorReply: + return nil, parseErrorReply(cn, line) + case arrayReply: + return parseArrayReply(cn, p, line) + default: + return nil, fmt.Errorf("readArrayReply: can't parse %.100q", line) + } +} + +func readReply(cn *conn, p multiBulkParser) (interface{}, error) { line, err := readLine(cn) if err != nil { return nil, err } switch line[0] { - case '-': - return nil, errorf(string(line[1:])) - case '+': + case errorReply: + return nil, parseErrorReply(cn, line) + case statusReply: return line[1:], nil - case ':': - v, err := strconv.ParseInt(bytesToString(line[1:]), 10, 64) - if err != nil { - return nil, err - } - return v, nil - case '$': - if len(line) == 3 && line[1] == '-' && line[2] == '1' { - return nil, Nil - } - - replyLen, err := strconv.Atoi(bytesToString(line[1:])) - if err != nil { - return nil, err - } - - b, err := readN(cn, replyLen+2) - if err != nil { - return nil, err - } - return b[:replyLen], nil - case '*': - if len(line) == 3 && line[1] == '-' && line[2] == '1' { - return nil, Nil - } - - repliesNum, err := strconv.ParseInt(bytesToString(line[1:]), 10, 64) - if err != nil { - return nil, err - } - - return p(cn, repliesNum) + case intReply: + return parseIntReply(cn, line) + case stringReply: + return parseBytesReply(cn, line) + case arrayReply: + return parseArrayReply(cn, p, line) } return nil, fmt.Errorf("redis: can't parse %q", line) } -func parseSlice(cn *conn, n int64) (interface{}, error) { +func sliceParser(cn *conn, n int64) (interface{}, error) { vals := make([]interface{}, 0, n) for i := int64(0); i < n; i++ { - v, err := parseReply(cn, parseSlice) + v, err := readReply(cn, sliceParser) if err == Nil { vals = append(vals, nil) } else if err != nil { @@ -306,171 +414,224 @@ func parseSlice(cn *conn, n int64) (interface{}, error) { return vals, nil } -func parseStringSlice(cn *conn, n int64) (interface{}, error) { - vals := make([]string, 0, n) +func intSliceParser(cn *conn, n int64) (interface{}, error) { + ints := make([]int64, 0, n) for i := int64(0); i < n; i++ { - viface, err := parseReply(cn, nil) + n, err := readIntReply(cn) if err != nil { return nil, err } - v, ok := viface.([]byte) - if !ok { - return nil, fmt.Errorf("got %T, expected string", viface) - } - vals = append(vals, string(v)) + ints = append(ints, n) } - return vals, nil + return ints, nil } -func parseBoolSlice(cn *conn, n int64) (interface{}, error) { - vals := make([]bool, 0, n) +func boolSliceParser(cn *conn, n int64) (interface{}, error) { + bools := make([]bool, 0, n) for i := int64(0); i < n; i++ { - viface, err := parseReply(cn, nil) + n, err := readIntReply(cn) if err != nil { return nil, err } - v, ok := viface.(int64) - if !ok { - return nil, fmt.Errorf("got %T, expected int64", viface) - } - vals = append(vals, v == 1) + bools = append(bools, n == 1) } - return vals, nil + return bools, nil } -func parseStringStringMap(cn *conn, n int64) (interface{}, error) { +func stringSliceParser(cn *conn, n int64) (interface{}, error) { + ss := make([]string, 0, n) + for i := int64(0); i < n; i++ { + s, err := readStringReply(cn) + if err != nil { + return nil, err + } + ss = append(ss, s) + } + return ss, nil +} + +func floatSliceParser(cn *conn, n int64) (interface{}, error) { + nn := make([]float64, 0, n) + for i := int64(0); i < n; i++ { + n, err := readFloatReply(cn) + if err != nil { + return nil, err + } + nn = append(nn, n) + } + return nn, nil +} + +func stringStringMapParser(cn *conn, n int64) (interface{}, error) { m := make(map[string]string, n/2) for i := int64(0); i < n; i += 2 { - keyIface, err := parseReply(cn, nil) + key, err := readStringReply(cn) if err != nil { return nil, err } - keyBytes, ok := keyIface.([]byte) - if !ok { - return nil, fmt.Errorf("got %T, expected []byte", keyIface) - } - key := string(keyBytes) - valueIface, err := parseReply(cn, nil) + value, err := readStringReply(cn) if err != nil { return nil, err } - valueBytes, ok := valueIface.([]byte) - if !ok { - return nil, fmt.Errorf("got %T, expected []byte", valueIface) - } - m[key] = string(valueBytes) + m[key] = value } return m, nil } -func parseStringIntMap(cn *conn, n int64) (interface{}, error) { +func stringIntMapParser(cn *conn, n int64) (interface{}, error) { m := make(map[string]int64, n/2) for i := int64(0); i < n; i += 2 { - keyiface, err := parseReply(cn, nil) + key, err := readStringReply(cn) if err != nil { return nil, err } - key, ok := keyiface.([]byte) - if !ok { - return nil, fmt.Errorf("got %T, expected string", keyiface) - } - valueiface, err := parseReply(cn, nil) + n, err := readIntReply(cn) if err != nil { return nil, err } - switch value := valueiface.(type) { - case int64: - m[string(key)] = value - case string: - m[string(key)], err = strconv.ParseInt(value, 10, 64) - if err != nil { - return nil, fmt.Errorf("got %v, expected number", value) - } - default: - return nil, fmt.Errorf("got %T, expected number or string", valueiface) - } + + m[key] = n } return m, nil } -func parseZSlice(cn *conn, n int64) (interface{}, error) { +func zSliceParser(cn *conn, n int64) (interface{}, error) { zz := make([]Z, n/2) for i := int64(0); i < n; i += 2 { + var err error + z := &zz[i/2] - memberiface, err := parseReply(cn, nil) + z.Member, err = readStringReply(cn) if err != nil { return nil, err } - member, ok := memberiface.([]byte) - if !ok { - return nil, fmt.Errorf("got %T, expected string", memberiface) - } - z.Member = string(member) - scoreiface, err := parseReply(cn, nil) + z.Score, err = readFloatReply(cn) if err != nil { return nil, err } - scoreb, ok := scoreiface.([]byte) - if !ok { - return nil, fmt.Errorf("got %T, expected string", scoreiface) - } - score, err := strconv.ParseFloat(bytesToString(scoreb), 64) - if err != nil { - return nil, err - } - z.Score = score } return zz, nil } -func parseClusterSlotInfoSlice(cn *conn, n int64) (interface{}, error) { +func clusterSlotInfoSliceParser(cn *conn, n int64) (interface{}, error) { infos := make([]ClusterSlotInfo, 0, n) for i := int64(0); i < n; i++ { - viface, err := parseReply(cn, parseSlice) + n, err := readArrayHeader(cn) + if err != nil { + return nil, err + } + if n < 2 { + return nil, fmt.Errorf("got %d elements in cluster info, expected at least 2", n) + } + + start, err := readIntReply(cn) if err != nil { return nil, err } - item, ok := viface.([]interface{}) - if !ok { - return nil, fmt.Errorf("got %T, expected []interface{}", viface) - } else if len(item) < 3 { - return nil, fmt.Errorf("got %v, expected {int64, int64, string...}", item) + end, err := readIntReply(cn) + if err != nil { + return nil, err } - start, ok := item[0].(int64) - if !ok || start < 0 || start > hashSlots { - return nil, fmt.Errorf("got %v, expected {int64, int64, string...}", item) - } - end, ok := item[1].(int64) - if !ok || end < 0 || end > hashSlots { - return nil, fmt.Errorf("got %v, expected {int64, int64, string...}", item) + addrsn := n - 2 + info := ClusterSlotInfo{ + Start: int(start), + End: int(end), + Addrs: make([]string, addrsn), } - info := ClusterSlotInfo{int(start), int(end), make([]string, len(item)-2)} - for n, ipair := range item[2:] { - pair, ok := ipair.([]interface{}) - if !ok || len(pair) != 2 { - return nil, fmt.Errorf("got %v, expected []interface{host, port}", viface) + for i := int64(0); i < addrsn; i++ { + n, err := readArrayHeader(cn) + if err != nil { + return nil, err + } + if n != 2 { + return nil, fmt.Errorf("got %d elements in cluster info address, expected 2", n) } - ip, ok := pair[0].(string) - if !ok || len(ip) < 1 { - return nil, fmt.Errorf("got %v, expected IP PORT pair", pair) - } - port, ok := pair[1].(int64) - if !ok || port < 1 { - return nil, fmt.Errorf("got %v, expected IP PORT pair", pair) + ip, err := readStringReply(cn) + if err != nil { + return nil, err } - info.Addrs[n] = net.JoinHostPort(ip, strconv.FormatInt(port, 10)) + port, err := readIntReply(cn) + if err != nil { + return nil, err + } + + info.Addrs[i] = net.JoinHostPort(ip, strconv.FormatInt(port, 10)) } + infos = append(infos, info) } return infos, nil } + +func geoLocationParser(cn *conn, n int64) (interface{}, error) { + loc := &GeoLocation{} + + var err error + loc.Name, err = readStringReply(cn) + if err != nil { + return nil, err + } + if n >= 2 { + loc.Distance, err = readFloatReply(cn) + if err != nil { + return nil, err + } + } + if n >= 3 { + loc.GeoHash, err = readIntReply(cn) + if err != nil { + return nil, err + } + } + if n >= 4 { + n, err := readArrayHeader(cn) + if err != nil { + return nil, err + } + if n != 2 { + return nil, fmt.Errorf("got %d coordinates, expected 2", n) + } + + loc.Longitude, err = readFloatReply(cn) + if err != nil { + return nil, err + } + loc.Latitude, err = readFloatReply(cn) + if err != nil { + return nil, err + } + } + + return loc, nil +} + +func geoLocationSliceParser(cn *conn, n int64) (interface{}, error) { + locs := make([]GeoLocation, 0, n) + for i := int64(0); i < n; i++ { + v, err := readReply(cn, geoLocationParser) + if err != nil { + return nil, err + } + switch vv := v.(type) { + case []byte: + locs = append(locs, GeoLocation{ + Name: string(vv), + }) + case *GeoLocation: + locs = append(locs, *vv) + default: + return nil, fmt.Errorf("got %T, expected string or *GeoLocation", v) + } + } + return locs, nil +} diff --git a/parser_test.go b/parser_test.go index 10403f62..b1c74344 100644 --- a/parser_test.go +++ b/parser_test.go @@ -23,7 +23,7 @@ func BenchmarkParseReplyString(b *testing.B) { } func BenchmarkParseReplySlice(b *testing.B) { - benchmarkParseReply(b, "*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", parseSlice, false) + benchmarkParseReply(b, "*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", sliceParser, false) } func benchmarkParseReply(b *testing.B, reply string, p multiBulkParser, wanterr bool) { @@ -39,7 +39,7 @@ func benchmarkParseReply(b *testing.B, reply string, p multiBulkParser, wanterr b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := parseReply(cn, p) + _, err := readReply(cn, p) if !wanterr && err != nil { b.Fatal(err) } diff --git a/pipeline.go b/pipeline.go index 02ecbacc..d7d13042 100644 --- a/pipeline.go +++ b/pipeline.go @@ -97,7 +97,7 @@ func execCmds(cn *conn, cmds []Cmder) ([]Cmder, error) { var firstCmdErr error var failedCmds []Cmder for _, cmd := range cmds { - err := cmd.parseReply(cn) + err := cmd.readReply(cn) if err == nil { continue } diff --git a/pubsub.go b/pubsub.go index fa804eb4..ba053e47 100644 --- a/pubsub.go +++ b/pubsub.go @@ -215,7 +215,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { cn.ReadTimeout = timeout cmd := NewSliceCmd() - if err := cmd.parseReply(cn); err != nil { + if err := cmd.readReply(cn); err != nil { return nil, err } return newMessage(cmd.Val()) diff --git a/redis.go b/redis.go index aea53a22..2d1076d2 100644 --- a/redis.go +++ b/redis.go @@ -69,7 +69,7 @@ func (c *baseClient) process(cmd Cmder) { return } - err = cmd.parseReply(cn) + err = cmd.readReply(cn) c.putConn(cn, err) if shouldRetry(err) { continue