From fb9233011d209174e8223e970f0f732412852908 Mon Sep 17 00:00:00 2001 From: ag9920 Date: Thu, 17 Mar 2022 21:23:28 +0800 Subject: [PATCH] fix: serializer use default valueOf in assignInterfacesToValue --- finisher_api.go | 4 ++ schema/field.go | 97 ++++++++++++++++++++++++++-------------- tests/serializer_test.go | 48 ++++++++++++++++++++ 3 files changed, 116 insertions(+), 33 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 4b428a59..5a89e348 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -240,6 +240,10 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { if f.Readable { if v, isZero := f.ValueOf(tx.Statement.Context, reflectValue); !isZero { if field := tx.Statement.Schema.LookUpField(f.Name); field != nil { + // serializer should use default implementation of ValueOf when assign to value + if field.Serializer != nil { + v, _ = field.DefaultValueOf(tx.Statement.Context, reflectValue) + } tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, v)) } } diff --git a/schema/field.go b/schema/field.go index 826680c5..97938d95 100644 --- a/schema/field.go +++ b/schema/field.go @@ -430,39 +430,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { // create valuer, setter when parse struct func (field *Field) setupValuerAndSetter() { // Setup NewValuePool - var fieldValue = reflect.New(field.FieldType).Interface() - if field.Serializer != nil { - field.NewValuePool = &sync.Pool{ - New: func() interface{} { - return &serializer{ - Field: field, - Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface), - } - }, - } - } else if _, ok := fieldValue.(sql.Scanner); !ok { - // set default NewValuePool - switch field.IndirectFieldType.Kind() { - case reflect.String: - field.NewValuePool = stringPool - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - field.NewValuePool = intPool - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - field.NewValuePool = uintPool - case reflect.Float32, reflect.Float64: - field.NewValuePool = floatPool - case reflect.Bool: - field.NewValuePool = boolPool - default: - if field.IndirectFieldType == TimeReflectType { - field.NewValuePool = timePool - } - } - } - - if field.NewValuePool == nil { - field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType)) - } + field.setupNewValuePool() // ValueOf returns field's value and if it is zero fieldIndex := field.StructField.Index[0] @@ -954,3 +922,66 @@ func (field *Field) setupValuerAndSetter() { } } } + +func (field *Field) DefaultValueOf(ctx context.Context, v reflect.Value) (interface{}, bool) { + fieldIndex := field.StructField.Index[0] + if len(field.StructField.Index) == 1 && fieldIndex > 0 { + fieldValue := reflect.Indirect(v).Field(fieldIndex) + return fieldValue.Interface(), fieldValue.IsZero() + } + v = reflect.Indirect(v) + for _, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) + } else { + v = v.Field(-fieldIdx - 1) + + if !v.IsNil() { + v = v.Elem() + } else { + return nil, true + } + } + } + return v.Interface(), v.IsZero() +} + +func (field *Field) setupNewValuePool() { + var fieldValue = reflect.New(field.FieldType).Interface() + if field.Serializer != nil { + field.NewValuePool = &sync.Pool{ + New: func() interface{} { + return &serializer{ + Field: field, + Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface), + } + }, + } + } else if _, ok := fieldValue.(sql.Scanner); !ok { + field.setupDefaultNewValuePool() + } + + if field.NewValuePool == nil { + field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType)) + } +} + +func (field *Field) setupDefaultNewValuePool() { + // set default NewValuePool + switch field.IndirectFieldType.Kind() { + case reflect.String: + field.NewValuePool = stringPool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + field.NewValuePool = intPool + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + field.NewValuePool = uintPool + case reflect.Float32, reflect.Float64: + field.NewValuePool = floatPool + case reflect.Bool: + field.NewValuePool = boolPool + default: + if field.IndirectFieldType == TimeReflectType { + field.NewValuePool = timePool + } + } +} diff --git a/tests/serializer_test.go b/tests/serializer_test.go index a8a4e28f..b65fa823 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -83,4 +83,52 @@ func TestSerializer(t *testing.T) { } AssertEqual(t, result, data) + +} + +func TestSerializer_AssignFirstOrCreate(t *testing.T) { + DB.Migrator().DropTable(&SerializerStruct{}) + if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil { + t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) + } + + createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + + data := SerializerStruct{ + Name: []byte("ag9920"), + Roles: []string{"r1", "r2"}, + Contracts: map[string]interface{}{"name": "jing1", "age": 11}, + EncryptedString: EncryptedString("pass"), + CreatedTime: createdAt.Unix(), + JobInfo: Job{ + Title: "programmer", + Number: 9920, + Location: "Shadyside", + IsIntern: false, + }, + } + + // first time insert record + out := SerializerStruct{} + if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil { + t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err) + } + var result SerializerStruct + if err := DB.First(&result, out.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + AssertEqual(t, result, out) + + //update record + data.Roles = append(data.Roles, "r3") + data.JobInfo.Location = "Gates Hillman Complex" + if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil { + t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err) + } + if err := DB.First(&result, out.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + + AssertEqual(t, result.Roles, data.Roles) + AssertEqual(t, result.JobInfo.Location, data.JobInfo.Location) }