diff --git a/internal/proto/scan.go b/internal/proto/scan.go index 0431a877..03c8b59a 100644 --- a/internal/proto/scan.go +++ b/internal/proto/scan.go @@ -120,8 +120,9 @@ func ScanSlice(data []string, slice interface{}) error { return fmt.Errorf("redis: ScanSlice(non-slice %T)", slice) } + next := internal.MakeSliceNextElemFunc(v) for i, s := range data { - elem := internal.SliceNextElem(v) + elem := next() if err := Scan(internal.StringToBytes(s), elem.Addr().Interface()); err != nil { return fmt.Errorf("redis: ScanSlice(index=%d value=%q) failed: %s", i, s, err) } diff --git a/internal/util.go b/internal/util.go index 520596fd..1ba9805f 100644 --- a/internal/util.go +++ b/internal/util.go @@ -28,20 +28,35 @@ func isLower(s string) bool { return true } -func SliceNextElem(v reflect.Value) reflect.Value { - if v.Len() < v.Cap() { - v.Set(v.Slice(0, v.Len()+1)) - return v.Index(v.Len() - 1) - } - +func MakeSliceNextElemFunc(v reflect.Value) func() reflect.Value { elemType := v.Type().Elem() if elemType.Kind() == reflect.Ptr { - elem := reflect.New(elemType.Elem()) - v.Set(reflect.Append(v, elem)) - return elem.Elem() + elemType = elemType.Elem() + return func() reflect.Value { + if v.Len() < v.Cap() { + v.Set(v.Slice(0, v.Len()+1)) + elem := v.Index(v.Len() - 1) + if elem.IsNil() { + elem.Set(reflect.New(elemType)) + } + return elem.Elem() + } + + elem := reflect.New(elemType) + v.Set(reflect.Append(v, elem)) + return elem.Elem() + } } - v.Set(reflect.Append(v, reflect.Zero(elemType))) - return v.Index(v.Len() - 1) + zero := reflect.Zero(elemType) + return func() reflect.Value { + if v.Len() < v.Cap() { + v.Set(v.Slice(0, v.Len()+1)) + return v.Index(v.Len() - 1) + } + + v.Set(reflect.Append(v, zero)) + return v.Index(v.Len() - 1) + } }