diff --git a/callbacks/associations.go b/callbacks/associations.go index 270a7d75..f5b208fe 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -1,7 +1,6 @@ package callbacks import ( - "fmt" "reflect" "strings" @@ -350,11 +349,9 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, ) // stop save association loop - savedRelKey := fmt.Sprintf("gorm:saved_relation_%s", rel.Name) - if _, ok := db.Get(savedRelKey); ok { + if checkAssociationsSaved(db, values) { return nil } - db.Set(savedRelKey, true) for name, ok := range selectColumns { columnName := "" @@ -398,3 +395,23 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, return db.AddError(tx.Create(values).Error) } + +// check association values has been saved +func checkAssociationsSaved(db *gorm.DB, values interface{}) (saved bool) { + visitMapStoreKey := "gorm:saved_association_map" + var vistMap VisitMap + if visit, ok := db.Get(visitMapStoreKey); ok { + if v, ok := visit.(VisitMap); ok { + vistMap = v + if LoadOrStoreVisitMap(&vistMap, values) { + return true + } + } + } else { + vistMap = make(VisitMap) + LoadOrStoreVisitMap(&vistMap, values) + } + + db.Set(visitMapStoreKey, vistMap) + return false +} diff --git a/callbacks/helper.go b/callbacks/helper.go index a5eb047e..8f1a3873 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -1,6 +1,7 @@ package callbacks import ( + "reflect" "sort" "gorm.io/gorm" @@ -120,3 +121,47 @@ func checkMissingWhereConditions(db *gorm.DB) { return } } + +type VisitMap = map[reflect.Value]bool + +// Check if circular values +func LoadOrStoreVisitMap(origin *VisitMap, v interface{}) (loaded bool) { + if v == nil { + return + } + value := reflect.ValueOf(v) + return loadOrStoreVisitMap(origin, value) +} + +func loadOrStoreVisitMap(vistMap *VisitMap, v reflect.Value) (loaded bool) { + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + switch v.Kind() { + case reflect.Slice, reflect.Array: + sameCount := 0 + for i := 0; i < v.Len(); i++ { + subv := v.Index(i) + if subv.CanAddr() { + if loadOrStoreVisitMap(vistMap, subv) { + sameCount++ + } + } + } + // all slice item already visited + if v.Len() == sameCount { + return true + } + case reflect.Struct, reflect.Interface: + if v.CanAddr() { + p := v.Addr() + if _, ok := (*vistMap)[p]; ok { + return true + } else { + (*vistMap)[p] = true + } + } + } + return false +}