diff --git a/commands.go b/commands.go index f58b9a36..0bad2324 100644 --- a/commands.go +++ b/commands.go @@ -5,6 +5,7 @@ import ( "errors" "io" "reflect" + "strings" "time" "github.com/go-redis/redis/v9/internal" @@ -75,31 +76,44 @@ func appendArg(dst []interface{}, arg interface{}) []interface{} { } return dst default: + // scan struct field + v := reflect.ValueOf(arg) + if v.Type().Kind() == reflect.Ptr { + if v.IsNil() { + // error: arg is not a valid object + return dst + } + v = v.Elem() + } + + if v.Type().Kind() == reflect.Struct { + return appendStructField(dst, v) + } + return append(dst, arg) } } -func structToMap(items interface{}) map[string]interface{} { - res := map[string]interface{}{} - if items == nil { - return res - } - v := reflect.TypeOf(items) - reflectValue := reflect.Indirect(reflect.ValueOf(items)) +// appendStructField appends the field and value held by the structure v to dst, and returns the appended dst. +func appendStructField(dst []interface{}, v reflect.Value) []interface{} { + typ := v.Type() + for i := 0; i < typ.NumField(); i++ { + tag := typ.Field(i).Tag.Get("redis") + if tag == "" || tag == "-" { + continue + } + tag = strings.Split(tag, ",")[0] + if tag == "" { + continue + } - if v.Kind() == reflect.Ptr { - v = v.Elem() - } - for i := 0; i < v.NumField(); i++ { - tag := v.Field(i).Tag.Get("redis") - - if tag != "" && v.Field(i).Type.Kind() != reflect.Struct { - field := reflectValue.Field(i).Interface() - res[tag] = field + field := v.Field(i) + if field.CanInterface() { + dst = append(dst, tag, field.Interface()) } } - return res + return dst } type Cmdable interface { @@ -904,6 +918,7 @@ func (c cmdable) MGet(ctx context.Context, keys ...string) *SliceCmd { // - MSet("key1", "value1", "key2", "value2") // - MSet([]string{"key1", "value1", "key2", "value2"}) // - MSet(map[string]interface{}{"key1": "value1", "key2": "value2"}) +// - MSet(struct), For struct types, see HSet description. func (c cmdable) MSet(ctx context.Context, values ...interface{}) *StatusCmd { args := make([]interface{}, 1, 1+len(values)) args[0] = "mset" @@ -917,6 +932,7 @@ func (c cmdable) MSet(ctx context.Context, values ...interface{}) *StatusCmd { // - MSetNX("key1", "value1", "key2", "value2") // - MSetNX([]string{"key1", "value1", "key2", "value2"}) // - MSetNX(map[string]interface{}{"key1": "value1", "key2": "value2"}) +// - MSetNX(struct), For struct types, see HSet description. func (c cmdable) MSetNX(ctx context.Context, values ...interface{}) *BoolCmd { args := make([]interface{}, 1, 1+len(values)) args[0] = "msetnx" @@ -1319,21 +1335,27 @@ func (c cmdable) HMGet(ctx context.Context, key string, fields ...string) *Slice } // HSet accepts values in following formats: +// // - HSet("myhash", "key1", "value1", "key2", "value2") +// // - HSet("myhash", []string{"key1", "value1", "key2", "value2"}) +// // - HSet("myhash", map[string]interface{}{"key1": "value1", "key2": "value2"}) // -// Playing struct With "redis" tag -// - type MyHash struct { Key1 string `redis:"key1"`; Key2 int `redis:"key2"` } +// Playing struct With "redis" tag. +// type MyHash struct { Key1 string `redis:"key1"`; Key2 int `redis:"key2"` } +// // - HSet("myhash", MyHash{"value1", "value2"}) // +// For struct, can be a structure pointer type, we only parse the field whose tag is redis. +// if you don't want the field to be read, you can use the `redis:"-"` flag to ignore it, +// or you don't need to set the redis tag. +// For the type of structure field, we only support simple data types: +// string, int/uint(8,16,32,64), float(32,64), time.Time(to RFC3339Nano), time.Duration(to Nanoseconds ), +// if you are other more complex or custom data types, please implement the encoding.BinaryMarshaler interface. +// // Note that it requires Redis v4 for multiple field/value pairs support. func (c cmdable) HSet(ctx context.Context, key string, values ...interface{}) *IntCmd { - if len(values) == 1 { - if reflect.ValueOf(values[0]).Kind() == reflect.Struct { - values = []interface{}{structToMap(values[0])} - } - } args := make([]interface{}, 2, 2+len(values)) args[0] = "hset" args[1] = key diff --git a/commands_test.go b/commands_test.go index def9d7fb..6a11ec8b 100644 --- a/commands_test.go +++ b/commands_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "reflect" + "strconv" "time" . "github.com/onsi/ginkgo" @@ -1220,6 +1221,33 @@ var _ = Describe("Commands", func() { mGet := client.MGet(ctx, "key1", "key2", "_") Expect(mGet.Err()).NotTo(HaveOccurred()) Expect(mGet.Val()).To(Equal([]interface{}{"hello1", "hello2", nil})) + + // MSet struct + type set struct { + Set1 string `redis:"set1"` + Set2 int16 `redis:"set2"` + Set3 time.Duration `redis:"set3"` + Set4 interface{} `redis:"set4"` + Set5 map[string]interface{} `redis:"-"` + } + mSet = client.MSet(ctx, &set{ + Set1: "val1", + Set2: 1024, + Set3: 2 * time.Minute, + Set4: nil, + Set5: map[string]interface{}{"k1": 1}, + }) + Expect(mSet.Err()).NotTo(HaveOccurred()) + Expect(mSet.Val()).To(Equal("OK")) + + mGet = client.MGet(ctx, "set1", "set2", "set3", "set4") + Expect(mGet.Err()).NotTo(HaveOccurred()) + Expect(mGet.Val()).To(Equal([]interface{}{ + "val1", + "1024", + strconv.Itoa(int(2 * time.Minute.Nanoseconds())), + "", + })) }) It("should scan Mget", func() { @@ -1255,6 +1283,25 @@ var _ = Describe("Commands", func() { mSetNX = client.MSetNX(ctx, "key2", "hello1", "key3", "hello2") Expect(mSetNX.Err()).NotTo(HaveOccurred()) Expect(mSetNX.Val()).To(Equal(false)) + + // set struct + // MSet struct + type set struct { + Set1 string `redis:"set1"` + Set2 int16 `redis:"set2"` + Set3 time.Duration `redis:"set3"` + Set4 interface{} `redis:"set4"` + Set5 map[string]interface{} `redis:"-"` + } + mSetNX = client.MSetNX(ctx, &set{ + Set1: "val1", + Set2: 1024, + Set3: 2 * time.Minute, + Set4: nil, + Set5: map[string]interface{}{"k1": 1}, + }) + Expect(mSetNX.Err()).NotTo(HaveOccurred()) + Expect(mSetNX.Val()).To(Equal(true)) }) It("should SetWithArgs with TTL", func() { @@ -1895,6 +1942,35 @@ var _ = Describe("Commands", func() { hGet := client.HGet(ctx, "hash", "key") Expect(hGet.Err()).NotTo(HaveOccurred()) Expect(hGet.Val()).To(Equal("hello")) + + // set struct + // MSet struct + type set struct { + Set1 string `redis:"set1"` + Set2 int16 `redis:"set2"` + Set3 time.Duration `redis:"set3"` + Set4 interface{} `redis:"set4"` + Set5 map[string]interface{} `redis:"-"` + } + + hSet = client.HSet(ctx, "hash", &set{ + Set1: "val1", + Set2: 1024, + Set3: 2 * time.Minute, + Set4: nil, + Set5: map[string]interface{}{"k1": 1}, + }) + Expect(hSet.Err()).NotTo(HaveOccurred()) + Expect(hSet.Val()).To(Equal(int64(4))) + + hMGet := client.HMGet(ctx, "hash", "set1", "set2", "set3", "set4") + Expect(hMGet.Err()).NotTo(HaveOccurred()) + Expect(hMGet.Val()).To(Equal([]interface{}{ + "val1", + "1024", + strconv.Itoa(int(2 * time.Minute.Nanoseconds())), + "", + })) }) It("should HSetNX", func() {