diff --git a/customize_column_test.go b/customize_column_test.go index ddb536b8..c96b2d40 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -279,3 +279,25 @@ func TestBelongsToWithPartialCustomizedColumn(t *testing.T) { t.Errorf("should preload discount from coupon") } } + +type SelfReferencingUser struct { + gorm.Model + Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;AssociationForeignKey:ID=friend_id"` +} + +func TestSelfReferencingMany2ManyColumn(t *testing.T) { + DB.DropTable(&SelfReferencingUser{}, "UserFriends") + DB.AutoMigrate(&SelfReferencingUser{}) + + friend := SelfReferencingUser{} + if err := DB.Create(&friend).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + user := SelfReferencingUser{ + Friends: []*SelfReferencingUser{&friend}, + } + if err := DB.Create(&user).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } +} diff --git a/join_table_handler.go b/join_table_handler.go index 2d1a5055..b4be6cf9 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -109,7 +109,24 @@ func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[strin // Add create relationship in join table for source and destination func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { scope := db.NewScope("") - searchMap := s.getSearchMap(db, source, destination) + searchMap := map[string]interface{}{} + + // getSearchMap() cannot be used here since the source and destination + // model types may be identical + + sourceScope := db.NewScope(source) + for _, foreignKey := range s.Source.ForeignKeys { + if field, ok := sourceScope.FieldByName(foreignKey.AssociationDBName); ok { + searchMap[foreignKey.DBName] = field.Field.Interface() + } + } + + destinationScope := db.NewScope(destination) + for _, foreignKey := range s.Destination.ForeignKeys { + if field, ok := destinationScope.FieldByName(foreignKey.AssociationDBName); ok { + searchMap[foreignKey.DBName] = field.Field.Interface() + } + } var assignColumns, binVars, conditions []string var values []interface{} diff --git a/model_struct.go b/model_struct.go index 315028c4..463ec517 100644 --- a/model_struct.go +++ b/model_struct.go @@ -289,11 +289,24 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } for _, name := range associationForeignKeys { + + // In order to allow self-referencing many2many tables, the name + // may be followed by "=" to allow renaming the column + parts := strings.Split(name, "=") + name = parts[0] + if field, ok := toScope.FieldByName(name); ok { // association foreign keys (db names) relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) + + // If a new name was provided for the field, use it + name = field.DBName + if len(parts) > 1 { + name = parts[1] + } + // join table foreign keys for association - joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + joinTableDBName := ToDBName(elemType.Name()) + "_" + name relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) } }