From 8aa7b475d4a891312c13d7b27ca904a066bfd3e5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 4 Jan 2016 12:21:35 +0800 Subject: [PATCH] Add more tests for one to many relations with customized foreign keys --- customize_column_test.go | 96 ++++++++++++++++++++++++++++++++++++++++ model_struct.go | 15 ++++--- 2 files changed, 105 insertions(+), 6 deletions(-) diff --git a/customize_column_test.go b/customize_column_test.go index 5897b2cd..29400cd0 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -139,3 +139,99 @@ func TestOneToOneWithCustomizedColumn(t *testing.T) { t.Errorf("Should preload one to one relation with customize foreign keys") } } + +type PromotionDiscount struct { + gorm.Model + Name string + Coupons []*PromotionCoupon `gorm:"ForeignKey:discount_id"` + Rule *PromotionRule `gorm:"ForeignKey:discount_id"` +} + +type PromotionCoupon struct { + gorm.Model + Code string + DiscountID uint + Discount PromotionDiscount +} + +type PromotionRule struct { + gorm.Model + Name string + Begin *time.Time + End *time.Time + DiscountID uint + Discount *PromotionDiscount +} + +func TestOneToManyWithCustomizedColumn(t *testing.T) { + DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{}) + DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{}) + + discount := PromotionDiscount{ + Name: "Happy New Year", + Coupons: []*PromotionCoupon{ + {Code: "newyear1"}, + {Code: "newyear2"}, + }, + } + + if err := DB.Create(&discount).Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } + + var discount1 PromotionDiscount + if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } + + if len(discount.Coupons) != 2 { + t.Errorf("should find two coupons") + } + + var coupon PromotionCoupon + if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } + + if coupon.Discount.Name != "Happy New Year" { + t.Errorf("should preload discount from coupon") + } +} + +func TestOneToOneWithPartialCustomizedColumn(t *testing.T) { + DB.DropTable(&PromotionDiscount{}, &PromotionRule{}) + DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{}) + + var begin = time.Now() + var end = time.Now().Add(24 * time.Hour) + discount := PromotionDiscount{ + Name: "Happy New Year 2", + Rule: &PromotionRule{ + Name: "time_limited", + Begin: &begin, + End: &end, + }, + } + + if err := DB.Create(&discount).Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } + + var discount1 PromotionDiscount + if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } + + if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) { + t.Errorf("Should be able to preload Rule") + } + + var rule PromotionRule + if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } + + if rule.Discount.Name != "Happy New Year 2" { + t.Errorf("should preload discount from rule") + } +} diff --git a/model_struct.go b/model_struct.go index d094edad..16032f06 100644 --- a/model_struct.go +++ b/model_struct.go @@ -299,10 +299,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for _, foreignKey := range foreignKeys { if strings.HasPrefix(foreignKey, associationType) { associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, associationType)) - } else { - scope.Err(fmt.Errorf("invalid foreign keys, foreign key %v should start with %v", foreignKey, associationType)) } } + if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { + associationForeignKeys = []string{scope.PrimaryKey()} + } } else if len(foreignKeys) != len(associationForeignKeys) { scope.Err(errors.New("invalid foreign keys, should have same length")) return @@ -391,10 +392,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for _, foreignKey := range foreignKeys { if strings.HasPrefix(foreignKey, associationType) { associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, associationType)) - } else { - scope.Err(fmt.Errorf("invalid foreign keys, foreign key %v should start with %v", foreignKey, associationType)) } } + if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { + associationForeignKeys = []string{scope.PrimaryKey()} + } } else if len(foreignKeys) != len(associationForeignKeys) { scope.Err(errors.New("invalid foreign keys, should have same length")) return @@ -446,10 +448,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for _, foreignKey := range foreignKeys { if strings.HasPrefix(foreignKey, field.Name) { associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, field.Name)) - } else { - scope.Err(fmt.Errorf("invalid foreign keys, foreign key %v should start with %v", foreignKey, field.Name)) } } + if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { + associationForeignKeys = []string{toScope.PrimaryKey()} + } } else if len(foreignKeys) != len(associationForeignKeys) { scope.Err(errors.New("invalid foreign keys, should have same length")) return