From 9c4070ed19fd0ba9c93b49131deffec31d678610 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=BE=E4=B8=80=E9=A5=BC?= Date: Wed, 12 Jun 2024 17:51:44 +0800 Subject: [PATCH] fix: AfterQuery should clear FROM Clause's Joins rather than the Statement (#7027) --- callbacks/query.go | 6 +++++- tests/joins_test.go | 8 +++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 2a82eaba..9b2b17ea 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -286,7 +286,11 @@ func Preload(db *gorm.DB) { func AfterQuery(db *gorm.DB) { // clear the joins after query because preload need it - db.Statement.Joins = nil + if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { + fromClause := db.Statement.Clauses["FROM"] + fromClause.Expression = clause.From{Tables: v.Tables, Joins: v.Joins[:len(v.Joins)-len(db.Statement.Joins)]} // keep the original From Joins + db.Statement.Clauses["FROM"] = fromClause + } if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(AfterFindInterface); ok { diff --git a/tests/joins_test.go b/tests/joins_test.go index 786fc37e..bcb60c88 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -184,14 +184,12 @@ func TestJoinCount(t *testing.T) { DB.Create(&user) query := DB.Model(&User{}).Joins("Company") - // Bug happens when .Count is called on a query. - // Removing the below two lines or downgrading to gorm v1.20.12 will make this test pass. + var total int64 query.Count(&total) var result User - // Incorrectly generates a 'SELECT *' query which causes companies.id to overwrite users.id if err := query.First(&result, user.ID).Error; err != nil { t.Fatalf("Failed, got error: %v", err) } @@ -199,6 +197,10 @@ func TestJoinCount(t *testing.T) { if result.ID != user.ID { t.Fatalf("result's id, %d, doesn't match user's id, %d", result.ID, user.ID) } + // should find company + if result.Company.ID != *user.CompanyID { + t.Fatalf("result's id, %d, doesn't match user's company id, %d", result.Company.ID, *user.CompanyID) + } } func TestJoinWithSoftDeleted(t *testing.T) {