fix: should check values has been saved instead of rel.Name

This commit is contained in:
chenrui 2022-03-10 20:28:48 +08:00
parent e55cdfa4b3
commit bfffefb179
2 changed files with 66 additions and 4 deletions

View File

@ -1,7 +1,6 @@
package callbacks
import (
"fmt"
"reflect"
"strings"
@ -350,11 +349,9 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{},
)
// stop save association loop
savedRelKey := fmt.Sprintf("gorm:saved_relation_%s", rel.Name)
if _, ok := db.Get(savedRelKey); ok {
if checkAssociationsSaved(db, values) {
return nil
}
db.Set(savedRelKey, true)
for name, ok := range selectColumns {
columnName := ""
@ -398,3 +395,23 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{},
return db.AddError(tx.Create(values).Error)
}
// check association values has been saved
func checkAssociationsSaved(db *gorm.DB, values interface{}) (saved bool) {
visitMapStoreKey := "gorm:saved_association_map"
var vistMap VisitMap
if visit, ok := db.Get(visitMapStoreKey); ok {
if v, ok := visit.(VisitMap); ok {
vistMap = v
if LoadOrStoreVisitMap(&vistMap, values) {
return true
}
}
} else {
vistMap = make(VisitMap)
LoadOrStoreVisitMap(&vistMap, values)
}
db.Set(visitMapStoreKey, vistMap)
return false
}

View File

@ -1,6 +1,7 @@
package callbacks
import (
"reflect"
"sort"
"gorm.io/gorm"
@ -120,3 +121,47 @@ func checkMissingWhereConditions(db *gorm.DB) {
return
}
}
type VisitMap = map[reflect.Value]bool
// Check if circular values
func LoadOrStoreVisitMap(origin *VisitMap, v interface{}) (loaded bool) {
if v == nil {
return
}
value := reflect.ValueOf(v)
return loadOrStoreVisitMap(origin, value)
}
func loadOrStoreVisitMap(vistMap *VisitMap, v reflect.Value) (loaded bool) {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
switch v.Kind() {
case reflect.Slice, reflect.Array:
sameCount := 0
for i := 0; i < v.Len(); i++ {
subv := v.Index(i)
if subv.CanAddr() {
if loadOrStoreVisitMap(vistMap, subv) {
sameCount++
}
}
}
// all slice item already visited
if v.Len() == sameCount {
return true
}
case reflect.Struct, reflect.Interface:
if v.CanAddr() {
p := v.Addr()
if _, ok := (*vistMap)[p]; ok {
return true
} else {
(*vistMap)[p] = true
}
}
}
return false
}