diff --git a/association.go b/association.go index 14fd1c35..3d522ccc 100644 --- a/association.go +++ b/association.go @@ -290,7 +290,9 @@ func (association *Association) Count() int { ) } - query.Model(fieldValue).Count(&count) + if err := query.Model(fieldValue).Count(&count).Error; err != nil { + association.Error = err + } return count } diff --git a/join_table_handler.go b/join_table_handler.go index 18c12a85..2d1a5055 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -59,6 +59,7 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s s.TableName = tableName s.Source = JoinTableSource{ModelType: source} + s.Source.ForeignKeys = []JoinTableForeignKey{} for idx, dbName := range relationship.ForeignFieldNames { s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ DBName: relationship.ForeignDBNames[idx], @@ -67,6 +68,7 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s } s.Destination = JoinTableSource{ModelType: destination} + s.Destination.ForeignKeys = []JoinTableForeignKey{} for idx, dbName := range relationship.AssociationForeignFieldNames { s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ DBName: relationship.AssociationForeignDBNames[idx], diff --git a/join_table_test.go b/join_table_test.go index f083ab02..dd2171e1 100644 --- a/join_table_test.go +++ b/join_table_test.go @@ -80,3 +80,38 @@ func TestJoinTable(t *testing.T) { t.Errorf("Should deleted all addresses") } } + +func TestEmbeddedMany2ManyRelationship(t *testing.T) { + type EmbeddedPerson struct { + ID int + Name string + Addresses []*Address `gorm:"many2many:person_addresses;"` + } + + type NewPerson struct { + EmbeddedPerson + ExternalID uint + } + DB.Exec("drop table person_addresses;") + DB.AutoMigrate(&NewPerson{}) + + address1 := &Address{Address1: "address 1"} + address2 := &Address{Address1: "address 2"} + person := &NewPerson{ExternalID: 100, EmbeddedPerson: EmbeddedPerson{Name: "person", Addresses: []*Address{address1, address2}}} + if err := DB.Save(person).Error; err != nil { + t.Errorf("no error should return when save embedded many2many relationship, but got %v", err) + } + + if err := DB.Model(person).Association("Addresses").Delete(address1).Error; err != nil { + t.Errorf("no error should return when delete embedded many2many relationship, but got %v", err) + } + + association := DB.Model(person).Debug().Association("Addresses") + if count := association.Count(); count != 1 || association.Error != nil { + t.Errorf("Should found one address, but got %v, error is %v", count, association.Error) + } + + if association.Clear(); association.Count() != 0 { + t.Errorf("Should deleted all addresses") + } +} diff --git a/model_struct.go b/model_struct.go index d4a46784..9c7c1a15 100644 --- a/model_struct.go +++ b/model_struct.go @@ -219,6 +219,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { subField.IsPrimaryKey = false } } + + if subField.Relationship != nil && subField.Relationship.JoinTableHandler != nil { + if joinTableHandler, ok := subField.Relationship.JoinTableHandler.(*JoinTableHandler); ok { + joinTableHandler.Setup(subField.Relationship, joinTableHandler.TableName, reflectType, joinTableHandler.Destination.ModelType) + } + } + modelStruct.StructFields = append(modelStruct.StructFields, subField) } continue