diff --git a/schema/relationship.go b/schema/relationship.go index 3dcef9fc..dffe5988 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -168,31 +168,76 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel joinTableFields []reflect.StructField fieldsMap = map[string]*Field{} ownFieldsMap = map[string]bool{} // fix self join many2many + joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"]) + joinReferences = toColumns(field.TagSettings["JOINREFERENCES"]) ) - for _, s := range []*Schema{schema, relation.FieldSchema} { - for _, primaryField := range s.PrimaryFields { - fieldName := s.Name + primaryField.Name - if _, ok := fieldsMap[fieldName]; ok { - if field.Name != s.Name { - fieldName = inflection.Singular(field.Name) + primaryField.Name - } else { - fieldName = s.Name + primaryField.Name + "Reference" - } - } else { - ownFieldsMap[fieldName] = true - } + ownForeignFields := schema.PrimaryFields + refForeignFields := relation.FieldSchema.PrimaryFields - fieldsMap[fieldName] = primaryField - joinTableFields = append(joinTableFields, reflect.StructField{ - Name: fieldName, - PkgPath: primaryField.StructField.PkgPath, - Type: primaryField.StructField.Type, - Tag: removeSettingFromTag(primaryField.StructField.Tag, "column"), - }) + if len(relation.foreignKeys) > 0 { + ownForeignFields = []*Field{} + for _, foreignKey := range relation.foreignKeys { + if field := schema.LookUpField(foreignKey); field != nil { + ownForeignFields = append(ownForeignFields, field) + } else { + schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey) + return + } } } + if len(relation.primaryKeys) > 0 { + refForeignFields = []*Field{} + for _, foreignKey := range relation.primaryKeys { + if field := relation.FieldSchema.LookUpField(foreignKey); field != nil { + refForeignFields = append(refForeignFields, field) + } else { + schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey) + return + } + } + } + + for idx, ownField := range ownForeignFields { + joinFieldName := schema.Name + ownField.Name + if len(joinForeignKeys) > idx { + joinFieldName = joinForeignKeys[idx] + } + + ownFieldsMap[joinFieldName] = true + fieldsMap[joinFieldName] = ownField + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: joinFieldName, + PkgPath: ownField.StructField.PkgPath, + Type: ownField.StructField.Type, + Tag: removeSettingFromTag(ownField.StructField.Tag, "column"), + }) + } + + for idx, relField := range refForeignFields { + joinFieldName := relation.FieldSchema.Name + relField.Name + if len(joinReferences) > idx { + joinFieldName = joinReferences[idx] + } + + if _, ok := ownFieldsMap[joinFieldName]; ok { + if field.Name != relation.FieldSchema.Name { + joinFieldName = inflection.Singular(field.Name) + relField.Name + } else { + joinFieldName += "Reference" + } + } + + fieldsMap[joinFieldName] = relField + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: joinFieldName, + PkgPath: relField.StructField.PkgPath, + Type: relField.StructField.Type, + Tag: removeSettingFromTag(relField.StructField.Tag, "column"), + }) + } + if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { schema.err = err } diff --git a/schema/relationship_test.go b/schema/relationship_test.go new file mode 100644 index 00000000..41e8c7bd --- /dev/null +++ b/schema/relationship_test.go @@ -0,0 +1,199 @@ +package schema_test + +import ( + "sync" + "testing" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/schema" +) + +func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) { + if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil { + t.Errorf("Failed to parse schema") + } else { + for _, rel := range relations { + checkSchemaRelation(t, s, rel) + } + } +} + +func TestBelongsToOverrideForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:ProfileRefer"` + ProfileRefer int + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ID", "Profile", "ProfileRefer", "User", "", false}}, + }) +} + +func TestBelongsToOverrideReferences(t *testing.T) { + type Profile struct { + gorm.Model + Refer string + Name string + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:ProfileID;References:Refer"` + ProfileID int + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "Profile", "ProfileID", "User", "", false}}, + }) +} + +func TestHasOneOverrideForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:UserRefer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}}, + }) +} + +func TestHasOneOverrideReferences(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserID uint + } + + type User struct { + gorm.Model + Refer string + Profile Profile `gorm:"ForeignKey:UserID;References:Refer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "User", "UserID", "Profile", "", true}}, + }) +} + +func TestHasManyOverrideForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profile []Profile `gorm:"ForeignKey:UserRefer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasMany, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}}, + }) +} + +func TestHasManyOverrideReferences(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserID uint + } + + type User struct { + gorm.Model + Refer string + Profile []Profile `gorm:"ForeignKey:UserID;References:Refer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasMany, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "User", "UserID", "Profile", "", true}}, + }) +} + +func TestMany2ManyOverrideForeignKeyAndReferences(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"many2many:user_profiles;ForeignKey:Refer;JoinForeignKey:UserReferID;References:UserRefer;JoinReferences:ProfileRefer"` + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, + References: []Reference{ + {"Refer", "User", "UserReferID", "user_profiles", "", true}, + {"UserRefer", "Profile", "ProfileRefer", "user_profiles", "", false}, + }, + }) +} + +func TestMany2ManyOverrideForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"many2many:user_profiles;ForeignKey:Refer;References:UserRefer"` + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, + References: []Reference{ + {"Refer", "User", "UserRefer", "user_profiles", "", true}, + {"UserRefer", "Profile", "ProfileUserRefer", "user_profiles", "", false}, + }, + }) +} + +func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"many2many:user_profiles;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"` + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, + References: []Reference{ + {"ID", "User", "UserReferID", "user_profiles", "", true}, + {"ID", "Profile", "ProfileRefer", "user_profiles", "", false}, + }, + }) +} diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 24920515..b5474fe7 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -127,7 +127,7 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { } if r.FieldSchema.Name != relation.FieldSchema { - t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name) + t.Errorf("schema %v field relation's schema expects %v, but got %v", s, relation.FieldSchema, r.FieldSchema.Name) } if r.Polymorphic != nil {