diff --git a/association_test.go b/association_test.go index d7984b12..51a0c553 100644 --- a/association_test.go +++ b/association_test.go @@ -1,6 +1,9 @@ package gorm_test -import "testing" +import ( + "fmt" + "testing" +) func TestHasOneAndHasManyAssociation(t *testing.T) { DB.DropTable(Category{}) @@ -219,3 +222,37 @@ func TestManyToMany(t *testing.T) { t.Errorf("Relations should be cleared") } } + +func TestForeignKey(t *testing.T) { + for _, structField := range DB.NewScope(&User{}).GetStructFields() { + for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} { + if structField.Name == foreignKey && !structField.IsForeignKey { + t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) + } + } + } + + for _, structField := range DB.NewScope(&Email{}).GetStructFields() { + for _, foreignKey := range []string{"UserId"} { + if structField.Name == foreignKey && !structField.IsForeignKey { + t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) + } + } + } + + for _, structField := range DB.NewScope(&Post{}).GetStructFields() { + for _, foreignKey := range []string{"CategoryId", "MainCategoryId"} { + if structField.Name == foreignKey && !structField.IsForeignKey { + t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) + } + } + } + + for _, structField := range DB.NewScope(&Comment{}).GetStructFields() { + for _, foreignKey := range []string{"PostId"} { + if structField.Name == foreignKey && !structField.IsForeignKey { + t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) + } + } + } +} diff --git a/model_struct.go b/model_struct.go index ed299e85..25c8e50c 100644 --- a/model_struct.go +++ b/model_struct.go @@ -66,7 +66,7 @@ type Relationship struct { var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")} var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"} -func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct { +func (scope *Scope) GetModelStruct() *ModelStruct { var modelStruct ModelStruct reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value)) @@ -164,9 +164,10 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct { field.IsNormal = true } - if !field.IsNormal && len(noRelationship) == 0 { + if !field.IsNormal { gormSettings := parseTagSetting(field.Tag.Get("gorm")) - toModelStruct := scope.New(reflect.New(fieldStruct.Type).Interface()).GetModelStruct(true) + toScope := scope.New(reflect.New(fieldStruct.Type).Interface()) + getForeignField := func(column string, fields []*StructField) *StructField { for _, field := range fields { if field.Name == column || field.DBName == ToDBName(column) { @@ -180,8 +181,8 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct { foreignKey := gormSettings["FOREIGNKEY"] if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" { - if polymorphicField := getForeignField(polymorphic+"Id", toModelStruct.StructFields); polymorphicField != nil { - if polymorphicType := getForeignField(polymorphic+"Type", toModelStruct.StructFields); polymorphicType != nil { + if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil { + if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { relationship.ForeignFieldName = polymorphicField.Name relationship.ForeignDBName = polymorphicField.DBName relationship.PolymorphicType = polymorphicType.Name @@ -194,7 +195,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct { switch indirectType.Kind() { case reflect.Slice: - if len(toModelStruct.StructFields) > 0 { + if toStructFields := toScope.GetStructFields(); len(toStructFields) > 0 { if foreignKey == "" { foreignKey = scopeType.Name() + "Id" } @@ -205,7 +206,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct { associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"] if associationForeignKey == "" { - associationForeignKey = toModelStruct.ModelType.Name() + "Id" + associationForeignKey = toScope.GetModelStruct().ModelType.Name() + "Id" } relationship.ForeignFieldName = foreignKey @@ -215,7 +216,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct { field.Relationship = relationship } else { relationship.Kind = "has_many" - if foreignField := getForeignField(foreignKey, toModelStruct.StructFields); foreignField != nil { + if foreignField := getForeignField(foreignKey, toStructFields); foreignField != nil { relationship.ForeignFieldName = foreignField.Name relationship.ForeignDBName = foreignField.DBName foreignField.IsForeignKey = true @@ -229,12 +230,12 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct { } case reflect.Struct: if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { - for _, f := range scope.New(reflect.New(indirectType).Interface()).GetStructFields() { - f = f.clone() - f.Names = append([]string{fieldStruct.Name}, f.Names...) - modelStruct.StructFields = append(modelStruct.StructFields, f) - if f.IsPrimaryKey { - modelStruct.PrimaryKeyField = f + for _, toField := range toScope.GetStructFields() { + toField = toField.clone() + toField.Names = append([]string{fieldStruct.Name}, toField.Names...) + modelStruct.StructFields = append(modelStruct.StructFields, toField) + if toField.IsPrimaryKey { + modelStruct.PrimaryKeyField = toField } } continue @@ -255,7 +256,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct { foreignKey = modelStruct.ModelType.Name() + "Id" } relationship.Kind = "has_one" - if foreignField := getForeignField(foreignKey, toModelStruct.StructFields); foreignField != nil { + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { relationship.ForeignFieldName = foreignField.Name relationship.ForeignDBName = foreignField.DBName foreignField.IsForeignKey = true @@ -284,7 +285,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct { modelStruct.StructFields = append(modelStruct.StructFields, field) } - if scope.db != nil && len(noRelationship) == 0 { + if scope.db != nil { scope.db.parent.ModelStructs[scopeType] = &modelStruct }