From 63e48191a83f0891af4c7a19a8a0c89a521240a0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 1 Jul 2020 21:28:19 +0800 Subject: [PATCH] Test failed to save association should rollback, close #3100 --- callbacks/associations.go | 16 ++++++------- tests/hooks_test.go | 48 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 8 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index bcb6c414..0968b460 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -139,10 +139,10 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ + db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(elems.Interface()) + }).Create(elems.Interface()).Error) } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { @@ -162,10 +162,10 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ + db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(f.Interface()) + }).Create(f.Interface()).Error) } } } @@ -221,10 +221,10 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ + db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(elems.Interface()) + }).Create(elems.Interface()).Error) } } @@ -286,7 +286,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if elems.Len() > 0 { - db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(elems.Interface()) + db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(elems.Interface()).Error) for i := 0; i < elems.Len(); i++ { appendToJoins(objs[i], elems.Index(i)) @@ -294,7 +294,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if joins.Len() > 0 { - db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface()) + db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface()).Error) } } } diff --git a/tests/hooks_test.go b/tests/hooks_test.go index ed5ee746..3612857b 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -368,6 +368,9 @@ func TestSetColumn(t *testing.T) { } func TestHooksForSlice(t *testing.T) { + DB.Migrator().DropTable(&Product3{}) + DB.AutoMigrate(&Product3{}) + products := []*Product3{ {Name: "Product-1", Price: 100}, {Name: "Product-2", Price: 200}, @@ -414,3 +417,48 @@ func TestHooksForSlice(t *testing.T) { } } } + +type Product4 struct { + gorm.Model + Name string + Code string + Price int64 + Owner string + Item ProductItem +} + +type ProductItem struct { + gorm.Model + Code string + Product4ID uint +} + +func (pi ProductItem) BeforeCreate(*gorm.DB) error { + if pi.Code == "invalid" { + return errors.New("invalid item") + } + return nil +} + +func TestFailedToSaveAssociationShouldRollback(t *testing.T) { + DB.Migrator().DropTable(&Product4{}, &ProductItem{}) + DB.AutoMigrate(&Product4{}, &ProductItem{}) + + product := Product4{Name: "Product-1", Price: 100, Item: ProductItem{Code: "invalid"}} + if err := DB.Create(&product).Error; err == nil { + t.Errorf("should got failed to save, but error is nil") + } + + if DB.First(&Product4{}, "name = ?", product.Name).Error == nil { + t.Errorf("should got RecordNotFound, but got nil") + } + + product = Product4{Name: "Product-2", Price: 100, Item: ProductItem{Code: "valid"}} + if err := DB.Create(&product).Error; err != nil { + t.Errorf("should create product, but got error %v", err) + } + + if err := DB.First(&Product4{}, "name = ?", product.Name).Error; err != nil { + t.Errorf("should find product, but got error %v", err) + } +}