diff --git a/internal/proto/scan.go b/internal/proto/scan.go index 8fa0323..08d18d3 100644 --- a/internal/proto/scan.go +++ b/internal/proto/scan.go @@ -4,10 +4,13 @@ import ( "encoding" "fmt" "reflect" + "time" "github.com/go-redis/redis/v8/internal/util" ) +// Scan parses bytes `b` to `v` with appropriate type. +// nolint: gocyclo func Scan(b []byte, v interface{}) error { switch v := v.(type) { case nil: @@ -99,6 +102,10 @@ func Scan(b []byte, v interface{}) error { case *bool: *v = len(b) == 1 && b[0] == '1' return nil + case *time.Time: + var err error + *v, err = time.Parse(time.RFC3339Nano, util.BytesToString(b)) + return err case encoding.BinaryUnmarshaler: return v.UnmarshalBinary(b) default: diff --git a/internal/proto/scan_test.go b/internal/proto/scan_test.go index fadcd05..034648c 100644 --- a/internal/proto/scan_test.go +++ b/internal/proto/scan_test.go @@ -1,8 +1,14 @@ -package proto +package proto_test import ( + "context" "encoding/json" + "errors" + "testing" + "time" + "github.com/go-redis/redis/v8" + "github.com/go-redis/redis/v8/internal/proto" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -28,7 +34,7 @@ var _ = Describe("ScanSlice", func() { It("[]testScanSliceStruct", func() { var slice []testScanSliceStruct - err := ScanSlice(data, &slice) + err := proto.ScanSlice(data, &slice) Expect(err).NotTo(HaveOccurred()) Expect(slice).To(Equal([]testScanSliceStruct{ {-1, "Back Yu"}, @@ -38,7 +44,7 @@ var _ = Describe("ScanSlice", func() { It("var testContainer []*testScanSliceStruct", func() { var slice []*testScanSliceStruct - err := ScanSlice(data, &slice) + err := proto.ScanSlice(data, &slice) Expect(err).NotTo(HaveOccurred()) Expect(slice).To(Equal([]*testScanSliceStruct{ {-1, "Back Yu"}, @@ -46,3 +52,28 @@ var _ = Describe("ScanSlice", func() { })) }) }) + +func TestScan(t *testing.T) { + t.Parallel() + + t.Run("time", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + rdb := redis.NewClient(&redis.Options{ + Addr: ":6379", + }) + + tm := time.Now() + rdb.Set(ctx, "now", tm, 0) + + var tm2 time.Time + rdb.Get(ctx, "now").Scan(&tm2) + + if !tm2.Equal(tm) { + t.Fatal(errors.New("tm2 and tm are not equal")) + } + }) + +}