diff --git a/callback_create.go b/callback_create.go index 7f21ed6a..bded5324 100644 --- a/callback_create.go +++ b/callback_create.go @@ -35,9 +35,11 @@ func Create(scope *Scope) { } } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - if relationField := fields[relationship.ForeignDBName]; !scope.changeableField(relationField) { - columns = append(columns, scope.Quote(relationField.DBName)) - sqls = append(sqls, scope.AddToVars(relationField.Field.Interface())) + for _, dbName := range relationship.ForeignDBNames { + if relationField := fields[dbName]; !scope.changeableField(relationField) { + columns = append(columns, scope.Quote(relationField.DBName)) + sqls = append(sqls, scope.AddToVars(relationField.Field.Interface())) + } } } } diff --git a/callback_shared.go b/callback_shared.go index c1b9bd00..1e9d320f 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -19,8 +19,13 @@ func SaveBeforeAssociations(scope *Scope) { if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { value := field.Field scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error) - if relationship.ForeignFieldName != "" { - scope.Err(scope.SetColumn(relationship.ForeignFieldName, scope.New(value.Addr().Interface()).PrimaryKeyValue())) + if len(relationship.ForeignFieldNames) != 0 { + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, f.Field.Interface())) + } + } } } } @@ -44,8 +49,13 @@ func SaveAfterAssociations(scope *Scope) { elem := value.Index(i).Addr().Interface() newScope := newDB.NewScope(elem) - if relationship.JoinTableHandler == nil && relationship.ForeignFieldName != "" { - scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue())) + if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, f.Field.Interface())) + } + } } if relationship.PolymorphicType != "" { @@ -61,8 +71,13 @@ func SaveAfterAssociations(scope *Scope) { default: elem := value.Addr().Interface() newScope := scope.New(elem) - if relationship.ForeignFieldName != "" { - scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue())) + if len(relationship.ForeignFieldNames) != 0 { + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, f.Field.Interface())) + } + } } if relationship.PolymorphicType != "" { diff --git a/callback_update.go b/callback_update.go index c3f7b4b6..6090ee6b 100644 --- a/callback_update.go +++ b/callback_update.go @@ -55,9 +55,10 @@ func Update(scope *Scope) { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - if relationField := fields[relationship.ForeignDBName]; !scope.changeableField(relationField) { - if !relationField.IsBlank { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface()))) + for _, dbName := range relationship.ForeignDBNames { + if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank { + sql := fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface())) + sqls = append(sqls, sql) } } } diff --git a/join_table_handler.go b/join_table_handler.go index 07ecee2e..10e1e848 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -45,41 +45,18 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s s.TableName = tableName s.Source = JoinTableSource{ModelType: source} - sourceScope := &Scope{Value: reflect.New(source).Interface()} - sourcePrimaryFields := sourceScope.GetModelStruct().PrimaryFields - for _, primaryField := range sourcePrimaryFields { - if relationship.ForeignDBName == "" { - relationship.ForeignFieldName = source.Name() + primaryField.Name - relationship.ForeignDBName = ToDBName(relationship.ForeignFieldName) - } - - var dbName string - if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" { - dbName = relationship.ForeignDBName - } else { - dbName = ToDBName(source.Name() + primaryField.Name) - } - + for idx, dbName := range relationship.ForeignFieldNames { s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ - DBName: dbName, - AssociationDBName: primaryField.DBName, + DBName: relationship.ForeignDBNames[idx], + AssociationDBName: dbName, }) } s.Destination = JoinTableSource{ModelType: destination} - destinationScope := &Scope{Value: reflect.New(destination).Interface()} - destinationPrimaryFields := destinationScope.GetModelStruct().PrimaryFields - for _, primaryField := range destinationPrimaryFields { - var dbName string - if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" { - dbName = relationship.AssociationForeignDBName - } else { - dbName = ToDBName(destinationScope.GetModelStruct().ModelType.Name() + primaryField.Name) - } - + for idx, dbName := range relationship.AssociationForeignFieldNames { s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ - DBName: dbName, - AssociationDBName: primaryField.DBName, + DBName: relationship.AssociationForeignDBNames[idx], + AssociationDBName: dbName, }) } } diff --git a/main.go b/main.go index aba51fc4..7c4c4df4 100644 --- a/main.go +++ b/main.go @@ -445,7 +445,7 @@ func (s *DB) Association(column string) *Association { err = errors.New("primary key can't be nil") } else { if field, ok := scope.FieldByName(column); ok { - if field.Relationship == nil || field.Relationship.ForeignFieldName == "" { + if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) } else { return &Association{Scope: scope, Column: column, PrimaryKey: primaryField.Field.Interface(), Field: field} diff --git a/model_struct.go b/model_struct.go index 10423ae2..119e6dc9 100644 --- a/model_struct.go +++ b/model_struct.go @@ -61,14 +61,14 @@ func (structField *StructField) clone() *StructField { } type Relationship struct { - Kind string - PolymorphicType string - PolymorphicDBName string - ForeignFieldName string - ForeignDBName string - AssociationForeignFieldName string - AssociationForeignDBName string - JoinTableHandler JoinTableHandlerInterface + Kind string + PolymorphicType string + PolymorphicDBName string + ForeignFieldNames []string + ForeignDBNames []string + AssociationForeignFieldNames []string + AssociationForeignDBNames []string + JoinTableHandler JoinTableHandlerInterface } var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")} @@ -190,12 +190,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct { var relationship = &Relationship{} - foreignKey := gormSettings["FOREIGNKEY"] if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" { if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil { if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { - relationship.ForeignFieldName = polymorphicField.Name - relationship.ForeignDBName = polymorphicField.DBName + relationship.ForeignFieldNames = []string{polymorphicField.Name} + relationship.ForeignDBNames = []string{polymorphicField.DBName} relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicDBName = polymorphicType.DBName polymorphicType.IsForeignKey = true @@ -204,6 +203,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } + var foreignKeys []string + if foreignKey, ok := gormSettings["FOREIGNKEY"]; ok { + foreignKeys := append(foreignKeys, gormSettings["FOREIGNKEY"]) + } switch indirectType.Kind() { case reflect.Slice: elemType := indirectType.Elem() @@ -212,34 +215,63 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } if elemType.Kind() == reflect.Struct { - if foreignKey == "" { - foreignKey = scopeType.Name() + "Id" - } - if many2many := gormSettings["MANY2MANY"]; many2many != "" { relationship.Kind = "many_to_many" - associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"] - if associationForeignKey == "" { - associationForeignKey = elemType.Name() + "Id" + + // foreign keys + if len(foreignKeys) == 0 { + for _, field := range scope.PrimaryFields() { + foreignKeys = append(foreignKeys, field.DBName) + } } - relationship.ForeignFieldName = foreignKey - relationship.ForeignDBName = ToDBName(foreignKey) - relationship.AssociationForeignFieldName = associationForeignKey - relationship.AssociationForeignDBName = ToDBName(associationForeignKey) + for _, foreignKey := range foreignKeys { + if field, ok := scope.FieldByName(foreignKey); ok { + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName) + joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName + relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) + } + } + + // association foreign keys + var associationForeignKeys []string + if foreignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + associationForeignKeys = []string{gormSettings["ASSOCIATIONFOREIGNKEY"]} + } else { + for _, field := range toScope.PrimaryFields() { + associationForeignKeys = append(associationForeignKeys, field.DBName) + } + } + + for _, name := range associationForeignKeys { + if field, ok := toScope.FieldByName(name); ok { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, name) + joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) + } + } joinTableHandler := JoinTableHandler{} joinTableHandler.Setup(relationship, many2many, scopeType, elemType) relationship.JoinTableHandler = &joinTableHandler field.Relationship = relationship } else { + if len(foreignKeys) == 0 { + for _, field := range scope.PrimaryFields() { + foreignKeys = append(foreignKeys, scopeType.Name()+field.Name) + } + } + relationship.Kind = "has_many" - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.ForeignFieldName = foreignField.Name - relationship.ForeignDBName = foreignField.DBName - foreignField.IsForeignKey = true - field.Relationship = relationship - } else if relationship.ForeignFieldName != "" { + for _, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } + } + + if len(relationship.ForeignFieldNames) != 0 { field.Relationship = relationship } } @@ -258,28 +290,42 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } continue } else { - belongsToForeignKey := foreignKey - if belongsToForeignKey == "" { - belongsToForeignKey = field.Name + "Id" + belongsToForeignKeys := foreignKeys + if len(belongsToForeignKeys) == 0 { + for _, field := range toScope.PrimaryFields() { + belongsToForeignKeys = append(belongsToForeignKeys, field.Name+field.Name) + } } - if foreignField := getForeignField(belongsToForeignKey, fields); foreignField != nil { + for _, foreignKey := range belongsToForeignKeys { + if foreignField := getForeignField(foreignKey, fields); foreignField != nil { + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } + } + + if len(relationship.ForeignFieldNames) != 0 { relationship.Kind = "belongs_to" - relationship.ForeignFieldName = foreignField.Name - relationship.ForeignDBName = foreignField.DBName - foreignField.IsForeignKey = true field.Relationship = relationship } else { - if foreignKey == "" { - foreignKey = modelStruct.ModelType.Name() + "Id" + hasOneForeignKeys := foreignKeys + if len(hasOneForeignKeys) == 0 { + for _, field := range toScope.PrimaryFields() { + hasOneForeignKeys = append(hasOneForeignKeys, modelStruct.ModelType.Name()+field.Name) + } } - relationship.Kind = "has_one" - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.ForeignFieldName = foreignField.Name - relationship.ForeignDBName = foreignField.DBName - foreignField.IsForeignKey = true - field.Relationship = relationship - } else if relationship.ForeignFieldName != "" { + + for _, foreignKey := range hasOneForeignKeys { + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } + } + + if len(relationship.ForeignFieldNames) != 0 { + relationship.Kind = "has_one" field.Relationship = relationship } }