From f00b95d305e086d2644f26ef1f445f4df391470e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jul 2015 22:59:25 +0800 Subject: [PATCH] Passed all tests for multiple primary keys --- association.go | 2 +- model_struct.go | 26 ++++++++++++++------------ 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/association.go b/association.go index b088c1dd..4d3fb15f 100644 --- a/association.go +++ b/association.go @@ -174,7 +174,7 @@ func (association *Association) Count() int { if relationship.PolymorphicType != "" { query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName()) } - query.Count(&count) + query.Table(newScope.TableName()).Count(&count) } else if relationship.Kind == "belongs_to" { query := scope.DB() for idx, foreignKey := range relationship.ForeignDBNames { diff --git a/model_struct.go b/model_struct.go index 50437778..468002d5 100644 --- a/model_struct.go +++ b/model_struct.go @@ -195,6 +195,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { relationship.ForeignFieldNames = []string{polymorphicField.Name} relationship.ForeignDBNames = []string{polymorphicField.DBName} + relationship.AssociationForeignFieldNames = []string{scope.PrimaryField().Name} + relationship.AssociationForeignDBNames = []string{scope.PrimaryField().DBName} relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicDBName = polymorphicType.DBName polymorphicType.IsForeignKey = true @@ -300,8 +302,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { continue } else { if len(foreignKeys) == 0 { - for _, f := range toScope.PrimaryFields() { - if foreignField := getForeignField(field.Name+f.Name, fields); foreignField != nil { + for _, f := range scope.PrimaryFields() { + if foreignField := getForeignField(modelStruct.ModelType.Name()+f.Name, toScope.GetStructFields()); foreignField != nil { relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName) relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) @@ -311,9 +313,9 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } else { for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, fields); foreignField != nil { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, toScope.PrimaryField().Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, toScope.PrimaryField().DBName) + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName) relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) foreignField.IsForeignKey = true @@ -322,12 +324,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "belongs_to" + relationship.Kind = "has_one" field.Relationship = relationship } else { if len(foreignKeys) == 0 { - for _, f := range scope.PrimaryFields() { - if foreignField := getForeignField(modelStruct.ModelType.Name()+f.Name, toScope.GetStructFields()); foreignField != nil { + for _, f := range toScope.PrimaryFields() { + if foreignField := getForeignField(field.Name+f.Name, fields); foreignField != nil { relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName) relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) @@ -337,9 +339,9 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } else { for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName) + if foreignField := getForeignField(foreignKey, fields); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, toScope.PrimaryField().Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, toScope.PrimaryField().DBName) relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) foreignField.IsForeignKey = true @@ -348,7 +350,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "has_one" + relationship.Kind = "belongs_to" field.Relationship = relationship } }