diff --git a/command.go b/command.go index 6b4465e8..850d1eb4 100644 --- a/command.go +++ b/command.go @@ -198,7 +198,7 @@ func (cmd *SliceCmd) String() string { } func (cmd *SliceCmd) readReply(cn *conn) error { - v, err := readReply(cn, sliceParser) + v, err := readArrayReply(cn, sliceParser) if err != nil { cmd.err = err return err @@ -241,13 +241,8 @@ func (cmd *StatusCmd) String() string { } func (cmd *StatusCmd) readReply(cn *conn) error { - v, err := readReply(cn, nil) - if err != nil { - cmd.err = err - return err - } - cmd.val = string(v.([]byte)) - return nil + cmd.val, cmd.err = readStringReply(cn) + return cmd.err } //------------------------------------------------------------------------------ @@ -280,13 +275,8 @@ func (cmd *IntCmd) String() string { } func (cmd *IntCmd) readReply(cn *conn) error { - v, err := readReply(cn, nil) - if err != nil { - cmd.err = err - return err - } - cmd.val = v.(int64) - return nil + cmd.val, cmd.err = readIntReply(cn) + return cmd.err } //------------------------------------------------------------------------------ @@ -323,12 +313,12 @@ func (cmd *DurationCmd) String() string { } func (cmd *DurationCmd) readReply(cn *conn) error { - v, err := readReply(cn, nil) + n, err := readIntReply(cn) if err != nil { cmd.err = err return err } - cmd.val = time.Duration(v.(int64)) * cmd.precision + cmd.val = time.Duration(n) * cmd.precision return nil } @@ -365,8 +355,8 @@ var ok = []byte("OK") func (cmd *BoolCmd) readReply(cn *conn) error { v, err := readReply(cn, nil) - // `SET key value NX` returns nil when key already exists, which - // is inconsistent with `SETNX key value`. + // `SET key value NX` returns nil when key already exists. But + // `SETNX key value` returns bool (0/1). So convert nil to bool. // TODO: is this okay? if err == Nil { cmd.val = false @@ -450,12 +440,12 @@ func (cmd *StringCmd) String() string { } func (cmd *StringCmd) readReply(cn *conn) error { - v, err := readReply(cn, nil) + b, err := readBytesReply(cn) if err != nil { cmd.err = err return err } - cmd.val = cn.copyBuf(v.([]byte)) + cmd.val = cn.copyBuf(b) return nil } @@ -489,13 +479,7 @@ func (cmd *FloatCmd) String() string { } func (cmd *FloatCmd) readReply(cn *conn) error { - v, err := readReply(cn, nil) - if err != nil { - cmd.err = err - return err - } - b := v.([]byte) - cmd.val, cmd.err = strconv.ParseFloat(bytesToString(b), 64) + cmd.val, cmd.err = readFloatReply(cn) return cmd.err } @@ -529,7 +513,7 @@ func (cmd *StringSliceCmd) String() string { } func (cmd *StringSliceCmd) readReply(cn *conn) error { - v, err := readReply(cn, stringSliceParser) + v, err := readArrayReply(cn, stringSliceParser) if err != nil { cmd.err = err return err @@ -568,7 +552,7 @@ func (cmd *BoolSliceCmd) String() string { } func (cmd *BoolSliceCmd) readReply(cn *conn) error { - v, err := readReply(cn, boolSliceParser) + v, err := readArrayReply(cn, boolSliceParser) if err != nil { cmd.err = err return err @@ -607,7 +591,7 @@ func (cmd *StringStringMapCmd) String() string { } func (cmd *StringStringMapCmd) readReply(cn *conn) error { - v, err := readReply(cn, stringStringMapParser) + v, err := readArrayReply(cn, stringStringMapParser) if err != nil { cmd.err = err return err @@ -646,7 +630,7 @@ func (cmd *StringIntMapCmd) reset() { } func (cmd *StringIntMapCmd) readReply(cn *conn) error { - v, err := readReply(cn, stringIntMapParser) + v, err := readArrayReply(cn, stringIntMapParser) if err != nil { cmd.err = err return err @@ -685,7 +669,7 @@ func (cmd *ZSliceCmd) String() string { } func (cmd *ZSliceCmd) readReply(cn *conn) error { - v, err := readReply(cn, zSliceParser) + v, err := readArrayReply(cn, zSliceParser) if err != nil { cmd.err = err return err @@ -713,6 +697,9 @@ func (cmd *ScanCmd) reset() { cmd.err = nil } +// TODO: cursor should be string to match redis type +// TODO: swap return values + func (cmd *ScanCmd) Val() (int64, []string) { return cmd.cursor, cmd.keys } @@ -726,23 +713,13 @@ func (cmd *ScanCmd) String() string { } func (cmd *ScanCmd) readReply(cn *conn) error { - vi, err := readReply(cn, sliceParser) + keys, cursor, err := readScanReply(cn) if err != nil { cmd.err = err return cmd.err } - v := vi.([]interface{}) - - cmd.cursor, cmd.err = strconv.ParseInt(v[0].(string), 10, 64) - if cmd.err != nil { - return cmd.err - } - - keys := v[1].([]interface{}) - for _, keyi := range keys { - cmd.keys = append(cmd.keys, keyi.(string)) - } - + cmd.keys = keys + cmd.cursor = cursor return nil } @@ -782,7 +759,7 @@ func (cmd *ClusterSlotCmd) reset() { } func (cmd *ClusterSlotCmd) readReply(cn *conn) error { - v, err := readReply(cn, clusterSlotInfoSliceParser) + v, err := readArrayReply(cn, clusterSlotInfoSliceParser) if err != nil { cmd.err = err return err @@ -844,7 +821,7 @@ func (cmd *GeoLocationCmd) String() string { } func (cmd *GeoLocationCmd) readReply(cn *conn) error { - reply, err := readReply(cn, geoLocationSliceParser) + reply, err := readArrayReply(cn, geoLocationSliceParser) if err != nil { cmd.err = err return err diff --git a/parser.go b/parser.go index 9acb0f17..dd242ca1 100644 --- a/parser.go +++ b/parser.go @@ -251,6 +251,21 @@ func parseErrorReply(cn *conn, line []byte) error { return errorf(string(line[1:])) } +func isNilReply(b []byte) bool { + return len(b) == 3 && b[1] == '-' && b[2] == '1' +} + +func parseNilReply(cn *conn, line []byte) error { + if isNilReply(line) { + return Nil + } + return fmt.Errorf("redis: can't parse nil reply: %.100", line) +} + +func parseStatusReply(cn *conn, line []byte) ([]byte, error) { + return line[1:], nil +} + func parseIntReply(cn *conn, line []byte) (int64, error) { n, err := strconv.ParseInt(bytesToString(line[1:]), 10, 64) if err != nil { @@ -267,15 +282,17 @@ func readIntReply(cn *conn) (int64, error) { switch line[0] { case errorReply: return 0, parseErrorReply(cn, line) + case stringReply: + return 0, parseNilReply(cn, line) case intReply: return parseIntReply(cn, line) default: - return 0, fmt.Errorf("readIntReply: can't parse %.100q", line) + return 0, fmt.Errorf("redis: can't parse int reply: %.100q", line) } } func parseBytesReply(cn *conn, line []byte) ([]byte, error) { - if len(line) == 3 && line[1] == '-' && line[2] == '1' { + if isNilReply(line) { return nil, Nil } @@ -302,8 +319,10 @@ func readBytesReply(cn *conn) ([]byte, error) { return nil, parseErrorReply(cn, line) case stringReply: return parseBytesReply(cn, line) + case statusReply: + return parseStatusReply(cn, line) default: - return nil, fmt.Errorf("readBytesReply: can't parse %.100q", line) + return nil, fmt.Errorf("redis: can't parse string reply: %.100q", line) } } @@ -354,7 +373,7 @@ func readArrayHeader(cn *conn) (int64, error) { case arrayReply: return parseArrayHeader(cn, line) default: - return 0, fmt.Errorf("readArrayReply: can't parse %.100q", line) + return 0, fmt.Errorf("redis: can't parse array reply: %.100q", line) } } @@ -369,7 +388,7 @@ func readArrayReply(cn *conn, p multiBulkParser) (interface{}, error) { case arrayReply: return parseArrayReply(cn, p, line) default: - return nil, fmt.Errorf("readArrayReply: can't parse %.100q", line) + return nil, fmt.Errorf("redis: can't parse array reply: %.100q", line) } } @@ -383,7 +402,7 @@ func readReply(cn *conn, p multiBulkParser) (interface{}, error) { case errorReply: return nil, parseErrorReply(cn, line) case statusReply: - return line[1:], nil + return parseStatusReply(cn, line) case intReply: return parseIntReply(cn, line) case stringReply: @@ -391,7 +410,43 @@ func readReply(cn *conn, p multiBulkParser) (interface{}, error) { case arrayReply: return parseArrayReply(cn, p, line) } - return nil, fmt.Errorf("redis: can't parse %q", line) + return nil, fmt.Errorf("redis: can't parse %.100q", line) +} + +func readScanReply(cn *conn) ([]string, int64, error) { + n, err := readArrayHeader(cn) + if err != nil { + return nil, 0, err + } + if n != 2 { + return nil, 0, fmt.Errorf("redis: got %d elements in scan reply, expected 2") + } + + b, err := readBytesReply(cn) + if err != nil { + return nil, 0, err + } + + cursor, err := strconv.ParseInt(bytesToString(b), 10, 64) + if err != nil { + return nil, 0, err + } + + n, err = readArrayHeader(cn) + if err != nil { + return nil, 0, err + } + + keys := make([]string, n) + for i := int64(0); i < n; i++ { + key, err := readStringReply(cn) + if err != nil { + return nil, 0, err + } + keys[i] = key + } + + return keys, cursor, err } func sliceParser(cn *conn, n int64) (interface{}, error) { @@ -526,7 +581,8 @@ func clusterSlotInfoSliceParser(cn *conn, n int64) (interface{}, error) { return nil, err } if n < 2 { - return nil, fmt.Errorf("got %d elements in cluster info, expected at least 2", n) + err := fmt.Errorf("redis: got %d elements in cluster info, expected at least 2", n) + return nil, err } start, err := readIntReply(cn)