From 3195ae12072f51d15064a3428f4e906c6873c4e2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 25 Aug 2020 18:59:19 +0800 Subject: [PATCH] Allow override alias table in preload conditions --- callbacks/preload.go | 6 +++--- tests/preload_test.go | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index cd09a6d6..25b8cb2b 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -50,7 +50,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { joinResults := rel.JoinTable.MakeSlice().Elem() column, values := schema.ToQueryValues(rel.JoinTable.Table, joinForeignKeys, joinForeignValues) - tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()) + db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error) // convert join identity map to relation identity map fieldValues := make([]interface{}, len(joinForeignFields)) @@ -93,7 +93,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } reflectResults := rel.FieldSchema.MakeSlice().Elem() - column, values := schema.ToQueryValues(rel.FieldSchema.Table, relForeignKeys, foreignValues) + column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) for _, cond := range conds { if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { @@ -103,7 +103,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } } - tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...) + db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error) fieldValues := make([]interface{}, len(relForeignFields)) diff --git a/tests/preload_test.go b/tests/preload_test.go index 3caa17b4..7e5d2622 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -5,6 +5,7 @@ import ( "strconv" "testing" + "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -108,6 +109,20 @@ func TestPreloadWithConds(t *testing.T) { } CheckUser(t, users2[0], users[0]) + + var users3 []User + if err := DB.Preload("Account", func(tx *gorm.DB) *gorm.DB { + return tx.Table("accounts AS a").Select("a.*") + }).Find(&users3, "id IN ?", userIDs).Error; err != nil { + t.Errorf("failed to query, got error %v", err) + } + sort.Slice(users3, func(i, j int) bool { + return users2[i].ID < users2[j].ID + }) + + for i, u := range users3 { + CheckUser(t, u, users[i]) + } } func TestNestedPreloadWithConds(t *testing.T) {