diff --git a/commands_test.go b/commands_test.go index 476644d0..d0c43c7b 100644 --- a/commands_test.go +++ b/commands_test.go @@ -14,6 +14,15 @@ import ( "github.com/go-redis/redis/v9/internal/proto" ) +type TimeValue struct { + time.Time +} + +func (t *TimeValue) ScanRedis(s string) (err error) { + t.Time, err = time.Parse(time.RFC3339Nano, s) + return +} + var _ = Describe("Commands", func() { ctx := context.TODO() var client *redis.Client @@ -1192,19 +1201,28 @@ var _ = Describe("Commands", func() { }) It("should scan Mget", func() { - err := client.MSet(ctx, "key1", "hello1", "key2", 123).Err() + now := time.Now() + + err := client.MSet(ctx, "key1", "hello1", "key2", 123, "time", now.Format(time.RFC3339Nano)).Err() Expect(err).NotTo(HaveOccurred()) - res := client.MGet(ctx, "key1", "key2", "_") + res := client.MGet(ctx, "key1", "key2", "_", "time") Expect(res.Err()).NotTo(HaveOccurred()) type data struct { - Key1 string `redis:"key1"` - Key2 int `redis:"key2"` + Key1 string `redis:"key1"` + Key2 int `redis:"key2"` + Time TimeValue `redis:"time"` } var d data Expect(res.Scan(&d)).NotTo(HaveOccurred()) - Expect(d).To(Equal(data{Key1: "hello1", Key2: 123})) + Expect(d.Time.UnixNano()).To(Equal(now.UnixNano())) + d.Time.Time = time.Time{} + Expect(d).To(Equal(data{ + Key1: "hello1", + Key2: 123, + Time: TimeValue{Time: time.Time{}}, + })) }) It("should MSetNX", func() { @@ -1732,19 +1750,28 @@ var _ = Describe("Commands", func() { }) It("should scan", func() { - err := client.HMSet(ctx, "hash", "key1", "hello1", "key2", 123).Err() + now := time.Now() + + err := client.HMSet(ctx, "hash", "key1", "hello1", "key2", 123, "time", now.Format(time.RFC3339Nano)).Err() Expect(err).NotTo(HaveOccurred()) res := client.HGetAll(ctx, "hash") Expect(res.Err()).NotTo(HaveOccurred()) type data struct { - Key1 string `redis:"key1"` - Key2 int `redis:"key2"` + Key1 string `redis:"key1"` + Key2 int `redis:"key2"` + Time TimeValue `redis:"time"` } var d data Expect(res.Scan(&d)).NotTo(HaveOccurred()) - Expect(d).To(Equal(data{Key1: "hello1", Key2: 123})) + Expect(d.Time.UnixNano()).To(Equal(now.UnixNano())) + d.Time.Time = time.Time{} + Expect(d).To(Equal(data{ + Key1: "hello1", + Key2: 123, + Time: TimeValue{Time: time.Time{}}, + })) }) It("should HIncrBy", func() { diff --git a/internal/hscan/hscan.go b/internal/hscan/hscan.go index 852c8bd5..203ec4aa 100644 --- a/internal/hscan/hscan.go +++ b/internal/hscan/hscan.go @@ -10,6 +10,12 @@ import ( // decoderFunc represents decoding functions for default built-in types. type decoderFunc func(reflect.Value, string) error +// Scanner is the interface implemented by themselves, +// which will override the decoding behavior of decoderFunc. +type Scanner interface { + ScanRedis(s string) error +} + var ( // List of built-in decoders indexed by their numeric constant values (eg: reflect.Bool = 1). decoders = []decoderFunc{ diff --git a/internal/hscan/hscan_test.go b/internal/hscan/hscan_test.go index ab4c0e1d..72979fb3 100644 --- a/internal/hscan/hscan_test.go +++ b/internal/hscan/hscan_test.go @@ -4,6 +4,7 @@ import ( "math" "strconv" "testing" + "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -30,6 +31,20 @@ type data struct { Bool bool `redis:"bool"` } +type TimeRFC3339Nano struct { + time.Time +} + +func (t *TimeRFC3339Nano) ScanRedis(s string) (err error) { + t.Time, err = time.Parse(time.RFC3339Nano, s) + return +} + +type TimeData struct { + Name string `redis:"name"` + Time *TimeRFC3339Nano `redis:"login"` +} + type i []interface{} func TestGinkgoSuite(t *testing.T) { @@ -175,4 +190,14 @@ var _ = Describe("Scan", func() { Expect(Scan(&d, i{"bool"}, i{""})).To(HaveOccurred()) Expect(Scan(&d, i{"bool"}, i{"123"})).To(HaveOccurred()) }) + + It("Implements Scanner", func() { + var td TimeData + + now := time.Now() + Expect(Scan(&td, i{"name", "login"}, i{"hello", now.Format(time.RFC3339Nano)})).NotTo(HaveOccurred()) + Expect(td.Name).To(Equal("hello")) + Expect(td.Time.UnixNano()).To(Equal(now.UnixNano())) + Expect(td.Time.Format(time.RFC3339Nano)).To(Equal(now.Format(time.RFC3339Nano))) + }) }) diff --git a/internal/hscan/structmap.go b/internal/hscan/structmap.go index 6839412b..9befd987 100644 --- a/internal/hscan/structmap.go +++ b/internal/hscan/structmap.go @@ -84,7 +84,29 @@ func (s StructValue) Scan(key string, value string) error { if !ok { return nil } - if err := field.fn(s.value.Field(field.index), value); err != nil { + + v := s.value.Field(field.index) + isPtr := v.Kind() == reflect.Pointer + + if isPtr && v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + if !isPtr && v.Type().Name() != "" && v.CanAddr() { + v = v.Addr() + isPtr = true + } + + if isPtr && v.Type().NumMethod() > 0 && v.CanInterface() { + if scan, ok := v.Interface().(Scanner); ok { + return scan.ScanRedis(value) + } + } + + if isPtr { + v = v.Elem() + } + + if err := field.fn(v, value); err != nil { t := s.value.Type() return fmt.Errorf("cannot scan redis.result %s into struct field %s.%s of type %s, error-%s", value, t.Name(), t.Field(field.index).Name, t.Field(field.index).Type, err.Error()) diff --git a/redis.go b/redis.go index ee6c9bdb..83e9d404 100644 --- a/redis.go +++ b/redis.go @@ -10,10 +10,14 @@ import ( "time" "github.com/go-redis/redis/v9/internal" + "github.com/go-redis/redis/v9/internal/hscan" "github.com/go-redis/redis/v9/internal/pool" "github.com/go-redis/redis/v9/internal/proto" ) +// Scanner internal/hscan.Scanner exposed interface. +type Scanner = hscan.Scanner + // Nil reply returned by Redis when key does not exist. const Nil = proto.Nil