Set IsForeignKey for StructField

This commit is contained in:
Jinzhu 2015-02-28 15:54:38 +08:00
parent 22b1a93e03
commit da8fc53c86
2 changed files with 55 additions and 17 deletions

View File

@ -1,6 +1,9 @@
package gorm_test package gorm_test
import "testing" import (
"fmt"
"testing"
)
func TestHasOneAndHasManyAssociation(t *testing.T) { func TestHasOneAndHasManyAssociation(t *testing.T) {
DB.DropTable(Category{}) DB.DropTable(Category{})
@ -219,3 +222,37 @@ func TestManyToMany(t *testing.T) {
t.Errorf("Relations should be cleared") t.Errorf("Relations should be cleared")
} }
} }
func TestForeignKey(t *testing.T) {
for _, structField := range DB.NewScope(&User{}).GetStructFields() {
for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
}
}
for _, structField := range DB.NewScope(&Email{}).GetStructFields() {
for _, foreignKey := range []string{"UserId"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
}
}
for _, structField := range DB.NewScope(&Post{}).GetStructFields() {
for _, foreignKey := range []string{"CategoryId", "MainCategoryId"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
}
}
for _, structField := range DB.NewScope(&Comment{}).GetStructFields() {
for _, foreignKey := range []string{"PostId"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
}
}
}

View File

@ -66,7 +66,7 @@ type Relationship struct {
var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")} var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")}
var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"} var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"}
func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct { func (scope *Scope) GetModelStruct() *ModelStruct {
var modelStruct ModelStruct var modelStruct ModelStruct
reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value)) reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value))
@ -164,9 +164,10 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
field.IsNormal = true field.IsNormal = true
} }
if !field.IsNormal && len(noRelationship) == 0 { if !field.IsNormal {
gormSettings := parseTagSetting(field.Tag.Get("gorm")) gormSettings := parseTagSetting(field.Tag.Get("gorm"))
toModelStruct := scope.New(reflect.New(fieldStruct.Type).Interface()).GetModelStruct(true) toScope := scope.New(reflect.New(fieldStruct.Type).Interface())
getForeignField := func(column string, fields []*StructField) *StructField { getForeignField := func(column string, fields []*StructField) *StructField {
for _, field := range fields { for _, field := range fields {
if field.Name == column || field.DBName == ToDBName(column) { if field.Name == column || field.DBName == ToDBName(column) {
@ -180,8 +181,8 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
foreignKey := gormSettings["FOREIGNKEY"] foreignKey := gormSettings["FOREIGNKEY"]
if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" { if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" {
if polymorphicField := getForeignField(polymorphic+"Id", toModelStruct.StructFields); polymorphicField != nil { if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil {
if polymorphicType := getForeignField(polymorphic+"Type", toModelStruct.StructFields); polymorphicType != nil { if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil {
relationship.ForeignFieldName = polymorphicField.Name relationship.ForeignFieldName = polymorphicField.Name
relationship.ForeignDBName = polymorphicField.DBName relationship.ForeignDBName = polymorphicField.DBName
relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicType = polymorphicType.Name
@ -194,7 +195,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
switch indirectType.Kind() { switch indirectType.Kind() {
case reflect.Slice: case reflect.Slice:
if len(toModelStruct.StructFields) > 0 { if toStructFields := toScope.GetStructFields(); len(toStructFields) > 0 {
if foreignKey == "" { if foreignKey == "" {
foreignKey = scopeType.Name() + "Id" foreignKey = scopeType.Name() + "Id"
} }
@ -205,7 +206,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"] associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]
if associationForeignKey == "" { if associationForeignKey == "" {
associationForeignKey = toModelStruct.ModelType.Name() + "Id" associationForeignKey = toScope.GetModelStruct().ModelType.Name() + "Id"
} }
relationship.ForeignFieldName = foreignKey relationship.ForeignFieldName = foreignKey
@ -215,7 +216,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
field.Relationship = relationship field.Relationship = relationship
} else { } else {
relationship.Kind = "has_many" relationship.Kind = "has_many"
if foreignField := getForeignField(foreignKey, toModelStruct.StructFields); foreignField != nil { if foreignField := getForeignField(foreignKey, toStructFields); foreignField != nil {
relationship.ForeignFieldName = foreignField.Name relationship.ForeignFieldName = foreignField.Name
relationship.ForeignDBName = foreignField.DBName relationship.ForeignDBName = foreignField.DBName
foreignField.IsForeignKey = true foreignField.IsForeignKey = true
@ -229,12 +230,12 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
} }
case reflect.Struct: case reflect.Struct:
if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
for _, f := range scope.New(reflect.New(indirectType).Interface()).GetStructFields() { for _, toField := range toScope.GetStructFields() {
f = f.clone() toField = toField.clone()
f.Names = append([]string{fieldStruct.Name}, f.Names...) toField.Names = append([]string{fieldStruct.Name}, toField.Names...)
modelStruct.StructFields = append(modelStruct.StructFields, f) modelStruct.StructFields = append(modelStruct.StructFields, toField)
if f.IsPrimaryKey { if toField.IsPrimaryKey {
modelStruct.PrimaryKeyField = f modelStruct.PrimaryKeyField = toField
} }
} }
continue continue
@ -255,7 +256,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
foreignKey = modelStruct.ModelType.Name() + "Id" foreignKey = modelStruct.ModelType.Name() + "Id"
} }
relationship.Kind = "has_one" relationship.Kind = "has_one"
if foreignField := getForeignField(foreignKey, toModelStruct.StructFields); foreignField != nil { if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
relationship.ForeignFieldName = foreignField.Name relationship.ForeignFieldName = foreignField.Name
relationship.ForeignDBName = foreignField.DBName relationship.ForeignDBName = foreignField.DBName
foreignField.IsForeignKey = true foreignField.IsForeignKey = true
@ -284,7 +285,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
modelStruct.StructFields = append(modelStruct.StructFields, field) modelStruct.StructFields = append(modelStruct.StructFields, field)
} }
if scope.db != nil && len(noRelationship) == 0 { if scope.db != nil {
scope.db.parent.ModelStructs[scopeType] = &modelStruct scope.db.parent.ModelStructs[scopeType] = &modelStruct
} }