From c1782d60c149483111b021e29c412d9139bd46ba Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 19 Aug 2020 15:47:08 +0800 Subject: [PATCH] Fix embedded scanner/valuer, close #3283 --- schema/field.go | 34 +++++++++++++++++++++------------- tests/scanner_valuer_test.go | 6 ++++++ 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/schema/field.go b/schema/field.go index 35c1e44d..59367399 100644 --- a/schema/field.go +++ b/schema/field.go @@ -92,32 +92,40 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { valuer, isValuer := fieldValue.Interface().(driver.Valuer) if isValuer { if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok { - var overrideFieldValue bool - if v, err := valuer.Value(); v != nil && err == nil { - overrideFieldValue = true + if v, err := valuer.Value(); reflect.ValueOf(v).IsValid() && err == nil { fieldValue = reflect.ValueOf(v) } - if field.IndirectFieldType.Kind() == reflect.Struct { - for i := 0; i < field.IndirectFieldType.NumField(); i++ { - if !overrideFieldValue { - newFieldType := field.IndirectFieldType.Field(i).Type + var getRealFieldValue func(reflect.Value) + getRealFieldValue = func(v reflect.Value) { + rv := reflect.Indirect(v) + if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + for i := 0; i < rv.Type().NumField(); i++ { + newFieldType := rv.Type().Field(i).Type for newFieldType.Kind() == reflect.Ptr { newFieldType = newFieldType.Elem() } fieldValue = reflect.New(newFieldType) - overrideFieldValue = true - } - // copy tag settings from valuer - for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { - if _, ok := field.TagSettings[key]; !ok { - field.TagSettings[key] = value + if rv.Type() != reflect.Indirect(fieldValue).Type() { + getRealFieldValue(fieldValue) + } + + if fieldValue.IsValid() { + return + } + + for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { + if _, ok := field.TagSettings[key]; !ok { + field.TagSettings[key] = value + } } } } } + + getRealFieldValue(fieldValue) } } diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index b8306af7..ce8a2b50 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -27,6 +27,7 @@ func TestScannerValuer(t *testing.T) { Male: sql.NullBool{Bool: true, Valid: true}, Height: sql.NullFloat64{Float64: 1.8888, Valid: true}, Birthday: sql.NullTime{Time: time.Now(), Valid: true}, + Allergen: NullString{sql.NullString{String: "Allergen", Valid: true}}, Password: EncryptedData("pass1"), Bytes: []byte("byte"), Num: 18, @@ -143,6 +144,7 @@ type ScannerValuerStruct struct { Male sql.NullBool Height sql.NullFloat64 Birthday sql.NullTime + Allergen NullString Password EncryptedData Bytes []byte Num Num @@ -299,3 +301,7 @@ func (t *EmptyTime) Scan(v interface{}) error { func (t EmptyTime) Value() (driver.Value, error) { return time.Now() /* pass tests, mysql 8 doesn't support 0000-00-00 by default */, nil } + +type NullString struct { + sql.NullString +}