Fix proto.RedisError in slices

This commit is contained in:
Vladimir Mihailenco 2018-02-22 14:14:30 +02:00
parent 71ed499c46
commit 56dea1f39a
14 changed files with 122 additions and 105 deletions

View File

@ -574,8 +574,8 @@ var _ = Describe("ClusterClient", func() {
Describe("ClusterClient failover", func() { Describe("ClusterClient failover", func() {
BeforeEach(func() { BeforeEach(func() {
opt = redisClusterOptions() opt = redisClusterOptions()
opt.MinRetryBackoff = 100 * time.Millisecond opt.MinRetryBackoff = 250 * time.Millisecond
opt.MaxRetryBackoff = 3 * time.Second opt.MaxRetryBackoff = time.Second
client = cluster.clusterClient(opt) client = cluster.clusterClient(opt)
_ = client.ForEachSlave(func(slave *redis.Client) error { _ = client.ForEachSlave(func(slave *redis.Client) error {

View File

@ -10,6 +10,7 @@ import (
"github.com/go-redis/redis/internal" "github.com/go-redis/redis/internal"
"github.com/go-redis/redis/internal/pool" "github.com/go-redis/redis/internal/pool"
"github.com/go-redis/redis/internal/proto" "github.com/go-redis/redis/internal/proto"
"github.com/go-redis/redis/internal/util"
) )
type Cmder interface { type Cmder interface {
@ -436,7 +437,7 @@ func NewStringCmd(args ...interface{}) *StringCmd {
} }
func (cmd *StringCmd) Val() string { func (cmd *StringCmd) Val() string {
return internal.BytesToString(cmd.val) return util.BytesToString(cmd.val)
} }
func (cmd *StringCmd) Result() (string, error) { func (cmd *StringCmd) Result() (string, error) {

View File

@ -1,6 +1,7 @@
package redis package redis
import ( import (
"errors"
"io" "io"
"time" "time"
@ -1802,7 +1803,7 @@ func (c *cmdable) shutdown(modifier string) *StatusCmd {
} }
} else { } else {
// Server did not quit. String reply contains the reason. // Server did not quit. String reply contains the reason.
cmd.err = internal.RedisError(cmd.val) cmd.err = errors.New(cmd.val)
cmd.val = "" cmd.val = ""
} }
return cmd return cmd

View File

@ -10,7 +10,7 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"github.com/go-redis/redis" "github.com/go-redis/redis"
"github.com/go-redis/redis/internal" "github.com/go-redis/redis/internal/proto"
) )
var _ = Describe("Commands", func() { var _ = Describe("Commands", func() {
@ -3000,7 +3000,7 @@ var _ = Describe("Commands", func() {
nil, nil,
).Result() ).Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(vals).To(Equal([]interface{}{int64(12), internal.RedisError("error"), "abc"})) Expect(vals).To(Equal([]interface{}{int64(12), proto.RedisError("error"), "abc"}))
}) })
}) })

View File

@ -4,14 +4,10 @@ import (
"io" "io"
"net" "net"
"strings" "strings"
"github.com/go-redis/redis/internal/proto"
) )
const Nil = RedisError("redis: nil")
type RedisError string
func (e RedisError) Error() string { return string(e) }
func IsRetryableError(err error, retryNetError bool) bool { func IsRetryableError(err error, retryNetError bool) bool {
if IsNetworkError(err) { if IsNetworkError(err) {
return retryNetError return retryNetError
@ -30,7 +26,7 @@ func IsRetryableError(err error, retryNetError bool) bool {
} }
func IsRedisError(err error) bool { func IsRedisError(err error) bool {
_, ok := err.(RedisError) _, ok := err.(proto.RedisError)
return ok return ok
} }

View File

@ -6,7 +6,7 @@ import (
"io" "io"
"strconv" "strconv"
"github.com/go-redis/redis/internal" "github.com/go-redis/redis/internal/util"
) )
const bytesAllocLimit = 1024 * 1024 // 1mb const bytesAllocLimit = 1024 * 1024 // 1mb
@ -19,6 +19,16 @@ const (
ArrayReply = '*' ArrayReply = '*'
) )
//------------------------------------------------------------------------------
const Nil = RedisError("redis: nil")
type RedisError string
func (e RedisError) Error() string { return string(e) }
//------------------------------------------------------------------------------
type MultiBulkParse func(*Reader, int64) (interface{}, error) type MultiBulkParse func(*Reader, int64) (interface{}, error)
type Reader struct { type Reader struct {
@ -66,7 +76,7 @@ func (r *Reader) ReadLine() ([]byte, error) {
return nil, fmt.Errorf("redis: reply is empty") return nil, fmt.Errorf("redis: reply is empty")
} }
if isNilReply(line) { if isNilReply(line) {
return nil, internal.Nil return nil, Nil
} }
return line, nil return line, nil
} }
@ -83,7 +93,7 @@ func (r *Reader) ReadReply(m MultiBulkParse) (interface{}, error) {
case StatusReply: case StatusReply:
return parseStatusValue(line), nil return parseStatusValue(line), nil
case IntReply: case IntReply:
return parseInt(line[1:], 10, 64) return util.ParseInt(line[1:], 10, 64)
case StringReply: case StringReply:
return r.readTmpBytesValue(line) return r.readTmpBytesValue(line)
case ArrayReply: case ArrayReply:
@ -105,7 +115,7 @@ func (r *Reader) ReadIntReply() (int64, error) {
case ErrorReply: case ErrorReply:
return 0, ParseErrorReply(line) return 0, ParseErrorReply(line)
case IntReply: case IntReply:
return parseInt(line[1:], 10, 64) return util.ParseInt(line[1:], 10, 64)
default: default:
return 0, fmt.Errorf("redis: can't parse int reply: %.100q", line) return 0, fmt.Errorf("redis: can't parse int reply: %.100q", line)
} }
@ -151,7 +161,7 @@ func (r *Reader) ReadFloatReply() (float64, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
return parseFloat(b, 64) return util.ParseFloat(b, 64)
} }
func (r *Reader) ReadArrayReply(m MultiBulkParse) (interface{}, error) { func (r *Reader) ReadArrayReply(m MultiBulkParse) (interface{}, error) {
@ -221,7 +231,7 @@ func (r *Reader) ReadScanReply() ([]string, uint64, error) {
func (r *Reader) readTmpBytesValue(line []byte) ([]byte, error) { func (r *Reader) readTmpBytesValue(line []byte) ([]byte, error) {
if isNilReply(line) { if isNilReply(line) {
return nil, internal.Nil return nil, Nil
} }
replyLen, err := strconv.Atoi(string(line[1:])) replyLen, err := strconv.Atoi(string(line[1:]))
@ -241,7 +251,7 @@ func (r *Reader) ReadInt() (int64, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
return parseInt(b, 10, 64) return util.ParseInt(b, 10, 64)
} }
func (r *Reader) ReadUint() (uint64, error) { func (r *Reader) ReadUint() (uint64, error) {
@ -249,7 +259,7 @@ func (r *Reader) ReadUint() (uint64, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
return parseUint(b, 10, 64) return util.ParseUint(b, 10, 64)
} }
// -------------------------------------------------------------------- // --------------------------------------------------------------------
@ -303,7 +313,7 @@ func isNilReply(b []byte) bool {
} }
func ParseErrorReply(line []byte) error { func ParseErrorReply(line []byte) error {
return internal.RedisError(string(line[1:])) return RedisError(string(line[1:]))
} }
func parseStatusValue(line []byte) []byte { func parseStatusValue(line []byte) []byte {
@ -312,23 +322,7 @@ func parseStatusValue(line []byte) []byte {
func parseArrayLen(line []byte) (int64, error) { func parseArrayLen(line []byte) (int64, error) {
if isNilReply(line) { if isNilReply(line) {
return 0, internal.Nil return 0, Nil
} }
return parseInt(line[1:], 10, 64) return util.ParseInt(line[1:], 10, 64)
}
func atoi(b []byte) (int, error) {
return strconv.Atoi(internal.BytesToString(b))
}
func parseInt(b []byte, base int, bitSize int) (int64, error) {
return strconv.ParseInt(internal.BytesToString(b), base, bitSize)
}
func parseUint(b []byte, base int, bitSize int) (uint64, error) {
return strconv.ParseUint(internal.BytesToString(b), base, bitSize)
}
func parseFloat(b []byte, bitSize int) (float64, error) {
return strconv.ParseFloat(internal.BytesToString(b), bitSize)
} }

View File

@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"github.com/go-redis/redis/internal" "github.com/go-redis/redis/internal/util"
) )
func Scan(b []byte, v interface{}) error { func Scan(b []byte, v interface{}) error {
@ -13,80 +13,80 @@ func Scan(b []byte, v interface{}) error {
case nil: case nil:
return fmt.Errorf("redis: Scan(nil)") return fmt.Errorf("redis: Scan(nil)")
case *string: case *string:
*v = internal.BytesToString(b) *v = util.BytesToString(b)
return nil return nil
case *[]byte: case *[]byte:
*v = b *v = b
return nil return nil
case *int: case *int:
var err error var err error
*v, err = atoi(b) *v, err = util.Atoi(b)
return err return err
case *int8: case *int8:
n, err := parseInt(b, 10, 8) n, err := util.ParseInt(b, 10, 8)
if err != nil { if err != nil {
return err return err
} }
*v = int8(n) *v = int8(n)
return nil return nil
case *int16: case *int16:
n, err := parseInt(b, 10, 16) n, err := util.ParseInt(b, 10, 16)
if err != nil { if err != nil {
return err return err
} }
*v = int16(n) *v = int16(n)
return nil return nil
case *int32: case *int32:
n, err := parseInt(b, 10, 32) n, err := util.ParseInt(b, 10, 32)
if err != nil { if err != nil {
return err return err
} }
*v = int32(n) *v = int32(n)
return nil return nil
case *int64: case *int64:
n, err := parseInt(b, 10, 64) n, err := util.ParseInt(b, 10, 64)
if err != nil { if err != nil {
return err return err
} }
*v = n *v = n
return nil return nil
case *uint: case *uint:
n, err := parseUint(b, 10, 64) n, err := util.ParseUint(b, 10, 64)
if err != nil { if err != nil {
return err return err
} }
*v = uint(n) *v = uint(n)
return nil return nil
case *uint8: case *uint8:
n, err := parseUint(b, 10, 8) n, err := util.ParseUint(b, 10, 8)
if err != nil { if err != nil {
return err return err
} }
*v = uint8(n) *v = uint8(n)
return nil return nil
case *uint16: case *uint16:
n, err := parseUint(b, 10, 16) n, err := util.ParseUint(b, 10, 16)
if err != nil { if err != nil {
return err return err
} }
*v = uint16(n) *v = uint16(n)
return nil return nil
case *uint32: case *uint32:
n, err := parseUint(b, 10, 32) n, err := util.ParseUint(b, 10, 32)
if err != nil { if err != nil {
return err return err
} }
*v = uint32(n) *v = uint32(n)
return nil return nil
case *uint64: case *uint64:
n, err := parseUint(b, 10, 64) n, err := util.ParseUint(b, 10, 64)
if err != nil { if err != nil {
return err return err
} }
*v = n *v = n
return nil return nil
case *float32: case *float32:
n, err := parseFloat(b, 32) n, err := util.ParseFloat(b, 32)
if err != nil { if err != nil {
return err return err
} }
@ -94,7 +94,7 @@ func Scan(b []byte, v interface{}) error {
return err return err
case *float64: case *float64:
var err error var err error
*v, err = parseFloat(b, 64) *v, err = util.ParseFloat(b, 64)
return err return err
case *bool: case *bool:
*v = len(b) == 1 && b[0] == '1' *v = len(b) == 1 && b[0] == '1'
@ -120,7 +120,7 @@ func ScanSlice(data []string, slice interface{}) error {
return fmt.Errorf("redis: ScanSlice(non-slice %T)", slice) return fmt.Errorf("redis: ScanSlice(non-slice %T)", slice)
} }
next := internal.MakeSliceNextElemFunc(v) next := makeSliceNextElemFunc(v)
for i, s := range data { for i, s := range data {
elem := next() elem := next()
if err := Scan([]byte(s), elem.Addr().Interface()); err != nil { if err := Scan([]byte(s), elem.Addr().Interface()); err != nil {
@ -131,3 +131,36 @@ func ScanSlice(data []string, slice interface{}) error {
return nil return nil
} }
func makeSliceNextElemFunc(v reflect.Value) func() reflect.Value {
elemType := v.Type().Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
return func() reflect.Value {
if v.Len() < v.Cap() {
v.Set(v.Slice(0, v.Len()+1))
elem := v.Index(v.Len() - 1)
if elem.IsNil() {
elem.Set(reflect.New(elemType))
}
return elem.Elem()
}
elem := reflect.New(elemType)
v.Set(reflect.Append(v, elem))
return elem.Elem()
}
}
zero := reflect.Zero(elemType)
return func() reflect.Value {
if v.Len() < v.Cap() {
v.Set(v.Slice(0, v.Len()+1))
return v.Index(v.Len() - 1)
}
v.Set(reflect.Append(v, zero))
return v.Index(v.Len() - 1)
}
}

View File

@ -1,6 +1,6 @@
package internal package internal
import "reflect" import "github.com/go-redis/redis/internal/util"
func ToLower(s string) string { func ToLower(s string) string {
if isLower(s) { if isLower(s) {
@ -15,7 +15,7 @@ func ToLower(s string) string {
} }
b[i] = c b[i] = c
} }
return BytesToString(b) return util.BytesToString(b)
} }
func isLower(s string) bool { func isLower(s string) bool {
@ -27,36 +27,3 @@ func isLower(s string) bool {
} }
return true return true
} }
func MakeSliceNextElemFunc(v reflect.Value) func() reflect.Value {
elemType := v.Type().Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
return func() reflect.Value {
if v.Len() < v.Cap() {
v.Set(v.Slice(0, v.Len()+1))
elem := v.Index(v.Len() - 1)
if elem.IsNil() {
elem.Set(reflect.New(elemType))
}
return elem.Elem()
}
elem := reflect.New(elemType)
v.Set(reflect.Append(v, elem))
return elem.Elem()
}
}
zero := reflect.Zero(elemType)
return func() reflect.Value {
if v.Len() < v.Cap() {
v.Set(v.Slice(0, v.Len()+1))
return v.Index(v.Len() - 1)
}
v.Set(reflect.Append(v, zero))
return v.Index(v.Len() - 1)
}
}

View File

@ -1,6 +1,6 @@
// +build appengine // +build appengine
package internal package util
func BytesToString(b []byte) string { func BytesToString(b []byte) string {
return string(b) return string(b)

19
internal/util/strconv.go Normal file
View File

@ -0,0 +1,19 @@
package util
import "strconv"
func Atoi(b []byte) (int, error) {
return strconv.Atoi(BytesToString(b))
}
func ParseInt(b []byte, base int, bitSize int) (int64, error) {
return strconv.ParseInt(BytesToString(b), base, bitSize)
}
func ParseUint(b []byte, base int, bitSize int) (uint64, error) {
return strconv.ParseUint(BytesToString(b), base, bitSize)
}
func ParseFloat(b []byte, bitSize int) (float64, error) {
return strconv.ParseFloat(BytesToString(b), bitSize)
}

View File

@ -1,6 +1,6 @@
// +build !appengine // +build !appengine
package internal package util
import ( import (
"unsafe" "unsafe"

View File

@ -14,17 +14,23 @@ func sliceParser(rd *proto.Reader, n int64) (interface{}, error) {
vals := make([]interface{}, 0, n) vals := make([]interface{}, 0, n)
for i := int64(0); i < n; i++ { for i := int64(0); i < n; i++ {
v, err := rd.ReadReply(sliceParser) v, err := rd.ReadReply(sliceParser)
if err == Nil { if err != nil {
vals = append(vals, nil) if err == Nil {
} else if err != nil { vals = append(vals, nil)
vals = append(vals, err) continue
} else {
switch vv := v.(type) {
case []byte:
vals = append(vals, string(vv))
default:
vals = append(vals, v)
} }
if err, ok := err.(proto.RedisError); ok {
vals = append(vals, err)
continue
}
return nil, err
}
switch v := v.(type) {
case []byte:
vals = append(vals, string(v))
default:
vals = append(vals, v)
} }
} }
return vals, nil return vals, nil

View File

@ -12,7 +12,7 @@ import (
) )
// Nil reply redis returned when key does not exist. // Nil reply redis returned when key does not exist.
const Nil = internal.Nil const Nil = proto.Nil
func init() { func init() {
SetLogger(log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile)) SetLogger(log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile))

4
tx.go
View File

@ -1,12 +1,12 @@
package redis package redis
import ( import (
"github.com/go-redis/redis/internal"
"github.com/go-redis/redis/internal/pool" "github.com/go-redis/redis/internal/pool"
"github.com/go-redis/redis/internal/proto"
) )
// TxFailedErr transaction redis failed. // TxFailedErr transaction redis failed.
const TxFailedErr = internal.RedisError("redis: transaction failed") const TxFailedErr = proto.RedisError("redis: transaction failed")
// Tx implements Redis transactions as described in // Tx implements Redis transactions as described in
// http://redis.io/topics/transactions. It's NOT safe for concurrent use // http://redis.io/topics/transactions. It's NOT safe for concurrent use