fix: serializer use default valueOf in assignInterfacesToValue, close #5168

commit 58e1b2bffbc216f2862d040fb545a8a486e473b6
Author: Jinzhu <wosmvp@gmail.com>
Date:   Fri Mar 18 17:06:43 2022 +0800

    Refactor #5168

commit fb9233011d
Author: ag9920 <alexgong7@outlook.com>
Date:   Thu Mar 17 21:23:28 2022 +0800

    fix: serializer use default valueOf in assignInterfacesToValue
This commit is contained in:
ag9920 2022-03-18 17:12:17 +08:00 committed by Jinzhu
parent e6f7da0e0d
commit 3c00980e01
3 changed files with 95 additions and 38 deletions

View File

@ -435,39 +435,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
// create valuer, setter when parse struct // create valuer, setter when parse struct
func (field *Field) setupValuerAndSetter() { func (field *Field) setupValuerAndSetter() {
// Setup NewValuePool // Setup NewValuePool
var fieldValue = reflect.New(field.FieldType).Interface() field.setupNewValuePool()
if field.Serializer != nil {
field.NewValuePool = &sync.Pool{
New: func() interface{} {
return &serializer{
Field: field,
Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface),
}
},
}
} else if _, ok := fieldValue.(sql.Scanner); !ok {
// set default NewValuePool
switch field.IndirectFieldType.Kind() {
case reflect.String:
field.NewValuePool = stringPool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.NewValuePool = intPool
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.NewValuePool = uintPool
case reflect.Float32, reflect.Float64:
field.NewValuePool = floatPool
case reflect.Bool:
field.NewValuePool = boolPool
default:
if field.IndirectFieldType == TimeReflectType {
field.NewValuePool = timePool
}
}
}
if field.NewValuePool == nil {
field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType))
}
// ValueOf returns field's value and if it is zero // ValueOf returns field's value and if it is zero
fieldIndex := field.StructField.Index[0] fieldIndex := field.StructField.Index[0]
@ -512,7 +480,7 @@ func (field *Field) setupValuerAndSetter() {
s = field.Serializer s = field.Serializer
} }
return serializer{ return &serializer{
Field: field, Field: field,
SerializeValuer: s, SerializeValuer: s,
Destination: v, Destination: v,
@ -943,7 +911,9 @@ func (field *Field) setupValuerAndSetter() {
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
if s, ok := v.(*serializer); ok { if s, ok := v.(*serializer); ok {
if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil { if s.fieldValue != nil {
err = oldFieldSetter(ctx, value, s.fieldValue)
} else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil {
if sameElemType { if sameElemType {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem()) field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem())
s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface)
@ -959,3 +929,43 @@ func (field *Field) setupValuerAndSetter() {
} }
} }
} }
func (field *Field) setupNewValuePool() {
var fieldValue = reflect.New(field.FieldType).Interface()
if field.Serializer != nil {
field.NewValuePool = &sync.Pool{
New: func() interface{} {
return &serializer{
Field: field,
Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface),
}
},
}
} else if _, ok := fieldValue.(sql.Scanner); !ok {
field.setupDefaultNewValuePool()
}
if field.NewValuePool == nil {
field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType))
}
}
func (field *Field) setupDefaultNewValuePool() {
// set default NewValuePool
switch field.IndirectFieldType.Kind() {
case reflect.String:
field.NewValuePool = stringPool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.NewValuePool = intPool
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.NewValuePool = uintPool
case reflect.Float32, reflect.Float64:
field.NewValuePool = floatPool
case reflect.Bool:
field.NewValuePool = boolPool
default:
if field.IndirectFieldType == TimeReflectType {
field.NewValuePool = timePool
}
}
}

View File

@ -202,8 +202,6 @@ func TestJoinCount(t *testing.T) {
} }
func TestJoinWithSoftDeleted(t *testing.T) { func TestJoinWithSoftDeleted(t *testing.T) {
DB = DB.Debug()
user := GetUser("TestJoinWithSoftDeletedUser", Config{Account: true, NamedPet: true}) user := GetUser("TestJoinWithSoftDeletedUser", Config{Account: true, NamedPet: true})
DB.Create(&user) DB.Create(&user)

View File

@ -42,7 +42,7 @@ func (es *EncryptedString) Scan(ctx context.Context, field *schema.Field, dst re
case string: case string:
*es = EncryptedString(strings.TrimPrefix(value, "hello")) *es = EncryptedString(strings.TrimPrefix(value, "hello"))
default: default:
return fmt.Errorf("unsupported data %v", dbValue) return fmt.Errorf("unsupported data %#v", dbValue)
} }
return nil return nil
} }
@ -83,4 +83,53 @@ func TestSerializer(t *testing.T) {
} }
AssertEqual(t, result, data) AssertEqual(t, result, data)
}
func TestSerializerAssignFirstOrCreate(t *testing.T) {
DB.Migrator().DropTable(&SerializerStruct{})
if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil {
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
}
createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
data := SerializerStruct{
Name: []byte("ag9920"),
Roles: []string{"r1", "r2"},
Contracts: map[string]interface{}{"name": "jing1", "age": 11},
EncryptedString: EncryptedString("pass"),
CreatedTime: createdAt.Unix(),
JobInfo: Job{
Title: "programmer",
Number: 9920,
Location: "Shadyside",
IsIntern: false,
},
}
// first time insert record
out := SerializerStruct{}
if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil {
t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err)
}
var result SerializerStruct
if err := DB.First(&result, out.ID).Error; err != nil {
t.Fatalf("failed to query data, got error %v", err)
}
AssertEqual(t, result, out)
//update record
data.Roles = append(data.Roles, "r3")
data.JobInfo.Location = "Gates Hillman Complex"
if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil {
t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err)
}
if err := DB.First(&result, out.ID).Error; err != nil {
t.Fatalf("failed to query data, got error %v", err)
}
AssertEqual(t, result.Roles, data.Roles)
AssertEqual(t, result.JobInfo.Location, data.JobInfo.Location)
} }