Set IsForeignKey for StructField

This commit is contained in:
Jinzhu 2015-02-28 15:54:38 +08:00
parent 22b1a93e03
commit da8fc53c86
2 changed files with 55 additions and 17 deletions

View File

@ -1,6 +1,9 @@
package gorm_test
import "testing"
import (
"fmt"
"testing"
)
func TestHasOneAndHasManyAssociation(t *testing.T) {
DB.DropTable(Category{})
@ -219,3 +222,37 @@ func TestManyToMany(t *testing.T) {
t.Errorf("Relations should be cleared")
}
}
func TestForeignKey(t *testing.T) {
for _, structField := range DB.NewScope(&User{}).GetStructFields() {
for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
}
}
for _, structField := range DB.NewScope(&Email{}).GetStructFields() {
for _, foreignKey := range []string{"UserId"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
}
}
for _, structField := range DB.NewScope(&Post{}).GetStructFields() {
for _, foreignKey := range []string{"CategoryId", "MainCategoryId"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
}
}
for _, structField := range DB.NewScope(&Comment{}).GetStructFields() {
for _, foreignKey := range []string{"PostId"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
}
}
}

View File

@ -66,7 +66,7 @@ type Relationship struct {
var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")}
var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"}
func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
func (scope *Scope) GetModelStruct() *ModelStruct {
var modelStruct ModelStruct
reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value))
@ -164,9 +164,10 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
field.IsNormal = true
}
if !field.IsNormal && len(noRelationship) == 0 {
if !field.IsNormal {
gormSettings := parseTagSetting(field.Tag.Get("gorm"))
toModelStruct := scope.New(reflect.New(fieldStruct.Type).Interface()).GetModelStruct(true)
toScope := scope.New(reflect.New(fieldStruct.Type).Interface())
getForeignField := func(column string, fields []*StructField) *StructField {
for _, field := range fields {
if field.Name == column || field.DBName == ToDBName(column) {
@ -180,8 +181,8 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
foreignKey := gormSettings["FOREIGNKEY"]
if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" {
if polymorphicField := getForeignField(polymorphic+"Id", toModelStruct.StructFields); polymorphicField != nil {
if polymorphicType := getForeignField(polymorphic+"Type", toModelStruct.StructFields); polymorphicType != nil {
if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil {
if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil {
relationship.ForeignFieldName = polymorphicField.Name
relationship.ForeignDBName = polymorphicField.DBName
relationship.PolymorphicType = polymorphicType.Name
@ -194,7 +195,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
switch indirectType.Kind() {
case reflect.Slice:
if len(toModelStruct.StructFields) > 0 {
if toStructFields := toScope.GetStructFields(); len(toStructFields) > 0 {
if foreignKey == "" {
foreignKey = scopeType.Name() + "Id"
}
@ -205,7 +206,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]
if associationForeignKey == "" {
associationForeignKey = toModelStruct.ModelType.Name() + "Id"
associationForeignKey = toScope.GetModelStruct().ModelType.Name() + "Id"
}
relationship.ForeignFieldName = foreignKey
@ -215,7 +216,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
field.Relationship = relationship
} else {
relationship.Kind = "has_many"
if foreignField := getForeignField(foreignKey, toModelStruct.StructFields); foreignField != nil {
if foreignField := getForeignField(foreignKey, toStructFields); foreignField != nil {
relationship.ForeignFieldName = foreignField.Name
relationship.ForeignDBName = foreignField.DBName
foreignField.IsForeignKey = true
@ -229,12 +230,12 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
}
case reflect.Struct:
if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
for _, f := range scope.New(reflect.New(indirectType).Interface()).GetStructFields() {
f = f.clone()
f.Names = append([]string{fieldStruct.Name}, f.Names...)
modelStruct.StructFields = append(modelStruct.StructFields, f)
if f.IsPrimaryKey {
modelStruct.PrimaryKeyField = f
for _, toField := range toScope.GetStructFields() {
toField = toField.clone()
toField.Names = append([]string{fieldStruct.Name}, toField.Names...)
modelStruct.StructFields = append(modelStruct.StructFields, toField)
if toField.IsPrimaryKey {
modelStruct.PrimaryKeyField = toField
}
}
continue
@ -255,7 +256,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
foreignKey = modelStruct.ModelType.Name() + "Id"
}
relationship.Kind = "has_one"
if foreignField := getForeignField(foreignKey, toModelStruct.StructFields); foreignField != nil {
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
relationship.ForeignFieldName = foreignField.Name
relationship.ForeignDBName = foreignField.DBName
foreignField.IsForeignKey = true
@ -284,7 +285,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
modelStruct.StructFields = append(modelStruct.StructFields, field)
}
if scope.db != nil && len(noRelationship) == 0 {
if scope.db != nil {
scope.db.parent.ModelStructs[scopeType] = &modelStruct
}