From 457f1e5d7390c2b7f54c6111bfa863cfb35c5dbd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 26 May 2020 01:21:15 +0800 Subject: [PATCH] Test Many2Many Association for Slice --- association.go | 32 ++++++++++++---- tests/associations_test.go | 78 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 99 insertions(+), 11 deletions(-) diff --git a/association.go b/association.go index 49fd4558..92a19efb 100644 --- a/association.go +++ b/association.go @@ -340,11 +340,16 @@ func (association *Association) Count() (count int64) { return } +type assignBack struct { + Source reflect.Value + Index int + Dest reflect.Value +} + func (association *Association) saveAssociation(clear bool, values ...interface{}) { var ( reflectValue = association.DB.Statement.ReflectValue - assignBacks = [][2]reflect.Value{} - assignBack = association.Relationship.Field.FieldType.Kind() == reflect.Struct + assignBacks []assignBack ) appendToRelations := func(source, rv reflect.Value, clear bool) { @@ -354,14 +359,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ case reflect.Slice, reflect.Array: if rv.Len() > 0 { association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface()) - if assignBack { - assignBacks = append(assignBacks, [2]reflect.Value{source, rv.Index(0)}) + if association.Relationship.Field.FieldType.Kind() == reflect.Struct { + assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)}) } } case reflect.Struct: association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface()) - if assignBack { - assignBacks = append(assignBacks, [2]reflect.Value{source, rv}) + if association.Relationship.Field.FieldType.Kind() == reflect.Struct { + assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv}) } } case schema.HasMany, schema.Many2Many: @@ -379,6 +384,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } else { association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name) } + + if association.Relationship.Field.IndirectFieldType.Elem().Kind() == reflect.Struct { + assignBacks = append(assignBacks, assignBack{ + Source: source, + Index: fieldValue.Len(), + Dest: ev, + }) + } } switch rv.Kind() { @@ -451,6 +464,11 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } for _, assignBack := range assignBacks { - reflect.Indirect(assignBack[1]).Set(association.Relationship.Field.ReflectValueOf(assignBack[0])) + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source)) + if assignBack.Index > 0 { + reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1)) + } else { + reflect.Indirect(assignBack.Dest).Set(fieldValue) + } } } diff --git a/tests/associations_test.go b/tests/associations_test.go index 3ab69b42..3aa11edb 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -786,7 +786,7 @@ func TestMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Languages", 2, "") // Append - var language = Language{Code: "language-has-many-append", Name: "language-has-many-append"} + var language = Language{Code: "language-many2many-append", Name: "language-many2many-append"} DB.Create(&language) if err := DB.Model(&user2).Association("Languages").Append(&language); err != nil { @@ -799,8 +799,8 @@ func TestMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Languages", 3, "AfterAppend") var languages = []Language{ - {Code: "language-has-many-append-1-1", Name: "language-has-many-append-1-1"}, - {Code: "language-has-many-append-2-1", Name: "language-has-many-append-2-1"}, + {Code: "language-many2many-append-1-1", Name: "language-many2many-append-1-1"}, + {Code: "language-many2many-append-2-1", Name: "language-many2many-append-2-1"}, } DB.Create(&languages) @@ -815,7 +815,7 @@ func TestMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Languages", 5, "AfterAppendSlice") // Replace - var language2 = Language{Code: "language-has-many-replace", Name: "language-has-many-replace"} + var language2 = Language{Code: "language-many2many-replace", Name: "language-many2many-replace"} DB.Create(&language2) if err := DB.Model(&user2).Association("Languages").Replace(&language2); err != nil { @@ -852,3 +852,73 @@ func TestMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user2, "Languages", 0, "after clear") } + +func TestMany2ManyAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-many2many-1", Config{Languages: 2}), + *GetUser("slice-many2many-2", Config{Languages: 0}), + *GetUser("slice-many2many-3", Config{Languages: 4}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Languages", 6, "") + + // Find + var languages []Language + if DB.Model(&users).Association("Languages").Find(&languages); len(languages) != 6 { + t.Errorf("languages count should be %v, but got %v", 6, len(languages)) + } + + // Append + var languages1 = []Language{ + {Code: "language-many2many-append-1", Name: "language-many2many-append-1"}, + } + var languages2 = []Language{} + var languages3 = []Language{ + {Code: "language-many2many-append-3-1", Name: "language-many2many-append-3-1"}, + {Code: "language-many2many-append-3-2", Name: "language-many2many-append-3-2"}, + } + DB.Create(&languages1) + DB.Create(&languages3) + + DB.Model(&users).Association("Languages").Append(&languages1, &languages2, &languages3) + + AssertAssociationCount(t, users, "Languages", 9, "After Append") + + languages2_1 := []*Language{ + {Code: "language-slice-replace-1-1", Name: "language-slice-replace-1-1"}, + {Code: "language-slice-replace-1-2", Name: "language-slice-replace-1-2"}, + } + languages2_2 := []*Language{ + {Code: "language-slice-replace-2-1", Name: "language-slice-replace-2-1"}, + {Code: "language-slice-replace-2-2", Name: "language-slice-replace-2-2"}, + } + languages2_3 := &Language{Code: "language-slice-replace-3", Name: "language-slice-replace-3"} + DB.Create(&languages2_1) + DB.Create(&languages2_2) + DB.Create(&languages2_3) + + // Replace + DB.Model(&users).Association("Languages").Replace(&languages2_1, &languages2_2, languages2_3) + + AssertAssociationCount(t, users, "Languages", 5, "After Replace") + + // Delete + if err := DB.Model(&users).Association("Languages").Delete(&users[2].Languages); err != nil { + t.Errorf("no error should happend when deleting language, but got %v", err) + } + + AssertAssociationCount(t, users, "Languages", 4, "after delete") + + if err := DB.Model(&users).Association("Languages").Delete(users[0].Languages[0], users[1].Languages[1]); err != nil { + t.Errorf("no error should happend when deleting language, but got %v", err) + } + + AssertAssociationCount(t, users, "Languages", 2, "after delete") + + // Clear + DB.Model(&users).Association("Languages").Clear() + AssertAssociationCount(t, users, "Languages", 0, "After Clear") +}