From 326862f3f8980482a09d7d1a7f4d1011bb8a7c59 Mon Sep 17 00:00:00 2001 From: chenrui Date: Tue, 8 Mar 2022 17:22:33 +0800 Subject: [PATCH] fix: circular reference save --- callbacks/associations.go | 8 ++++++++ tests/associations_test.go | 40 ++++++++++++++++++++++++++++++++++++++ tests/tests_test.go | 2 +- utils/tests/models.go | 13 +++++++++++++ 4 files changed, 62 insertions(+), 1 deletion(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index d6fd21de..270a7d75 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -1,6 +1,7 @@ package callbacks import ( + "fmt" "reflect" "strings" @@ -348,6 +349,13 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, refName = rel.Name + "." ) + // stop save association loop + savedRelKey := fmt.Sprintf("gorm:saved_relation_%s", rel.Name) + if _, ok := db.Get(savedRelKey); ok { + return nil + } + db.Set(savedRelKey, true) + for name, ok := range selectColumns { columnName := "" if strings.HasPrefix(name, refName) { diff --git a/tests/associations_test.go b/tests/associations_test.go index 5ce98c7d..696c84bb 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -220,3 +220,43 @@ func TestFullSaveAssociations(t *testing.T) { t.Errorf("Failed to preload AppliesToProduct") } } + +func TestSaveBelongsCircularReference(t *testing.T) { + parent := Parent{} + DB.Create(&parent) + + child := Child{ParentID: &parent.ID, Parent: &parent} + DB.Create(&child) + + parent.FavChildID = child.ID + parent.FavChild = &child + DB.Save(&parent) + + var parent1 Parent + DB.First(&parent1, parent.ID) + AssertObjEqual(t, parent, parent1, "ID", "FavChildID") + + DB.Updates(&parent) + DB.First(&parent1, parent.ID) + AssertObjEqual(t, parent, parent1, "ID", "FavChildID") +} + +func TestSaveHasManyCircularReference(t *testing.T) { + parent := Parent{} + DB.Create(&parent) + + child := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference"} + child1 := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference1"} + + parent.Children = []*Child{&child, &child1} + DB.Save(&parent) + + var children []*Child + DB.Where("parent_id = ?", parent.ID).Find(&children) + if len(children) != len(parent.Children) || + children[0].ID != parent.Children[0].ID || + children[1].ID != parent.Children[1].ID { + t.Errorf("circular reference children save not equal children:%v parent.Children:%v", + children, parent.Children) + } +} diff --git a/tests/tests_test.go b/tests/tests_test.go index 11b6f067..08f4f193 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -95,7 +95,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { func RunMigrations() { var err error - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) diff --git a/utils/tests/models.go b/utils/tests/models.go index c84f9cae..d4fe754b 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -80,3 +80,16 @@ type Order struct { Coupon *Coupon CouponID string } + +type Parent struct { + gorm.Model + FavChildID uint + FavChild *Child + Children []*Child +} +type Child struct { + gorm.Model + Name string + ParentID *uint + Parent *Parent +}