Test failed to save association should rollback, close #3100

This commit is contained in:
Jinzhu 2020-07-01 21:28:19 +08:00
parent b0aae504ab
commit 63e48191a8
2 changed files with 56 additions and 8 deletions

View File

@ -139,10 +139,10 @@ func SaveAfterAssociations(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
} }
db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}},
DoUpdates: clause.AssignmentColumns(assignmentColumns), DoUpdates: clause.AssignmentColumns(assignmentColumns),
}).Create(elems.Interface()) }).Create(elems.Interface()).Error)
} }
case reflect.Struct: case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
@ -162,10 +162,10 @@ func SaveAfterAssociations(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
} }
db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}},
DoUpdates: clause.AssignmentColumns(assignmentColumns), DoUpdates: clause.AssignmentColumns(assignmentColumns),
}).Create(f.Interface()) }).Create(f.Interface()).Error)
} }
} }
} }
@ -221,10 +221,10 @@ func SaveAfterAssociations(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
} }
db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}},
DoUpdates: clause.AssignmentColumns(assignmentColumns), DoUpdates: clause.AssignmentColumns(assignmentColumns),
}).Create(elems.Interface()) }).Create(elems.Interface()).Error)
} }
} }
@ -286,7 +286,7 @@ func SaveAfterAssociations(db *gorm.DB) {
} }
if elems.Len() > 0 { if elems.Len() > 0 {
db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(elems.Interface()) db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(elems.Interface()).Error)
for i := 0; i < elems.Len(); i++ { for i := 0; i < elems.Len(); i++ {
appendToJoins(objs[i], elems.Index(i)) appendToJoins(objs[i], elems.Index(i))
@ -294,7 +294,7 @@ func SaveAfterAssociations(db *gorm.DB) {
} }
if joins.Len() > 0 { if joins.Len() > 0 {
db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface()) db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface()).Error)
} }
} }
} }

View File

@ -368,6 +368,9 @@ func TestSetColumn(t *testing.T) {
} }
func TestHooksForSlice(t *testing.T) { func TestHooksForSlice(t *testing.T) {
DB.Migrator().DropTable(&Product3{})
DB.AutoMigrate(&Product3{})
products := []*Product3{ products := []*Product3{
{Name: "Product-1", Price: 100}, {Name: "Product-1", Price: 100},
{Name: "Product-2", Price: 200}, {Name: "Product-2", Price: 200},
@ -414,3 +417,48 @@ func TestHooksForSlice(t *testing.T) {
} }
} }
} }
type Product4 struct {
gorm.Model
Name string
Code string
Price int64
Owner string
Item ProductItem
}
type ProductItem struct {
gorm.Model
Code string
Product4ID uint
}
func (pi ProductItem) BeforeCreate(*gorm.DB) error {
if pi.Code == "invalid" {
return errors.New("invalid item")
}
return nil
}
func TestFailedToSaveAssociationShouldRollback(t *testing.T) {
DB.Migrator().DropTable(&Product4{}, &ProductItem{})
DB.AutoMigrate(&Product4{}, &ProductItem{})
product := Product4{Name: "Product-1", Price: 100, Item: ProductItem{Code: "invalid"}}
if err := DB.Create(&product).Error; err == nil {
t.Errorf("should got failed to save, but error is nil")
}
if DB.First(&Product4{}, "name = ?", product.Name).Error == nil {
t.Errorf("should got RecordNotFound, but got nil")
}
product = Product4{Name: "Product-2", Price: 100, Item: ProductItem{Code: "valid"}}
if err := DB.Create(&product).Error; err != nil {
t.Errorf("should create product, but got error %v", err)
}
if err := DB.First(&Product4{}, "name = ?", product.Name).Error; err != nil {
t.Errorf("should find product, but got error %v", err)
}
}