From f3e2da5ba359f0d672249fc52f54ae41c5a66d3a Mon Sep 17 00:00:00 2001 From: Hasan Date: Thu, 17 Mar 2022 22:51:56 +0800 Subject: [PATCH] Added offset when scanning the result back to struct, close #5143 commit 9a2058164d44c98d7b586b87bed1757f89d6fad7 Author: Jinzhu Date: Thu Mar 17 22:34:19 2022 +0800 Refactor #5143 commit c259de21768936428c9d89f7b31afb95b8acb36a Author: Hasan Date: Mon Mar 14 20:04:01 2022 +0545 Update scan_test.go commit 09f127b49151a52fbb8b354a03e6610d4f70262f Author: Hasan Date: Mon Mar 14 19:23:47 2022 +0545 Added test for scanning embedded data into structs commit aeaca493cf412def7813d36fd6a68acc832bf79f Author: Hasan Date: Tue Mar 8 04:08:16 2022 +0600 Added offset when scanning the result back to struct --- scan.go | 22 +++++++++++++++++----- tests/go.mod | 2 +- tests/scan_test.go | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 6 deletions(-) diff --git a/scan.go b/scan.go index a4243d12..89d92354 100644 --- a/scan.go +++ b/scan.go @@ -156,10 +156,11 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { } default: var ( - fields = make([]*schema.Field, len(columns)) - joinFields [][2]*schema.Field - sch = db.Statement.Schema - reflectValue = db.Statement.ReflectValue + fields = make([]*schema.Field, len(columns)) + selectedColumnsMap = make(map[string]int, len(columns)) + joinFields [][2]*schema.Field + sch = db.Statement.Schema + reflectValue = db.Statement.ReflectValue ) if reflectValue.Kind() == reflect.Interface { @@ -194,7 +195,18 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { if sch != nil { for idx, column := range columns { if field := sch.LookUpField(column); field != nil && field.Readable { - fields[idx] = field + if curIndex, ok := selectedColumnsMap[column]; ok { + for fieldIndex, selectField := range sch.Fields[curIndex:] { + if selectField.DBName == column && selectField.Readable { + selectedColumnsMap[column] = curIndex + fieldIndex + 1 + fields[idx] = selectField + break + } + } + } else { + fields[idx] = field + selectedColumnsMap[column] = idx + } } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := sch.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { diff --git a/tests/go.mod b/tests/go.mod index 4ef7fbe2..9dfa26ff 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,7 +6,7 @@ require ( github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 github.com/jackc/pgx/v4 v4.15.0 // indirect - github.com/jinzhu/now v1.1.4 + github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.12 // indirect golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd // indirect diff --git a/tests/scan_test.go b/tests/scan_test.go index 1a188fac..ec1e652f 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -10,6 +10,11 @@ import ( . "gorm.io/gorm/utils/tests" ) +type PersonAddressInfo struct { + Person *Person `gorm:"embedded"` + Address *Address `gorm:"embedded"` +} + func TestScan(t *testing.T) { user1 := User{Name: "ScanUser1", Age: 1} user2 := User{Name: "ScanUser2", Age: 10} @@ -156,3 +161,34 @@ func TestScanRows(t *testing.T) { t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name) } } + +func TestScanToEmbedded(t *testing.T) { + person1 := Person{Name: "person 1"} + person2 := Person{Name: "person 2"} + DB.Save(&person1).Save(&person2) + + address1 := Address{Name: "address 1"} + address2 := Address{Name: "address 2"} + DB.Save(&address1).Save(&address2) + + DB.Create(&PersonAddress{PersonID: person1.ID, AddressID: int(address1.ID)}) + DB.Create(&PersonAddress{PersonID: person1.ID, AddressID: int(address2.ID)}) + DB.Create(&PersonAddress{PersonID: person2.ID, AddressID: int(address1.ID)}) + + var personAddressInfoList []*PersonAddressInfo + if err := DB.Select("people.*, addresses.*"). + Table("people"). + Joins("inner join person_addresses on people.id = person_addresses.person_id"). + Joins("inner join addresses on person_addresses.address_id = addresses.id"). + Find(&personAddressInfoList).Error; err != nil { + t.Errorf("Failed to run join query, got error: %v", err) + } + + for _, info := range personAddressInfoList { + if info.Person != nil { + if info.Person.ID == person1.ID && info.Person.Name != person1.Name { + t.Errorf("Failed, expected %v, got %v", person1.Name, info.Person.Name) + } + } + } +}