diff --git a/callbacks/associations.go b/callbacks/associations.go index d6fd21de..3b204ab6 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -69,7 +69,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { } if elems.Len() > 0 { - if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { + if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } @@ -82,7 +82,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { rv = rv.Addr() } - if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { + if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil { setupReferences(db.Statement.ReflectValue, rv) } } @@ -146,7 +146,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) + saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns) } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { @@ -166,7 +166,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns) + saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns) } } } @@ -237,7 +237,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) + saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns) } } @@ -304,7 +304,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { // optimize elems of reflect value length if elemLen := elems.Len(); elemLen > 0 { if v, ok := selectColumns[rel.Name+".*"]; !ok || v { - saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) + saveAssociations(db, rel, elems, selectColumns, restricted, nil) } for i := 0; i < elemLen; i++ { @@ -341,11 +341,17 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[ return } -func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { +func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { + // stop save association loop + if checkAssociationsSaved(db, rValues) { + return nil + } + var ( selects, omits []string onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns) refName = rel.Name + "." + values = rValues.Interface() ) for name, ok := range selectColumns { @@ -390,3 +396,24 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, return db.AddError(tx.Create(values).Error) } + +// check association values has been saved +// if values kind is Struct, check it has been saved +// if values kind is Slice/Array, check all items have been saved +var visitMapStoreKey = "gorm:saved_association_map" + +func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool { + if visit, ok := db.Get(visitMapStoreKey); ok { + if v, ok := visit.(*visitMap); ok { + if loadOrStoreVisitMap(v, values) { + return true + } + } + } else { + vistMap := make(visitMap) + loadOrStoreVisitMap(&vistMap, values) + db.Set(visitMapStoreKey, &vistMap) + } + + return false +} diff --git a/callbacks/helper.go b/callbacks/helper.go index a5eb047e..71b67de5 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -1,6 +1,7 @@ package callbacks import ( + "reflect" "sort" "gorm.io/gorm" @@ -120,3 +121,32 @@ func checkMissingWhereConditions(db *gorm.DB) { return } } + +type visitMap = map[reflect.Value]bool + +// Check if circular values, return true if loaded +func loadOrStoreVisitMap(vistMap *visitMap, v reflect.Value) (loaded bool) { + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + switch v.Kind() { + case reflect.Slice, reflect.Array: + loaded = true + for i := 0; i < v.Len(); i++ { + if !loadOrStoreVisitMap(vistMap, v.Index(i)) { + loaded = false + } + } + case reflect.Struct, reflect.Interface: + if v.CanAddr() { + p := v.Addr() + if _, ok := (*vistMap)[p]; ok { + return true + } + (*vistMap)[p] = true + } + } + + return +} diff --git a/callbacks/visit_map_test.go b/callbacks/visit_map_test.go new file mode 100644 index 00000000..b1fb86db --- /dev/null +++ b/callbacks/visit_map_test.go @@ -0,0 +1,36 @@ +package callbacks + +import ( + "reflect" + "testing" +) + +func TestLoadOrStoreVisitMap(t *testing.T) { + var vm visitMap + var loaded bool + type testM struct { + Name string + } + + t1 := testM{Name: "t1"} + t2 := testM{Name: "t2"} + t3 := testM{Name: "t3"} + + vm = make(visitMap) + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded { + t.Fatalf("loaded should be false") + } + + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded { + t.Fatalf("loaded should be true") + } + + // t1 already exist but t2 not + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded { + t.Fatalf("loaded should be false") + } + + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded { + t.Fatalf("loaded should be true") + } +} diff --git a/tests/associations_test.go b/tests/associations_test.go index 5ce98c7d..32f6525b 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -220,3 +220,44 @@ func TestFullSaveAssociations(t *testing.T) { t.Errorf("Failed to preload AppliesToProduct") } } + +func TestSaveBelongsCircularReference(t *testing.T) { + parent := Parent{} + DB.Create(&parent) + + child := Child{ParentID: &parent.ID, Parent: &parent} + DB.Create(&child) + + parent.FavChildID = child.ID + parent.FavChild = &child + DB.Save(&parent) + + var parent1 Parent + DB.First(&parent1, parent.ID) + AssertObjEqual(t, parent, parent1, "ID", "FavChildID") + + // Save and Updates is the same + DB.Updates(&parent) + DB.First(&parent1, parent.ID) + AssertObjEqual(t, parent, parent1, "ID", "FavChildID") +} + +func TestSaveHasManyCircularReference(t *testing.T) { + parent := Parent{} + DB.Create(&parent) + + child := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference"} + child1 := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference1"} + + parent.Children = []*Child{&child, &child1} + DB.Save(&parent) + + var children []*Child + DB.Where("parent_id = ?", parent.ID).Find(&children) + if len(children) != len(parent.Children) || + children[0].ID != parent.Children[0].ID || + children[1].ID != parent.Children[1].ID { + t.Errorf("circular reference children save not equal children:%v parent.Children:%v", + children, parent.Children) + } +} diff --git a/tests/tests_test.go b/tests/tests_test.go index 11b6f067..08f4f193 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -95,7 +95,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { func RunMigrations() { var err error - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) diff --git a/utils/tests/models.go b/utils/tests/models.go index c84f9cae..22e8e659 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -80,3 +80,17 @@ type Order struct { Coupon *Coupon CouponID string } + +type Parent struct { + gorm.Model + FavChildID uint + FavChild *Child + Children []*Child +} + +type Child struct { + gorm.Model + Name string + ParentID *uint + Parent *Parent +}