forked from mirror/gorm
fix: circular reference save, close #5140
commit 2ac099a37ac7bd74f0a98a6fdc42cc8527404144 Author: Jinzhu <wosmvp@gmail.com> Date: Thu Mar 17 23:49:21 2022 +0800 Refactor #5140 commit6e3ca2d1aa
Author: a631807682 <631807682@qq.com> Date: Sun Mar 13 12:52:08 2022 +0800 test: add test for LoadOrStoreVisitMap commit9d5c68e410
Author: chenrui <chenrui@jingdaka.com> Date: Thu Mar 10 20:33:47 2022 +0800 chore: add more comment commitbfffefb179
Author: chenrui <chenrui@jingdaka.com> Date: Thu Mar 10 20:28:48 2022 +0800 fix: should check values has been saved instead of rel.Name commite55cdfa4b3
Author: chenrui <chenrui@jingdaka.com> Date: Tue Mar 8 17:48:01 2022 +0800 chore: go lint commitfe4715c5bd
Author: chenrui <chenrui@jingdaka.com> Date: Tue Mar 8 17:27:24 2022 +0800 chore: add test comment commit326862f3f8
Author: chenrui <chenrui@jingdaka.com> Date: Tue Mar 8 17:22:33 2022 +0800 fix: circular reference save
This commit is contained in:
parent
2990790fbc
commit
9b9ae325bb
|
@ -69,7 +69,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
|
|||
}
|
||||
|
||||
if elems.Len() > 0 {
|
||||
if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil {
|
||||
if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil {
|
||||
for i := 0; i < elems.Len(); i++ {
|
||||
setupReferences(objs[i], elems.Index(i))
|
||||
}
|
||||
|
@ -82,7 +82,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
|
|||
rv = rv.Addr()
|
||||
}
|
||||
|
||||
if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil {
|
||||
if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil {
|
||||
setupReferences(db.Statement.ReflectValue, rv)
|
||||
}
|
||||
}
|
||||
|
@ -146,7 +146,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
|
||||
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
|
||||
|
@ -166,7 +166,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns)
|
||||
saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -237,7 +237,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
|
||||
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -304,7 +304,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||
// optimize elems of reflect value length
|
||||
if elemLen := elems.Len(); elemLen > 0 {
|
||||
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
|
||||
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil)
|
||||
saveAssociations(db, rel, elems, selectColumns, restricted, nil)
|
||||
}
|
||||
|
||||
for i := 0; i < elemLen; i++ {
|
||||
|
@ -341,11 +341,17 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[
|
|||
return
|
||||
}
|
||||
|
||||
func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
|
||||
func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
|
||||
// stop save association loop
|
||||
if checkAssociationsSaved(db, rValues) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
selects, omits []string
|
||||
onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns)
|
||||
refName = rel.Name + "."
|
||||
values = rValues.Interface()
|
||||
)
|
||||
|
||||
for name, ok := range selectColumns {
|
||||
|
@ -390,3 +396,24 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{},
|
|||
|
||||
return db.AddError(tx.Create(values).Error)
|
||||
}
|
||||
|
||||
// check association values has been saved
|
||||
// if values kind is Struct, check it has been saved
|
||||
// if values kind is Slice/Array, check all items have been saved
|
||||
var visitMapStoreKey = "gorm:saved_association_map"
|
||||
|
||||
func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool {
|
||||
if visit, ok := db.Get(visitMapStoreKey); ok {
|
||||
if v, ok := visit.(*visitMap); ok {
|
||||
if loadOrStoreVisitMap(v, values) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
vistMap := make(visitMap)
|
||||
loadOrStoreVisitMap(&vistMap, values)
|
||||
db.Set(visitMapStoreKey, &vistMap)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
@ -120,3 +121,32 @@ func checkMissingWhereConditions(db *gorm.DB) {
|
|||
return
|
||||
}
|
||||
}
|
||||
|
||||
type visitMap = map[reflect.Value]bool
|
||||
|
||||
// Check if circular values, return true if loaded
|
||||
func loadOrStoreVisitMap(vistMap *visitMap, v reflect.Value) (loaded bool) {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
loaded = true
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
if !loadOrStoreVisitMap(vistMap, v.Index(i)) {
|
||||
loaded = false
|
||||
}
|
||||
}
|
||||
case reflect.Struct, reflect.Interface:
|
||||
if v.CanAddr() {
|
||||
p := v.Addr()
|
||||
if _, ok := (*vistMap)[p]; ok {
|
||||
return true
|
||||
}
|
||||
(*vistMap)[p] = true
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadOrStoreVisitMap(t *testing.T) {
|
||||
var vm visitMap
|
||||
var loaded bool
|
||||
type testM struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
t1 := testM{Name: "t1"}
|
||||
t2 := testM{Name: "t2"}
|
||||
t3 := testM{Name: "t3"}
|
||||
|
||||
vm = make(visitMap)
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
|
||||
t.Fatalf("loaded should be false")
|
||||
}
|
||||
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
|
||||
t.Fatalf("loaded should be true")
|
||||
}
|
||||
|
||||
// t1 already exist but t2 not
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
|
||||
t.Fatalf("loaded should be false")
|
||||
}
|
||||
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
|
||||
t.Fatalf("loaded should be true")
|
||||
}
|
||||
}
|
|
@ -220,3 +220,44 @@ func TestFullSaveAssociations(t *testing.T) {
|
|||
t.Errorf("Failed to preload AppliesToProduct")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveBelongsCircularReference(t *testing.T) {
|
||||
parent := Parent{}
|
||||
DB.Create(&parent)
|
||||
|
||||
child := Child{ParentID: &parent.ID, Parent: &parent}
|
||||
DB.Create(&child)
|
||||
|
||||
parent.FavChildID = child.ID
|
||||
parent.FavChild = &child
|
||||
DB.Save(&parent)
|
||||
|
||||
var parent1 Parent
|
||||
DB.First(&parent1, parent.ID)
|
||||
AssertObjEqual(t, parent, parent1, "ID", "FavChildID")
|
||||
|
||||
// Save and Updates is the same
|
||||
DB.Updates(&parent)
|
||||
DB.First(&parent1, parent.ID)
|
||||
AssertObjEqual(t, parent, parent1, "ID", "FavChildID")
|
||||
}
|
||||
|
||||
func TestSaveHasManyCircularReference(t *testing.T) {
|
||||
parent := Parent{}
|
||||
DB.Create(&parent)
|
||||
|
||||
child := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference"}
|
||||
child1 := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference1"}
|
||||
|
||||
parent.Children = []*Child{&child, &child1}
|
||||
DB.Save(&parent)
|
||||
|
||||
var children []*Child
|
||||
DB.Where("parent_id = ?", parent.ID).Find(&children)
|
||||
if len(children) != len(parent.Children) ||
|
||||
children[0].ID != parent.Children[0].ID ||
|
||||
children[1].ID != parent.Children[1].ID {
|
||||
t.Errorf("circular reference children save not equal children:%v parent.Children:%v",
|
||||
children, parent.Children)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -95,7 +95,7 @@ func OpenTestConnection() (db *gorm.DB, err error) {
|
|||
|
||||
func RunMigrations() {
|
||||
var err error
|
||||
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}}
|
||||
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}}
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
|
||||
|
||||
|
|
|
@ -80,3 +80,17 @@ type Order struct {
|
|||
Coupon *Coupon
|
||||
CouponID string
|
||||
}
|
||||
|
||||
type Parent struct {
|
||||
gorm.Model
|
||||
FavChildID uint
|
||||
FavChild *Child
|
||||
Children []*Child
|
||||
}
|
||||
|
||||
type Child struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
ParentID *uint
|
||||
Parent *Parent
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue