mirror of https://github.com/go-gorm/gorm.git
Set IsForeignKey for StructField
This commit is contained in:
parent
22b1a93e03
commit
da8fc53c86
|
@ -1,6 +1,9 @@
|
||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
func TestHasOneAndHasManyAssociation(t *testing.T) {
|
func TestHasOneAndHasManyAssociation(t *testing.T) {
|
||||||
DB.DropTable(Category{})
|
DB.DropTable(Category{})
|
||||||
|
@ -219,3 +222,37 @@ func TestManyToMany(t *testing.T) {
|
||||||
t.Errorf("Relations should be cleared")
|
t.Errorf("Relations should be cleared")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestForeignKey(t *testing.T) {
|
||||||
|
for _, structField := range DB.NewScope(&User{}).GetStructFields() {
|
||||||
|
for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} {
|
||||||
|
if structField.Name == foreignKey && !structField.IsForeignKey {
|
||||||
|
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, structField := range DB.NewScope(&Email{}).GetStructFields() {
|
||||||
|
for _, foreignKey := range []string{"UserId"} {
|
||||||
|
if structField.Name == foreignKey && !structField.IsForeignKey {
|
||||||
|
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, structField := range DB.NewScope(&Post{}).GetStructFields() {
|
||||||
|
for _, foreignKey := range []string{"CategoryId", "MainCategoryId"} {
|
||||||
|
if structField.Name == foreignKey && !structField.IsForeignKey {
|
||||||
|
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, structField := range DB.NewScope(&Comment{}).GetStructFields() {
|
||||||
|
for _, foreignKey := range []string{"PostId"} {
|
||||||
|
if structField.Name == foreignKey && !structField.IsForeignKey {
|
||||||
|
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -66,7 +66,7 @@ type Relationship struct {
|
||||||
var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")}
|
var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")}
|
||||||
var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"}
|
var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"}
|
||||||
|
|
||||||
func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
|
func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
var modelStruct ModelStruct
|
var modelStruct ModelStruct
|
||||||
|
|
||||||
reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value))
|
reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value))
|
||||||
|
@ -164,9 +164,10 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
|
||||||
field.IsNormal = true
|
field.IsNormal = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if !field.IsNormal && len(noRelationship) == 0 {
|
if !field.IsNormal {
|
||||||
gormSettings := parseTagSetting(field.Tag.Get("gorm"))
|
gormSettings := parseTagSetting(field.Tag.Get("gorm"))
|
||||||
toModelStruct := scope.New(reflect.New(fieldStruct.Type).Interface()).GetModelStruct(true)
|
toScope := scope.New(reflect.New(fieldStruct.Type).Interface())
|
||||||
|
|
||||||
getForeignField := func(column string, fields []*StructField) *StructField {
|
getForeignField := func(column string, fields []*StructField) *StructField {
|
||||||
for _, field := range fields {
|
for _, field := range fields {
|
||||||
if field.Name == column || field.DBName == ToDBName(column) {
|
if field.Name == column || field.DBName == ToDBName(column) {
|
||||||
|
@ -180,8 +181,8 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
|
||||||
|
|
||||||
foreignKey := gormSettings["FOREIGNKEY"]
|
foreignKey := gormSettings["FOREIGNKEY"]
|
||||||
if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" {
|
if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" {
|
||||||
if polymorphicField := getForeignField(polymorphic+"Id", toModelStruct.StructFields); polymorphicField != nil {
|
if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil {
|
||||||
if polymorphicType := getForeignField(polymorphic+"Type", toModelStruct.StructFields); polymorphicType != nil {
|
if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil {
|
||||||
relationship.ForeignFieldName = polymorphicField.Name
|
relationship.ForeignFieldName = polymorphicField.Name
|
||||||
relationship.ForeignDBName = polymorphicField.DBName
|
relationship.ForeignDBName = polymorphicField.DBName
|
||||||
relationship.PolymorphicType = polymorphicType.Name
|
relationship.PolymorphicType = polymorphicType.Name
|
||||||
|
@ -194,7 +195,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
|
||||||
|
|
||||||
switch indirectType.Kind() {
|
switch indirectType.Kind() {
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
if len(toModelStruct.StructFields) > 0 {
|
if toStructFields := toScope.GetStructFields(); len(toStructFields) > 0 {
|
||||||
if foreignKey == "" {
|
if foreignKey == "" {
|
||||||
foreignKey = scopeType.Name() + "Id"
|
foreignKey = scopeType.Name() + "Id"
|
||||||
}
|
}
|
||||||
|
@ -205,7 +206,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
|
||||||
|
|
||||||
associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]
|
associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]
|
||||||
if associationForeignKey == "" {
|
if associationForeignKey == "" {
|
||||||
associationForeignKey = toModelStruct.ModelType.Name() + "Id"
|
associationForeignKey = toScope.GetModelStruct().ModelType.Name() + "Id"
|
||||||
}
|
}
|
||||||
|
|
||||||
relationship.ForeignFieldName = foreignKey
|
relationship.ForeignFieldName = foreignKey
|
||||||
|
@ -215,7 +216,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
|
||||||
field.Relationship = relationship
|
field.Relationship = relationship
|
||||||
} else {
|
} else {
|
||||||
relationship.Kind = "has_many"
|
relationship.Kind = "has_many"
|
||||||
if foreignField := getForeignField(foreignKey, toModelStruct.StructFields); foreignField != nil {
|
if foreignField := getForeignField(foreignKey, toStructFields); foreignField != nil {
|
||||||
relationship.ForeignFieldName = foreignField.Name
|
relationship.ForeignFieldName = foreignField.Name
|
||||||
relationship.ForeignDBName = foreignField.DBName
|
relationship.ForeignDBName = foreignField.DBName
|
||||||
foreignField.IsForeignKey = true
|
foreignField.IsForeignKey = true
|
||||||
|
@ -229,12 +230,12 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
|
if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
|
||||||
for _, f := range scope.New(reflect.New(indirectType).Interface()).GetStructFields() {
|
for _, toField := range toScope.GetStructFields() {
|
||||||
f = f.clone()
|
toField = toField.clone()
|
||||||
f.Names = append([]string{fieldStruct.Name}, f.Names...)
|
toField.Names = append([]string{fieldStruct.Name}, toField.Names...)
|
||||||
modelStruct.StructFields = append(modelStruct.StructFields, f)
|
modelStruct.StructFields = append(modelStruct.StructFields, toField)
|
||||||
if f.IsPrimaryKey {
|
if toField.IsPrimaryKey {
|
||||||
modelStruct.PrimaryKeyField = f
|
modelStruct.PrimaryKeyField = toField
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
|
@ -255,7 +256,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
|
||||||
foreignKey = modelStruct.ModelType.Name() + "Id"
|
foreignKey = modelStruct.ModelType.Name() + "Id"
|
||||||
}
|
}
|
||||||
relationship.Kind = "has_one"
|
relationship.Kind = "has_one"
|
||||||
if foreignField := getForeignField(foreignKey, toModelStruct.StructFields); foreignField != nil {
|
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
|
||||||
relationship.ForeignFieldName = foreignField.Name
|
relationship.ForeignFieldName = foreignField.Name
|
||||||
relationship.ForeignDBName = foreignField.DBName
|
relationship.ForeignDBName = foreignField.DBName
|
||||||
foreignField.IsForeignKey = true
|
foreignField.IsForeignKey = true
|
||||||
|
@ -284,7 +285,7 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
|
||||||
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.db != nil && len(noRelationship) == 0 {
|
if scope.db != nil {
|
||||||
scope.db.parent.ModelStructs[scopeType] = &modelStruct
|
scope.db.parent.ModelStructs[scopeType] = &modelStruct
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue