fix: association many2many duplicate elem (#5473)

* fix: association many2many duplicate elem

* chore: gofumpt style
This commit is contained in:
Cr 2022-07-01 15:12:15 +08:00 committed by GitHub
parent 235c093bb9
commit c74bc57add
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 51 additions and 12 deletions

View File

@ -253,6 +253,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
fieldType = reflect.PtrTo(fieldType) fieldType = reflect.PtrTo(fieldType)
} }
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
objs := []reflect.Value{} objs := []reflect.Value{}
@ -272,19 +273,31 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
joins = reflect.Append(joins, joinValue) joins = reflect.Append(joins, joinValue)
} }
identityMap := map[string]bool{}
appendToElems := func(v reflect.Value) { appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
for i := 0; i < f.Len(); i++ { for i := 0; i < f.Len(); i++ {
elem := f.Index(i) elem := f.Index(i)
if !isPtr {
objs = append(objs, v) elem = elem.Addr()
if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
} }
objs = append(objs, v)
elems = reflect.Append(elems, elem)
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
cacheKey := utils.ToStringKey(relPrimaryValues)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
identityMap[cacheKey] = true
distinctElems = reflect.Append(distinctElems, elem)
}
} }
} }
} }
@ -304,7 +317,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
// optimize elems of reflect value length // optimize elems of reflect value length
if elemLen := elems.Len(); elemLen > 0 { if elemLen := elems.Len(); elemLen > 0 {
if v, ok := selectColumns[rel.Name+".*"]; !ok || v { if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
saveAssociations(db, rel, elems, selectColumns, restricted, nil) saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil)
} }
for i := 0; i < elemLen; i++ { for i := 0; i < elemLen; i++ {

View File

@ -3,6 +3,7 @@ package tests_test
import ( import (
"testing" "testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -324,3 +325,29 @@ func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) {
DB.Model(&users).Association("Team").Clear() DB.Model(&users).Association("Team").Clear()
AssertAssociationCount(t, users, "Team", 0, "After Clear") AssertAssociationCount(t, users, "Team", 0, "After Clear")
} }
func TestDuplicateMany2ManyAssociation(t *testing.T) {
user1 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{
{Code: "TestDuplicateMany2ManyAssociation-language-1"},
{Code: "TestDuplicateMany2ManyAssociation-language-2"},
}}
user2 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{
{Code: "TestDuplicateMany2ManyAssociation-language-1"},
{Code: "TestDuplicateMany2ManyAssociation-language-3"},
}}
users := []*User{&user1, &user2}
var err error
err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error
AssertEqual(t, nil, err)
var findUser1 User
err = DB.Preload("Languages").Where("id = ?", user1.ID).First(&findUser1).Error
AssertEqual(t, nil, err)
AssertEqual(t, user1, findUser1)
var findUser2 User
err = DB.Preload("Languages").Where("id = ?", user2.ID).First(&findUser2).Error
AssertEqual(t, nil, err)
AssertEqual(t, user2, findUser2)
}

View File

@ -830,11 +830,11 @@ func TestUniqueColumn(t *testing.T) {
value, ok = ct.DefaultValue() value, ok = ct.DefaultValue()
AssertEqual(t, "", value) AssertEqual(t, "", value)
AssertEqual(t, false, ok) AssertEqual(t, false, ok)
} }
func findColumnType(dest interface{}, columnName string) ( func findColumnType(dest interface{}, columnName string) (
foundColumn gorm.ColumnType, err error) { foundColumn gorm.ColumnType, err error,
) {
columnTypes, err := DB.Migrator().ColumnTypes(dest) columnTypes, err := DB.Migrator().ColumnTypes(dest)
if err != nil { if err != nil {
err = fmt.Errorf("ColumnTypes err:%v", err) err = fmt.Errorf("ColumnTypes err:%v", err)

View File

@ -113,7 +113,6 @@ func TestSerializer(t *testing.T) {
} }
AssertEqual(t, result, data) AssertEqual(t, result, data)
} }
func TestSerializerAssignFirstOrCreate(t *testing.T) { func TestSerializerAssignFirstOrCreate(t *testing.T) {
@ -152,7 +151,7 @@ func TestSerializerAssignFirstOrCreate(t *testing.T) {
} }
AssertEqual(t, result, out) AssertEqual(t, result, out)
//update record // update record
data.Roles = append(data.Roles, "r3") data.Roles = append(data.Roles, "r3")
data.JobInfo.Location = "Gates Hillman Complex" data.JobInfo.Location = "Gates Hillman Complex"
if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil { if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil {