fix: circular reference save, close #5140

commit 2ac099a37ac7bd74f0a98a6fdc42cc8527404144
Author: Jinzhu <wosmvp@gmail.com>
Date:   Thu Mar 17 23:49:21 2022 +0800

    Refactor #5140

commit 6e3ca2d1aa
Author: a631807682 <631807682@qq.com>
Date:   Sun Mar 13 12:52:08 2022 +0800

    test: add test for LoadOrStoreVisitMap

commit 9d5c68e410
Author: chenrui <chenrui@jingdaka.com>
Date:   Thu Mar 10 20:33:47 2022 +0800

    chore: add more comment

commit bfffefb179
Author: chenrui <chenrui@jingdaka.com>
Date:   Thu Mar 10 20:28:48 2022 +0800

    fix: should check values has been saved instead of rel.Name

commit e55cdfa4b3
Author: chenrui <chenrui@jingdaka.com>
Date:   Tue Mar 8 17:48:01 2022 +0800

    chore: go lint

commit fe4715c5bd
Author: chenrui <chenrui@jingdaka.com>
Date:   Tue Mar 8 17:27:24 2022 +0800

    chore: add test comment

commit 326862f3f8
Author: chenrui <chenrui@jingdaka.com>
Date:   Tue Mar 8 17:22:33 2022 +0800

    fix: circular reference save
This commit is contained in:
chenrui 2022-03-17 23:53:31 +08:00 committed by Jinzhu
parent 2990790fbc
commit 9b9ae325bb
6 changed files with 156 additions and 8 deletions

View File

@ -69,7 +69,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
}
if elems.Len() > 0 {
if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil {
if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil {
for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i))
}
@ -82,7 +82,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
rv = rv.Addr()
}
if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil {
if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil {
setupReferences(db.Statement.ReflectValue, rv)
}
}
@ -146,7 +146,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
@ -166,7 +166,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns)
saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns)
}
}
}
@ -237,7 +237,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
}
}
@ -304,7 +304,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
// optimize elems of reflect value length
if elemLen := elems.Len(); elemLen > 0 {
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil)
saveAssociations(db, rel, elems, selectColumns, restricted, nil)
}
for i := 0; i < elemLen; i++ {
@ -341,11 +341,17 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[
return
}
func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
// stop save association loop
if checkAssociationsSaved(db, rValues) {
return nil
}
var (
selects, omits []string
onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns)
refName = rel.Name + "."
values = rValues.Interface()
)
for name, ok := range selectColumns {
@ -390,3 +396,24 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{},
return db.AddError(tx.Create(values).Error)
}
// check association values has been saved
// if values kind is Struct, check it has been saved
// if values kind is Slice/Array, check all items have been saved
var visitMapStoreKey = "gorm:saved_association_map"
func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool {
if visit, ok := db.Get(visitMapStoreKey); ok {
if v, ok := visit.(*visitMap); ok {
if loadOrStoreVisitMap(v, values) {
return true
}
}
} else {
vistMap := make(visitMap)
loadOrStoreVisitMap(&vistMap, values)
db.Set(visitMapStoreKey, &vistMap)
}
return false
}

View File

@ -1,6 +1,7 @@
package callbacks
import (
"reflect"
"sort"
"gorm.io/gorm"
@ -120,3 +121,32 @@ func checkMissingWhereConditions(db *gorm.DB) {
return
}
}
type visitMap = map[reflect.Value]bool
// Check if circular values, return true if loaded
func loadOrStoreVisitMap(vistMap *visitMap, v reflect.Value) (loaded bool) {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
switch v.Kind() {
case reflect.Slice, reflect.Array:
loaded = true
for i := 0; i < v.Len(); i++ {
if !loadOrStoreVisitMap(vistMap, v.Index(i)) {
loaded = false
}
}
case reflect.Struct, reflect.Interface:
if v.CanAddr() {
p := v.Addr()
if _, ok := (*vistMap)[p]; ok {
return true
}
(*vistMap)[p] = true
}
}
return
}

View File

@ -0,0 +1,36 @@
package callbacks
import (
"reflect"
"testing"
)
func TestLoadOrStoreVisitMap(t *testing.T) {
var vm visitMap
var loaded bool
type testM struct {
Name string
}
t1 := testM{Name: "t1"}
t2 := testM{Name: "t2"}
t3 := testM{Name: "t3"}
vm = make(visitMap)
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
t.Fatalf("loaded should be true")
}
// t1 already exist but t2 not
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
t.Fatalf("loaded should be true")
}
}

View File

@ -220,3 +220,44 @@ func TestFullSaveAssociations(t *testing.T) {
t.Errorf("Failed to preload AppliesToProduct")
}
}
func TestSaveBelongsCircularReference(t *testing.T) {
parent := Parent{}
DB.Create(&parent)
child := Child{ParentID: &parent.ID, Parent: &parent}
DB.Create(&child)
parent.FavChildID = child.ID
parent.FavChild = &child
DB.Save(&parent)
var parent1 Parent
DB.First(&parent1, parent.ID)
AssertObjEqual(t, parent, parent1, "ID", "FavChildID")
// Save and Updates is the same
DB.Updates(&parent)
DB.First(&parent1, parent.ID)
AssertObjEqual(t, parent, parent1, "ID", "FavChildID")
}
func TestSaveHasManyCircularReference(t *testing.T) {
parent := Parent{}
DB.Create(&parent)
child := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference"}
child1 := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference1"}
parent.Children = []*Child{&child, &child1}
DB.Save(&parent)
var children []*Child
DB.Where("parent_id = ?", parent.ID).Find(&children)
if len(children) != len(parent.Children) ||
children[0].ID != parent.Children[0].ID ||
children[1].ID != parent.Children[1].ID {
t.Errorf("circular reference children save not equal children:%v parent.Children:%v",
children, parent.Children)
}
}

View File

@ -95,7 +95,7 @@ func OpenTestConnection() (db *gorm.DB, err error) {
func RunMigrations() {
var err error
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}}
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}}
rand.Seed(time.Now().UnixNano())
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })

View File

@ -80,3 +80,17 @@ type Order struct {
Coupon *Coupon
CouponID string
}
type Parent struct {
gorm.Model
FavChildID uint
FavChild *Child
Children []*Child
}
type Child struct {
gorm.Model
Name string
ParentID *uint
Parent *Parent
}