From 418c60c83cf8472d883bb9ab8b9821444e7c8f0a Mon Sep 17 00:00:00 2001 From: kinggo <30891428+longlihale@users.noreply.github.com> Date: Sat, 9 Oct 2021 16:55:45 +0800 Subject: [PATCH] fixed: clauseSelect.Columns missed when use Join And execute multiple query. (#4757) --- callbacks/query.go | 13 ++++++------- tests/joins_test.go | 27 +++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 1cfd618c..0eee2a43 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -95,7 +95,12 @@ func BuildQuerySQL(db *gorm.DB) { } // inline joins - if len(db.Statement.Joins) != 0 { + joins := []clause.Join{} + if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { + joins = fromClause.Joins + } + + if len(db.Statement.Joins) != 0 || len(joins) != 0 { if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil { clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) for idx, dbName := range db.Statement.Schema.DBNames { @@ -103,12 +108,6 @@ func BuildQuerySQL(db *gorm.DB) { } } - joins := []clause.Join{} - - if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { - joins = fromClause.Joins - } - for _, join := range db.Statement.Joins { if db.Statement.Schema == nil { joins = append(joins, clause.Join{ diff --git a/tests/joins_test.go b/tests/joins_test.go index 25fa20b4..ca8477dc 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -157,3 +157,30 @@ func TestJoinsWithSelect(t *testing.T) { t.Errorf("Should find all two pets with Join select, got %+v", results) } } + +func TestJoinCount(t *testing.T) { + companyA := Company{Name: "A"} + companyB := Company{Name: "B"} + DB.Create(&companyA) + DB.Create(&companyB) + + user := User{Name: "kingGo", CompanyID: &companyB.ID} + DB.Create(&user) + + query := DB.Model(&User{}).Joins("Company") + //Bug happens when .Count is called on a query. + //Removing the below two lines or downgrading to gorm v1.20.12 will make this test pass. + var total int64 + query.Count(&total) + + var result User + + // Incorrectly generates a 'SELECT *' query which causes companies.id to overwrite users.id + if err := query.First(&result, user.ID).Error; err != nil { + t.Fatalf("Failed, got error: %v", err) + } + + if result.ID != user.ID { + t.Fatalf("result's id, %d, doesn't match user's id, %d", result.ID, user.ID) + } +}