diff --git a/model_struct.go b/model_struct.go index 119e6dc9..6e1ff055 100644 --- a/model_struct.go +++ b/model_struct.go @@ -256,18 +256,27 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.JoinTableHandler = &joinTableHandler field.Relationship = relationship } else { + relationship.Kind = "has_many" + if len(foreignKeys) == 0 { for _, field := range scope.PrimaryFields() { - foreignKeys = append(foreignKeys, scopeType.Name()+field.Name) + if foreignField := getForeignField(scopeType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } } - } - - relationship.Kind = "has_many" - for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - foreignField.IsForeignKey = true + } else { + for _, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } } } @@ -293,15 +302,23 @@ func (scope *Scope) GetModelStruct() *ModelStruct { belongsToForeignKeys := foreignKeys if len(belongsToForeignKeys) == 0 { for _, field := range toScope.PrimaryFields() { - belongsToForeignKeys = append(belongsToForeignKeys, field.Name+field.Name) + if foreignField := getForeignField(field.Name+field.Name, fields); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } } - } - - for _, foreignKey := range belongsToForeignKeys { - if foreignField := getForeignField(foreignKey, fields); foreignField != nil { - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - foreignField.IsForeignKey = true + } else { + for _, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, fields); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, toScope.PrimaryField().Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, toScope.PrimaryField().DBName) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } } } @@ -312,15 +329,23 @@ func (scope *Scope) GetModelStruct() *ModelStruct { hasOneForeignKeys := foreignKeys if len(hasOneForeignKeys) == 0 { for _, field := range toScope.PrimaryFields() { - hasOneForeignKeys = append(hasOneForeignKeys, modelStruct.ModelType.Name()+field.Name) + if foreignField := getForeignField(modelStruct.ModelType.Name()+field.Name, fields); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } } - } - - for _, foreignKey := range hasOneForeignKeys { - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - foreignField.IsForeignKey = true + } else { + for _, foreignKey := range hasOneForeignKeys { + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } } } diff --git a/scope_private.go b/scope_private.go index edd0dbe9..931db3de 100644 --- a/scope_private.go +++ b/scope_private.go @@ -415,12 +415,21 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { joinTableHandler := relationship.JoinTableHandler scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error) } else if relationship.Kind == "belongs_to" { - sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) - foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface() - scope.Err(toScope.db.Where(sql, foreignKeyValue).Find(value).Error) + query := toScope.db + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + scope.Err(query.Find(value).Error) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName)) - query := toScope.db.Where(sql, scope.PrimaryKeyValue()) + query := toScope.db + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + if relationship.PolymorphicType != "" { query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) }