forked from mirror/gorm
fix: association many2many duplicate elem (#5473)
* fix: association many2many duplicate elem * chore: gofumpt style
This commit is contained in:
parent
235c093bb9
commit
c74bc57add
|
@ -253,6 +253,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||
fieldType = reflect.PtrTo(fieldType)
|
||||
}
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
|
||||
objs := []reflect.Value{}
|
||||
|
||||
|
@ -272,19 +273,31 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||
joins = reflect.Append(joins, joinValue)
|
||||
}
|
||||
|
||||
identityMap := map[string]bool{}
|
||||
appendToElems := func(v reflect.Value) {
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
|
||||
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
|
||||
|
||||
for i := 0; i < f.Len(); i++ {
|
||||
elem := f.Index(i)
|
||||
|
||||
objs = append(objs, v)
|
||||
if isPtr {
|
||||
elems = reflect.Append(elems, elem)
|
||||
} else {
|
||||
elems = reflect.Append(elems, elem.Addr())
|
||||
if !isPtr {
|
||||
elem = elem.Addr()
|
||||
}
|
||||
objs = append(objs, v)
|
||||
elems = reflect.Append(elems, elem)
|
||||
|
||||
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
|
||||
for _, pf := range rel.FieldSchema.PrimaryFields {
|
||||
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
|
||||
relPrimaryValues = append(relPrimaryValues, pfv)
|
||||
}
|
||||
}
|
||||
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
identityMap[cacheKey] = true
|
||||
distinctElems = reflect.Append(distinctElems, elem)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -304,7 +317,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||
// optimize elems of reflect value length
|
||||
if elemLen := elems.Len(); elemLen > 0 {
|
||||
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
|
||||
saveAssociations(db, rel, elems, selectColumns, restricted, nil)
|
||||
saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil)
|
||||
}
|
||||
|
||||
for i := 0; i < elemLen; i++ {
|
||||
|
|
|
@ -3,6 +3,7 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
|
@ -324,3 +325,29 @@ func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) {
|
|||
DB.Model(&users).Association("Team").Clear()
|
||||
AssertAssociationCount(t, users, "Team", 0, "After Clear")
|
||||
}
|
||||
|
||||
func TestDuplicateMany2ManyAssociation(t *testing.T) {
|
||||
user1 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{
|
||||
{Code: "TestDuplicateMany2ManyAssociation-language-1"},
|
||||
{Code: "TestDuplicateMany2ManyAssociation-language-2"},
|
||||
}}
|
||||
|
||||
user2 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{
|
||||
{Code: "TestDuplicateMany2ManyAssociation-language-1"},
|
||||
{Code: "TestDuplicateMany2ManyAssociation-language-3"},
|
||||
}}
|
||||
users := []*User{&user1, &user2}
|
||||
var err error
|
||||
err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error
|
||||
AssertEqual(t, nil, err)
|
||||
|
||||
var findUser1 User
|
||||
err = DB.Preload("Languages").Where("id = ?", user1.ID).First(&findUser1).Error
|
||||
AssertEqual(t, nil, err)
|
||||
AssertEqual(t, user1, findUser1)
|
||||
|
||||
var findUser2 User
|
||||
err = DB.Preload("Languages").Where("id = ?", user2.ID).First(&findUser2).Error
|
||||
AssertEqual(t, nil, err)
|
||||
AssertEqual(t, user2, findUser2)
|
||||
}
|
||||
|
|
|
@ -830,11 +830,11 @@ func TestUniqueColumn(t *testing.T) {
|
|||
value, ok = ct.DefaultValue()
|
||||
AssertEqual(t, "", value)
|
||||
AssertEqual(t, false, ok)
|
||||
|
||||
}
|
||||
|
||||
func findColumnType(dest interface{}, columnName string) (
|
||||
foundColumn gorm.ColumnType, err error) {
|
||||
foundColumn gorm.ColumnType, err error,
|
||||
) {
|
||||
columnTypes, err := DB.Migrator().ColumnTypes(dest)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("ColumnTypes err:%v", err)
|
||||
|
|
|
@ -113,7 +113,6 @@ func TestSerializer(t *testing.T) {
|
|||
}
|
||||
|
||||
AssertEqual(t, result, data)
|
||||
|
||||
}
|
||||
|
||||
func TestSerializerAssignFirstOrCreate(t *testing.T) {
|
||||
|
@ -152,7 +151,7 @@ func TestSerializerAssignFirstOrCreate(t *testing.T) {
|
|||
}
|
||||
AssertEqual(t, result, out)
|
||||
|
||||
//update record
|
||||
// update record
|
||||
data.Roles = append(data.Roles, "r3")
|
||||
data.JobInfo.Location = "Gates Hillman Complex"
|
||||
if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil {
|
||||
|
|
Loading…
Reference in New Issue