diff --git a/schema/field.go b/schema/field.go index 0d7085a9..45ec66e1 100644 --- a/schema/field.go +++ b/schema/field.go @@ -435,39 +435,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { // create valuer, setter when parse struct func (field *Field) setupValuerAndSetter() { // Setup NewValuePool - var fieldValue = reflect.New(field.FieldType).Interface() - if field.Serializer != nil { - field.NewValuePool = &sync.Pool{ - New: func() interface{} { - return &serializer{ - Field: field, - Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface), - } - }, - } - } else if _, ok := fieldValue.(sql.Scanner); !ok { - // set default NewValuePool - switch field.IndirectFieldType.Kind() { - case reflect.String: - field.NewValuePool = stringPool - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - field.NewValuePool = intPool - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - field.NewValuePool = uintPool - case reflect.Float32, reflect.Float64: - field.NewValuePool = floatPool - case reflect.Bool: - field.NewValuePool = boolPool - default: - if field.IndirectFieldType == TimeReflectType { - field.NewValuePool = timePool - } - } - } - - if field.NewValuePool == nil { - field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType)) - } + field.setupNewValuePool() // ValueOf returns field's value and if it is zero fieldIndex := field.StructField.Index[0] @@ -512,7 +480,7 @@ func (field *Field) setupValuerAndSetter() { s = field.Serializer } - return serializer{ + return &serializer{ Field: field, SerializeValuer: s, Destination: v, @@ -943,7 +911,9 @@ func (field *Field) setupValuerAndSetter() { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { if s, ok := v.(*serializer); ok { - if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil { + if s.fieldValue != nil { + err = oldFieldSetter(ctx, value, s.fieldValue) + } else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil { if sameElemType { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem()) s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) @@ -959,3 +929,43 @@ func (field *Field) setupValuerAndSetter() { } } } + +func (field *Field) setupNewValuePool() { + var fieldValue = reflect.New(field.FieldType).Interface() + if field.Serializer != nil { + field.NewValuePool = &sync.Pool{ + New: func() interface{} { + return &serializer{ + Field: field, + Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface), + } + }, + } + } else if _, ok := fieldValue.(sql.Scanner); !ok { + field.setupDefaultNewValuePool() + } + + if field.NewValuePool == nil { + field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType)) + } +} + +func (field *Field) setupDefaultNewValuePool() { + // set default NewValuePool + switch field.IndirectFieldType.Kind() { + case reflect.String: + field.NewValuePool = stringPool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + field.NewValuePool = intPool + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + field.NewValuePool = uintPool + case reflect.Float32, reflect.Float64: + field.NewValuePool = floatPool + case reflect.Bool: + field.NewValuePool = boolPool + default: + if field.IndirectFieldType == TimeReflectType { + field.NewValuePool = timePool + } + } +} diff --git a/tests/joins_test.go b/tests/joins_test.go index 0f02f3f9..bb5352ef 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -202,8 +202,6 @@ func TestJoinCount(t *testing.T) { } func TestJoinWithSoftDeleted(t *testing.T) { - DB = DB.Debug() - user := GetUser("TestJoinWithSoftDeletedUser", Config{Account: true, NamedPet: true}) DB.Create(&user) diff --git a/tests/serializer_test.go b/tests/serializer_test.go index a8a4e28f..ce60280e 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -42,7 +42,7 @@ func (es *EncryptedString) Scan(ctx context.Context, field *schema.Field, dst re case string: *es = EncryptedString(strings.TrimPrefix(value, "hello")) default: - return fmt.Errorf("unsupported data %v", dbValue) + return fmt.Errorf("unsupported data %#v", dbValue) } return nil } @@ -83,4 +83,53 @@ func TestSerializer(t *testing.T) { } AssertEqual(t, result, data) + +} + +func TestSerializerAssignFirstOrCreate(t *testing.T) { + DB.Migrator().DropTable(&SerializerStruct{}) + if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil { + t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) + } + + createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + + data := SerializerStruct{ + Name: []byte("ag9920"), + Roles: []string{"r1", "r2"}, + Contracts: map[string]interface{}{"name": "jing1", "age": 11}, + EncryptedString: EncryptedString("pass"), + CreatedTime: createdAt.Unix(), + JobInfo: Job{ + Title: "programmer", + Number: 9920, + Location: "Shadyside", + IsIntern: false, + }, + } + + // first time insert record + out := SerializerStruct{} + if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil { + t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err) + } + + var result SerializerStruct + if err := DB.First(&result, out.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + AssertEqual(t, result, out) + + //update record + data.Roles = append(data.Roles, "r3") + data.JobInfo.Location = "Gates Hillman Complex" + if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil { + t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err) + } + if err := DB.First(&result, out.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + + AssertEqual(t, result.Roles, data.Roles) + AssertEqual(t, result.JobInfo.Location, data.JobInfo.Location) }