Test overwrite foreign keys

This commit is contained in:
Jinzhu 2016-03-07 23:51:04 +08:00
parent 2c089573cd
commit 2e9d5e6f76
2 changed files with 106 additions and 14 deletions

View File

@ -191,12 +191,12 @@ func TestBelongsToOverrideForeignKey1(t *testing.T) {
ProfileRefer int
}
DB.AutoMigrate(&User{})
DB.AutoMigrate(&Profile{})
var user = User{Model: gorm.Model{ID: 1}, ProfileRefer: 10}
if err := DB.Model(&user).Association("Profile").Find(&[]Profile{}).Error; err != nil {
t.Errorf("Override belongs to foreign key with tag")
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
if relation.Relationship.Kind != "belongs_to" ||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileRefer"}) ||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
t.Errorf("Override belongs to foreign key with tag")
}
}
}
@ -213,12 +213,12 @@ func TestBelongsToOverrideForeignKey2(t *testing.T) {
ProfileID int
}
DB.AutoMigrate(&User{})
DB.AutoMigrate(&Profile{})
var user = User{Model: gorm.Model{ID: 1}, ProfileID: 10}
if err := DB.Model(&user).Association("Profile").Find(&[]Profile{}).Error; err != nil {
t.Errorf("Override belongs to foreign key with tag")
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
if relation.Relationship.Kind != "belongs_to" ||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileID"}) ||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
t.Errorf("Override belongs to foreign key with tag")
}
}
}
@ -368,6 +368,49 @@ func TestHasOne(t *testing.T) {
}
}
func TestHasOneOverrideForeignKey1(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserRefer uint
}
type User struct {
gorm.Model
Profile Profile `gorm:"ForeignKey:UserRefer"`
}
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
if relation.Relationship.Kind != "has_one" ||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) ||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
t.Errorf("Override belongs to foreign key with tag")
}
}
}
func TestHasOneOverrideForeignKey2(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserID uint
}
type User struct {
gorm.Model
Refer string
Profile Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"`
}
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
if relation.Relationship.Kind != "has_one" ||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) ||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
t.Errorf("Override belongs to foreign key with tag")
}
}
}
func TestHasMany(t *testing.T) {
post := Post{
Title: "post has many",
@ -507,6 +550,49 @@ func TestHasMany(t *testing.T) {
}
}
func TestHasManyOverrideForeignKey1(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserRefer uint
}
type User struct {
gorm.Model
Profile []Profile `gorm:"ForeignKey:UserRefer"`
}
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
if relation.Relationship.Kind != "has_many" ||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) ||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
t.Errorf("Override belongs to foreign key with tag")
}
}
}
func TestHasManyOverrideForeignKey2(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserID uint
}
type User struct {
gorm.Model
Refer string
Profile []Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"`
}
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
if relation.Relationship.Kind != "has_many" ||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) ||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
t.Errorf("Override belongs to foreign key with tag")
}
}
}
func TestManyToMany(t *testing.T) {
DB.Raw("delete from languages")
var languages = []Language{{Name: "ZH"}, {Name: "EN"}}

View File

@ -298,7 +298,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
if len(associationForeignKeys) == 0 {
for _, foreignKey := range foreignKeys {
if strings.HasPrefix(foreignKey, associationType) {
associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, associationType))
associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
}
}
}
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
@ -391,7 +394,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
if len(associationForeignKeys) == 0 {
for _, foreignKey := range foreignKeys {
if strings.HasPrefix(foreignKey, associationType) {
associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, associationType))
associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
}
}
}
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {