This commit is contained in:
oldme 2024-09-04 14:33:32 +08:00
parent 00d98485f8
commit 69005e6c3f
2 changed files with 45 additions and 0 deletions

View File

@ -5,6 +5,7 @@ import (
"fmt"
"io"
"net"
"reflect"
"strconv"
"time"
@ -140,6 +141,36 @@ func (w *Writer) WriteArg(v interface{}) error {
return w.bytes(b)
case net.IP:
return w.bytes(v)
default:
return w.writeArgExtra(v)
}
}
func (w *Writer) writeArgExtra(v interface{}) error {
var (
rfValue = reflect.ValueOf(v)
rfKind = rfValue.Kind()
)
switch rfKind {
case reflect.Bool:
if rfValue.Bool() {
return w.int(1)
}
return w.int(0)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return w.int(rfValue.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return w.uint(rfValue.Uint())
case reflect.Float32, reflect.Float64:
return w.float(rfValue.Float())
case reflect.String:
return w.string(rfValue.String())
case reflect.Slice:
if rfValue.Type().Elem().Kind() == reflect.Uint8 {
return w.bytes(rfValue.Bytes())
}
fallthrough
default:
return fmt.Errorf(
"redis: can't marshal %T (implement encoding.BinaryMarshaler)", v)

View File

@ -362,6 +362,20 @@ var _ = Describe("Client", func() {
Expect(ip2).To(Equal(ip))
})
It("should set and scan custom type", func() {
type customString string
val := customString("hello")
err := client.Set(ctx, "custom", val, 0).Err()
Expect(err).NotTo(HaveOccurred())
var val2 customString
err = client.Get(ctx, "custom").Scan(&val2)
Expect(err).NotTo(HaveOccurred())
Expect(val2).To(Equal(val))
})
})
var _ = Describe("Client timeout", func() {