Fix embedded scanner/valuer, close #3283

This commit is contained in:
Jinzhu 2020-08-19 15:47:08 +08:00
parent 3411425d65
commit c1782d60c1
2 changed files with 27 additions and 13 deletions

View File

@ -92,25 +92,30 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
valuer, isValuer := fieldValue.Interface().(driver.Valuer) valuer, isValuer := fieldValue.Interface().(driver.Valuer)
if isValuer { if isValuer {
if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok { if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok {
var overrideFieldValue bool if v, err := valuer.Value(); reflect.ValueOf(v).IsValid() && err == nil {
if v, err := valuer.Value(); v != nil && err == nil {
overrideFieldValue = true
fieldValue = reflect.ValueOf(v) fieldValue = reflect.ValueOf(v)
} }
if field.IndirectFieldType.Kind() == reflect.Struct { var getRealFieldValue func(reflect.Value)
for i := 0; i < field.IndirectFieldType.NumField(); i++ { getRealFieldValue = func(v reflect.Value) {
if !overrideFieldValue { rv := reflect.Indirect(v)
newFieldType := field.IndirectFieldType.Field(i).Type if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) {
for i := 0; i < rv.Type().NumField(); i++ {
newFieldType := rv.Type().Field(i).Type
for newFieldType.Kind() == reflect.Ptr { for newFieldType.Kind() == reflect.Ptr {
newFieldType = newFieldType.Elem() newFieldType = newFieldType.Elem()
} }
fieldValue = reflect.New(newFieldType) fieldValue = reflect.New(newFieldType)
overrideFieldValue = true
if rv.Type() != reflect.Indirect(fieldValue).Type() {
getRealFieldValue(fieldValue)
}
if fieldValue.IsValid() {
return
} }
// copy tag settings from valuer
for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") {
if _, ok := field.TagSettings[key]; !ok { if _, ok := field.TagSettings[key]; !ok {
field.TagSettings[key] = value field.TagSettings[key] = value
@ -119,6 +124,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
} }
} }
getRealFieldValue(fieldValue)
}
} }
if dbName, ok := field.TagSettings["COLUMN"]; ok { if dbName, ok := field.TagSettings["COLUMN"]; ok {

View File

@ -27,6 +27,7 @@ func TestScannerValuer(t *testing.T) {
Male: sql.NullBool{Bool: true, Valid: true}, Male: sql.NullBool{Bool: true, Valid: true},
Height: sql.NullFloat64{Float64: 1.8888, Valid: true}, Height: sql.NullFloat64{Float64: 1.8888, Valid: true},
Birthday: sql.NullTime{Time: time.Now(), Valid: true}, Birthday: sql.NullTime{Time: time.Now(), Valid: true},
Allergen: NullString{sql.NullString{String: "Allergen", Valid: true}},
Password: EncryptedData("pass1"), Password: EncryptedData("pass1"),
Bytes: []byte("byte"), Bytes: []byte("byte"),
Num: 18, Num: 18,
@ -143,6 +144,7 @@ type ScannerValuerStruct struct {
Male sql.NullBool Male sql.NullBool
Height sql.NullFloat64 Height sql.NullFloat64
Birthday sql.NullTime Birthday sql.NullTime
Allergen NullString
Password EncryptedData Password EncryptedData
Bytes []byte Bytes []byte
Num Num Num Num
@ -299,3 +301,7 @@ func (t *EmptyTime) Scan(v interface{}) error {
func (t EmptyTime) Value() (driver.Value, error) { func (t EmptyTime) Value() (driver.Value, error) {
return time.Now() /* pass tests, mysql 8 doesn't support 0000-00-00 by default */, nil return time.Now() /* pass tests, mysql 8 doesn't support 0000-00-00 by default */, nil
} }
type NullString struct {
sql.NullString
}