From 118a3f469ed8d4346584ceef302a0fc72d52b947 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 21 Sep 2022 17:29:38 +0800 Subject: [PATCH] Allow shared foreign key for many2many jointable --- schema/relationship.go | 60 ++++++++++++++++++++++--------------- schema/relationship_test.go | 29 +++++++++++++++++- tests/go.mod | 13 ++++---- 3 files changed, 71 insertions(+), 31 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index 0aa33e51..bb8aeb64 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -191,7 +191,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel err error joinTableFields []reflect.StructField fieldsMap = map[string]*Field{} - ownFieldsMap = map[string]bool{} // fix self join many2many + ownFieldsMap = map[string]*Field{} // fix self join many2many + referFieldsMap = map[string]*Field{} joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"]) joinReferences = toColumns(field.TagSettings["JOINREFERENCES"]) ) @@ -229,7 +230,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel joinFieldName = strings.Title(joinForeignKeys[idx]) } - ownFieldsMap[joinFieldName] = true + ownFieldsMap[joinFieldName] = ownField fieldsMap[joinFieldName] = ownField joinTableFields = append(joinTableFields, reflect.StructField{ Name: joinFieldName, @@ -242,9 +243,6 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel for idx, relField := range refForeignFields { joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name - if len(joinReferences) > idx { - joinFieldName = strings.Title(joinReferences[idx]) - } if _, ok := ownFieldsMap[joinFieldName]; ok { if field.Name != relation.FieldSchema.Name { @@ -254,14 +252,22 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } } - fieldsMap[joinFieldName] = relField - joinTableFields = append(joinTableFields, reflect.StructField{ - Name: joinFieldName, - PkgPath: relField.StructField.PkgPath, - Type: relField.StructField.Type, - Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"), - "column", "autoincrement", "index", "unique", "uniqueindex"), - }) + if len(joinReferences) > idx { + joinFieldName = strings.Title(joinReferences[idx]) + } + + referFieldsMap[joinFieldName] = relField + + if _, ok := fieldsMap[joinFieldName]; !ok { + fieldsMap[joinFieldName] = relField + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: joinFieldName, + PkgPath: relField.StructField.PkgPath, + Type: relField.StructField.Type, + Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"), + "column", "autoincrement", "index", "unique", "uniqueindex"), + }) + } } joinTableFields = append(joinTableFields, reflect.StructField{ @@ -317,31 +323,37 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel f.Size = fieldsMap[f.Name].Size } relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) - ownPrimaryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] - if ownPrimaryField { + if of, ok := ownFieldsMap[f.Name]; ok { joinRel := relation.JoinTable.Relationships.Relations[relName] joinRel.Field = relation.Field joinRel.References = append(joinRel.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], + PrimaryKey: of, ForeignKey: f, }) - } else { + + relation.References = append(relation.References, &Reference{ + PrimaryKey: of, + ForeignKey: f, + OwnPrimaryKey: true, + }) + } + + if rf, ok := referFieldsMap[f.Name]; ok { joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] if joinRefRel.Field == nil { joinRefRel.Field = relation.Field } joinRefRel.References = append(joinRefRel.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], + PrimaryKey: rf, + ForeignKey: f, + }) + + relation.References = append(relation.References, &Reference{ + PrimaryKey: rf, ForeignKey: f, }) } - - relation.References = append(relation.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], - ForeignKey: f, - OwnPrimaryKey: ownPrimaryField, - }) } } } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 6fffbfcb..85c45589 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -10,7 +10,7 @@ import ( func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) { if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil { - t.Errorf("Failed to parse schema") + t.Errorf("Failed to parse schema, got error %v", err) } else { for _, rel := range relations { checkSchemaRelation(t, s, rel) @@ -305,6 +305,33 @@ func TestMany2ManyOverrideForeignKey(t *testing.T) { }) } +func TestMany2ManySharedForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + Kind string + ProfileRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"many2many:user_profiles;foreignKey:Refer,Kind;joinForeignKey:UserRefer,Kind;References:ProfileRefer,Kind;joinReferences:ProfileR,Kind"` + Kind string + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, + References: []Reference{ + {"Refer", "User", "UserRefer", "user_profiles", "", true}, + {"Kind", "User", "Kind", "user_profiles", "", true}, + {"ProfileRefer", "Profile", "ProfileR", "user_profiles", "", false}, + {"Kind", "Profile", "Kind", "user_profiles", "", false}, + }, + }) +} + func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { type Profile struct { gorm.Model diff --git a/tests/go.mod b/tests/go.mod index 19280434..ebebabc0 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,17 +3,18 @@ module gorm.io/gorm/tests go 1.16 require ( + github.com/denisenkom/go-mssqldb v0.12.2 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.5 - github.com/lib/pq v1.10.6 - github.com/mattn/go-sqlite3 v1.14.14 // indirect - golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect - gorm.io/driver/mysql v1.3.5 - gorm.io/driver/postgres v1.3.8 + github.com/lib/pq v1.10.7 + github.com/mattn/go-sqlite3 v1.14.15 // indirect + golang.org/x/crypto v0.0.0-20220919173607-35f4265a4bc0 // indirect + gorm.io/driver/mysql v1.3.6 + gorm.io/driver/postgres v1.3.10 gorm.io/driver/sqlite v1.3.6 gorm.io/driver/sqlserver v1.3.2 - gorm.io/gorm v1.23.8 + gorm.io/gorm v1.23.9 ) replace gorm.io/gorm => ../