mirror of https://github.com/go-gorm/gorm.git
Set IsForeignKey for StructField
This commit is contained in:
parent
22b1a93e03
commit
da8fc53c86
|
@ -1,6 +1,9 @@
|
|||
package gorm_test
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHasOneAndHasManyAssociation(t *testing.T) {
|
||||
DB.DropTable(Category{})
|
||||
|
@ -219,3 +222,37 @@ func TestManyToMany(t *testing.T) {
|
|||
t.Errorf("Relations should be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForeignKey(t *testing.T) {
|
||||
for _, structField := range DB.NewScope(&User{}).GetStructFields() {
|
||||
for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} {
|
||||
if structField.Name == foreignKey && !structField.IsForeignKey {
|
||||
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, structField := range DB.NewScope(&Email{}).GetStructFields() {
|
||||
for _, foreignKey := range []string{"UserId"} {
|
||||
if structField.Name == foreignKey && !structField.IsForeignKey {
|
||||
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, structField := range DB.NewScope(&Post{}).GetStructFields() {
|
||||
for _, foreignKey := range []string{"CategoryId", "MainCategoryId"} {
|
||||
if structField.Name == foreignKey && !structField.IsForeignKey {
|
||||
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, structField := range DB.NewScope(&Comment{}).GetStructFields() {
|
||||
for _, foreignKey := range []string{"PostId"} {
|
||||
if structField.Name == foreignKey && !structField.IsForeignKey {
|
||||
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -66,7 +66,7 @@ type Relationship struct {
|
|||
var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")}
|
||||
var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"}
|
||||
|
||||
func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
|
||||
func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
var modelStruct ModelStruct
|
||||
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value))
|
||||
|
@ -164,9 +164,10 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
|
|||
field.IsNormal = true
|
||||
}
|
||||
|
||||
if !field.IsNormal && len(noRelationship) == 0 {
|
||||
if !field.IsNormal {
|
||||
gormSettings := parseTagSetting(field.Tag.Get("gorm"))
|
||||
toModelStruct := scope.New(reflect.New(fieldStruct.Type).Interface()).GetModelStruct(true)
|
||||
toScope := scope.New(reflect.New(fieldStruct.Type).Interface())
|
||||
|
||||
getForeignField := func(column string, fields []*StructField) *StructField {
|
||||
for _, field := range fields {
|
||||
if field.Name == column || field.DBName == ToDBName(column) {
|
||||
|
@ -180,8 +181,8 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
|
|||
|
||||
foreignKey := gormSettings["FOREIGNKEY"]
|
||||
if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" {
|
||||
if polymorphicField := getForeignField(polymorphic+"Id", toModelStruct.StructFields); polymorphicField != nil {
|
||||
if polymorphicType := getForeignField(polymorphic+"Type", toModelStruct.StructFields); polymorphicType != nil {
|
||||
if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil {
|
||||
if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil {
|
||||
relationship.ForeignFieldName = polymorphicField.Name
|
||||
relationship.ForeignDBName = polymorphicField.DBName
|
||||
relationship.PolymorphicType = polymorphicType.Name
|
||||
|
@ -194,7 +195,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
|
|||
|
||||
switch indirectType.Kind() {
|
||||
case reflect.Slice:
|
||||
if len(toModelStruct.StructFields) > 0 {
|
||||
if toStructFields := toScope.GetStructFields(); len(toStructFields) > 0 {
|
||||
if foreignKey == "" {
|
||||
foreignKey = scopeType.Name() + "Id"
|
||||
}
|
||||
|
@ -205,7 +206,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
|
|||
|
||||
associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]
|
||||
if associationForeignKey == "" {
|
||||
associationForeignKey = toModelStruct.ModelType.Name() + "Id"
|
||||
associationForeignKey = toScope.GetModelStruct().ModelType.Name() + "Id"
|
||||
}
|
||||
|
||||
relationship.ForeignFieldName = foreignKey
|
||||
|
@ -215,7 +216,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
|
|||
field.Relationship = relationship
|
||||
} else {
|
||||
relationship.Kind = "has_many"
|
||||
if foreignField := getForeignField(foreignKey, toModelStruct.StructFields); foreignField != nil {
|
||||
if foreignField := getForeignField(foreignKey, toStructFields); foreignField != nil {
|
||||
relationship.ForeignFieldName = foreignField.Name
|
||||
relationship.ForeignDBName = foreignField.DBName
|
||||
foreignField.IsForeignKey = true
|
||||
|
@ -229,12 +230,12 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
|
|||
}
|
||||
case reflect.Struct:
|
||||
if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
|
||||
for _, f := range scope.New(reflect.New(indirectType).Interface()).GetStructFields() {
|
||||
f = f.clone()
|
||||
f.Names = append([]string{fieldStruct.Name}, f.Names...)
|
||||
modelStruct.StructFields = append(modelStruct.StructFields, f)
|
||||
if f.IsPrimaryKey {
|
||||
modelStruct.PrimaryKeyField = f
|
||||
for _, toField := range toScope.GetStructFields() {
|
||||
toField = toField.clone()
|
||||
toField.Names = append([]string{fieldStruct.Name}, toField.Names...)
|
||||
modelStruct.StructFields = append(modelStruct.StructFields, toField)
|
||||
if toField.IsPrimaryKey {
|
||||
modelStruct.PrimaryKeyField = toField
|
||||
}
|
||||
}
|
||||
continue
|
||||
|
@ -255,7 +256,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
|
|||
foreignKey = modelStruct.ModelType.Name() + "Id"
|
||||
}
|
||||
relationship.Kind = "has_one"
|
||||
if foreignField := getForeignField(foreignKey, toModelStruct.StructFields); foreignField != nil {
|
||||
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
|
||||
relationship.ForeignFieldName = foreignField.Name
|
||||
relationship.ForeignDBName = foreignField.DBName
|
||||
foreignField.IsForeignKey = true
|
||||
|
@ -284,7 +285,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
|
|||
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
||||
}
|
||||
|
||||
if scope.db != nil && len(noRelationship) == 0 {
|
||||
if scope.db != nil {
|
||||
scope.db.parent.ModelStructs[scopeType] = &modelStruct
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue