From a948c846071f7e4fd264c6a95a81a0ef04293a28 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 5 Mar 2021 22:18:12 +0800 Subject: [PATCH] Revert "Revert "Don't override the from clauses, close #4129" close #4139" This reverts commit d6c23586ae435a124353d3c5dfa6f504c24c5c3c. --- callbacks/query.go | 6 ++++++ tests/sql_builder_test.go | 45 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/callbacks/query.go b/callbacks/query.go index 658216df..1868c247 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -104,6 +104,11 @@ func BuildQuerySQL(db *gorm.DB) { } joins := []clause.Join{} + + if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { + joins = fromClause.Joins + } + for _, join := range db.Statement.Joins { if db.Statement.Schema == nil { joins = append(joins, clause.Join{ @@ -154,6 +159,7 @@ func BuildQuerySQL(db *gorm.DB) { } } + db.Statement.Joins = nil db.Statement.AddClause(clause.From{Joins: joins}) } else { db.Statement.AddClauseIfNotExists(clause.From{}) diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index acb08130..081b96c9 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -6,6 +6,7 @@ import ( "testing" "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -242,3 +243,47 @@ func TestCombineStringConditions(t *testing.T) { t.Fatalf("invalid sql generated, got %v", sql) } } + +func TestFromWithJoins(t *testing.T) { + var result User + + newDB := DB.Session(&gorm.Session{NewDB: true, DryRun: true}).Table("users") + + newDB.Clauses( + clause.From{ + Tables: []clause.Table{{Name: "users"}}, + Joins: []clause.Join{ + { + Table: clause.Table{Name: "companies", Raw: false}, + ON: clause.Where{ + Exprs: []clause.Expression{ + clause.Eq{ + Column: clause.Column{ + Table: "users", + Name: "company_id", + }, + Value: clause.Column{ + Table: "companies", + Name: "id", + }, + }, + }, + }, + }, + }, + }, + ) + + newDB.Joins("inner join rgs on rgs.id = user.id") + + stmt := newDB.First(&result).Statement + str := stmt.SQL.String() + + if !strings.Contains(str, "rgs.id = user.id") { + t.Errorf("The second join condition is over written instead of combining") + } + + if !strings.Contains(str, "`users`.`company_id` = `companies`.`id`") && !strings.Contains(str, "\"users\".\"company_id\" = \"companies\".\"id\"") { + t.Errorf("The first join condition is over written instead of combining") + } +}