diff --git a/tests/scan_test.go b/tests/scan_test.go index 1a188fac..89e2cb3c 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -10,6 +10,11 @@ import ( . "gorm.io/gorm/utils/tests" ) +type PersonAddressInfo struct { + Person *Person `gorm:"embedded"` + Address *Address `gorm:"embedded"` +} + func TestScan(t *testing.T) { user1 := User{Name: "ScanUser1", Age: 1} user2 := User{Name: "ScanUser2", Age: 10} @@ -156,3 +161,34 @@ func TestScanRows(t *testing.T) { t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name) } } + +func TestScanToEmbedded(t *testing.T) { + person1 := Person{Name: "person 1"} + person2 := Person{Name: "person 2"} + DB.Save(&person1).Save(&person2) + + address1 := Address{Name: "address 1"} + address2 := Address{Name: "address 2"} + DB.Save(&address1).Save(&address2) + + DB.Create(&PersonAddress{PersonID: person1.ID, AddressID: int(address1.ID)}) + DB.Create(&PersonAddress{PersonID: person1.ID, AddressID: int(address2.ID)}) + DB.Create(&PersonAddress{PersonID: person2.ID, AddressID: int(address1.ID)}) + + var personAddressInfoList []*PersonAddressInfo + if err := DB.Select("people.*, addresses.*"). + Table("people"). + Joins("inner join person_addresses on people.id = person_addresses.person_id"). + Joins("inner join addresses on person_addresses.address_id = addresses.id"). + Find(&personAddressInfoList).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + for _, info := range personAddressInfoList { + if info.Person != nil { + if info.Person.ID == person1.ID && info.Person.Name != person1.Name { + t.Errorf("Failed, expected %v, got %v", person1.Name, info.Person.Name) + } + } + } +}