mirror of https://github.com/go-gorm/gorm.git
fix: circular reference save, close #5140
commit 2ac099a37ac7bd74f0a98a6fdc42cc8527404144 Author: Jinzhu <wosmvp@gmail.com> Date: Thu Mar 17 23:49:21 2022 +0800 Refactor #5140 commit6e3ca2d1aa
Author: a631807682 <631807682@qq.com> Date: Sun Mar 13 12:52:08 2022 +0800 test: add test for LoadOrStoreVisitMap commit9d5c68e410
Author: chenrui <chenrui@jingdaka.com> Date: Thu Mar 10 20:33:47 2022 +0800 chore: add more comment commitbfffefb179
Author: chenrui <chenrui@jingdaka.com> Date: Thu Mar 10 20:28:48 2022 +0800 fix: should check values has been saved instead of rel.Name commite55cdfa4b3
Author: chenrui <chenrui@jingdaka.com> Date: Tue Mar 8 17:48:01 2022 +0800 chore: go lint commitfe4715c5bd
Author: chenrui <chenrui@jingdaka.com> Date: Tue Mar 8 17:27:24 2022 +0800 chore: add test comment commit326862f3f8
Author: chenrui <chenrui@jingdaka.com> Date: Tue Mar 8 17:22:33 2022 +0800 fix: circular reference save
This commit is contained in:
parent
2990790fbc
commit
9b9ae325bb
|
@ -69,7 +69,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if elems.Len() > 0 {
|
if elems.Len() > 0 {
|
||||||
if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil {
|
if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil {
|
||||||
for i := 0; i < elems.Len(); i++ {
|
for i := 0; i < elems.Len(); i++ {
|
||||||
setupReferences(objs[i], elems.Index(i))
|
setupReferences(objs[i], elems.Index(i))
|
||||||
}
|
}
|
||||||
|
@ -82,7 +82,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
|
||||||
rv = rv.Addr()
|
rv = rv.Addr()
|
||||||
}
|
}
|
||||||
|
|
||||||
if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil {
|
if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil {
|
||||||
setupReferences(db.Statement.ReflectValue, rv)
|
setupReferences(db.Statement.ReflectValue, rv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -146,7 +146,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
|
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
|
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
|
||||||
|
@ -166,7 +166,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns)
|
saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -237,7 +237,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
|
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -304,7 +304,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||||
// optimize elems of reflect value length
|
// optimize elems of reflect value length
|
||||||
if elemLen := elems.Len(); elemLen > 0 {
|
if elemLen := elems.Len(); elemLen > 0 {
|
||||||
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
|
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
|
||||||
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil)
|
saveAssociations(db, rel, elems, selectColumns, restricted, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < elemLen; i++ {
|
for i := 0; i < elemLen; i++ {
|
||||||
|
@ -341,11 +341,17 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
|
func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
|
||||||
|
// stop save association loop
|
||||||
|
if checkAssociationsSaved(db, rValues) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
selects, omits []string
|
selects, omits []string
|
||||||
onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns)
|
onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns)
|
||||||
refName = rel.Name + "."
|
refName = rel.Name + "."
|
||||||
|
values = rValues.Interface()
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, ok := range selectColumns {
|
for name, ok := range selectColumns {
|
||||||
|
@ -390,3 +396,24 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{},
|
||||||
|
|
||||||
return db.AddError(tx.Create(values).Error)
|
return db.AddError(tx.Create(values).Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check association values has been saved
|
||||||
|
// if values kind is Struct, check it has been saved
|
||||||
|
// if values kind is Slice/Array, check all items have been saved
|
||||||
|
var visitMapStoreKey = "gorm:saved_association_map"
|
||||||
|
|
||||||
|
func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool {
|
||||||
|
if visit, ok := db.Get(visitMapStoreKey); ok {
|
||||||
|
if v, ok := visit.(*visitMap); ok {
|
||||||
|
if loadOrStoreVisitMap(v, values) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
vistMap := make(visitMap)
|
||||||
|
loadOrStoreVisitMap(&vistMap, values)
|
||||||
|
db.Set(visitMapStoreKey, &vistMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package callbacks
|
package callbacks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
@ -120,3 +121,32 @@ func checkMissingWhereConditions(db *gorm.DB) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type visitMap = map[reflect.Value]bool
|
||||||
|
|
||||||
|
// Check if circular values, return true if loaded
|
||||||
|
func loadOrStoreVisitMap(vistMap *visitMap, v reflect.Value) (loaded bool) {
|
||||||
|
if v.Kind() == reflect.Ptr {
|
||||||
|
v = v.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
loaded = true
|
||||||
|
for i := 0; i < v.Len(); i++ {
|
||||||
|
if !loadOrStoreVisitMap(vistMap, v.Index(i)) {
|
||||||
|
loaded = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Struct, reflect.Interface:
|
||||||
|
if v.CanAddr() {
|
||||||
|
p := v.Addr()
|
||||||
|
if _, ok := (*vistMap)[p]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
(*vistMap)[p] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
package callbacks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadOrStoreVisitMap(t *testing.T) {
|
||||||
|
var vm visitMap
|
||||||
|
var loaded bool
|
||||||
|
type testM struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
t1 := testM{Name: "t1"}
|
||||||
|
t2 := testM{Name: "t2"}
|
||||||
|
t3 := testM{Name: "t3"}
|
||||||
|
|
||||||
|
vm = make(visitMap)
|
||||||
|
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
|
||||||
|
t.Fatalf("loaded should be false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
|
||||||
|
t.Fatalf("loaded should be true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// t1 already exist but t2 not
|
||||||
|
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
|
||||||
|
t.Fatalf("loaded should be false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
|
||||||
|
t.Fatalf("loaded should be true")
|
||||||
|
}
|
||||||
|
}
|
|
@ -220,3 +220,44 @@ func TestFullSaveAssociations(t *testing.T) {
|
||||||
t.Errorf("Failed to preload AppliesToProduct")
|
t.Errorf("Failed to preload AppliesToProduct")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSaveBelongsCircularReference(t *testing.T) {
|
||||||
|
parent := Parent{}
|
||||||
|
DB.Create(&parent)
|
||||||
|
|
||||||
|
child := Child{ParentID: &parent.ID, Parent: &parent}
|
||||||
|
DB.Create(&child)
|
||||||
|
|
||||||
|
parent.FavChildID = child.ID
|
||||||
|
parent.FavChild = &child
|
||||||
|
DB.Save(&parent)
|
||||||
|
|
||||||
|
var parent1 Parent
|
||||||
|
DB.First(&parent1, parent.ID)
|
||||||
|
AssertObjEqual(t, parent, parent1, "ID", "FavChildID")
|
||||||
|
|
||||||
|
// Save and Updates is the same
|
||||||
|
DB.Updates(&parent)
|
||||||
|
DB.First(&parent1, parent.ID)
|
||||||
|
AssertObjEqual(t, parent, parent1, "ID", "FavChildID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveHasManyCircularReference(t *testing.T) {
|
||||||
|
parent := Parent{}
|
||||||
|
DB.Create(&parent)
|
||||||
|
|
||||||
|
child := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference"}
|
||||||
|
child1 := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference1"}
|
||||||
|
|
||||||
|
parent.Children = []*Child{&child, &child1}
|
||||||
|
DB.Save(&parent)
|
||||||
|
|
||||||
|
var children []*Child
|
||||||
|
DB.Where("parent_id = ?", parent.ID).Find(&children)
|
||||||
|
if len(children) != len(parent.Children) ||
|
||||||
|
children[0].ID != parent.Children[0].ID ||
|
||||||
|
children[1].ID != parent.Children[1].ID {
|
||||||
|
t.Errorf("circular reference children save not equal children:%v parent.Children:%v",
|
||||||
|
children, parent.Children)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -95,7 +95,7 @@ func OpenTestConnection() (db *gorm.DB, err error) {
|
||||||
|
|
||||||
func RunMigrations() {
|
func RunMigrations() {
|
||||||
var err error
|
var err error
|
||||||
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}}
|
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}}
|
||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
|
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
|
||||||
|
|
||||||
|
|
|
@ -80,3 +80,17 @@ type Order struct {
|
||||||
Coupon *Coupon
|
Coupon *Coupon
|
||||||
CouponID string
|
CouponID string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Parent struct {
|
||||||
|
gorm.Model
|
||||||
|
FavChildID uint
|
||||||
|
FavChild *Child
|
||||||
|
Children []*Child
|
||||||
|
}
|
||||||
|
|
||||||
|
type Child struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
ParentID *uint
|
||||||
|
Parent *Parent
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue