forked from mirror/gorm
fix: circular reference save
This commit is contained in:
parent
b566ed7913
commit
326862f3f8
|
@ -1,6 +1,7 @@
|
|||
package callbacks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
|
@ -348,6 +349,13 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{},
|
|||
refName = rel.Name + "."
|
||||
)
|
||||
|
||||
// stop save association loop
|
||||
savedRelKey := fmt.Sprintf("gorm:saved_relation_%s", rel.Name)
|
||||
if _, ok := db.Get(savedRelKey); ok {
|
||||
return nil
|
||||
}
|
||||
db.Set(savedRelKey, true)
|
||||
|
||||
for name, ok := range selectColumns {
|
||||
columnName := ""
|
||||
if strings.HasPrefix(name, refName) {
|
||||
|
|
|
@ -220,3 +220,43 @@ func TestFullSaveAssociations(t *testing.T) {
|
|||
t.Errorf("Failed to preload AppliesToProduct")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveBelongsCircularReference(t *testing.T) {
|
||||
parent := Parent{}
|
||||
DB.Create(&parent)
|
||||
|
||||
child := Child{ParentID: &parent.ID, Parent: &parent}
|
||||
DB.Create(&child)
|
||||
|
||||
parent.FavChildID = child.ID
|
||||
parent.FavChild = &child
|
||||
DB.Save(&parent)
|
||||
|
||||
var parent1 Parent
|
||||
DB.First(&parent1, parent.ID)
|
||||
AssertObjEqual(t, parent, parent1, "ID", "FavChildID")
|
||||
|
||||
DB.Updates(&parent)
|
||||
DB.First(&parent1, parent.ID)
|
||||
AssertObjEqual(t, parent, parent1, "ID", "FavChildID")
|
||||
}
|
||||
|
||||
func TestSaveHasManyCircularReference(t *testing.T) {
|
||||
parent := Parent{}
|
||||
DB.Create(&parent)
|
||||
|
||||
child := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference"}
|
||||
child1 := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference1"}
|
||||
|
||||
parent.Children = []*Child{&child, &child1}
|
||||
DB.Save(&parent)
|
||||
|
||||
var children []*Child
|
||||
DB.Where("parent_id = ?", parent.ID).Find(&children)
|
||||
if len(children) != len(parent.Children) ||
|
||||
children[0].ID != parent.Children[0].ID ||
|
||||
children[1].ID != parent.Children[1].ID {
|
||||
t.Errorf("circular reference children save not equal children:%v parent.Children:%v",
|
||||
children, parent.Children)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -95,7 +95,7 @@ func OpenTestConnection() (db *gorm.DB, err error) {
|
|||
|
||||
func RunMigrations() {
|
||||
var err error
|
||||
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}}
|
||||
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}}
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
|
||||
|
||||
|
|
|
@ -80,3 +80,16 @@ type Order struct {
|
|||
Coupon *Coupon
|
||||
CouponID string
|
||||
}
|
||||
|
||||
type Parent struct {
|
||||
gorm.Model
|
||||
FavChildID uint
|
||||
FavChild *Child
|
||||
Children []*Child
|
||||
}
|
||||
type Child struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
ParentID *uint
|
||||
Parent *Parent
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue