From 3a97639880a6a965c5e8209e2ff5557008e8b191 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Aug 2020 10:40:37 +0800 Subject: [PATCH] Fix unordered joins, close #3267 --- callbacks/query.go | 8 ++++---- chainable_api.go | 5 +---- statement.go | 13 +++++++++---- tests/joins_test.go | 8 ++++++++ 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 5ae1e904..f6cb32d5 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -104,12 +104,12 @@ func BuildQuerySQL(db *gorm.DB) { } joins := []clause.Join{} - for name, conds := range db.Statement.Joins { + for _, join := range db.Statement.Joins { if db.Statement.Schema == nil { joins = append(joins, clause.Join{ - Expression: clause.Expr{SQL: name, Vars: conds}, + Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, }) - } else if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { + } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { tableAliasName := relation.Name for _, s := range relation.FieldSchema.DBNames { @@ -149,7 +149,7 @@ func BuildQuerySQL(db *gorm.DB) { }) } else { joins = append(joins, clause.Join{ - Expression: clause.Expr{SQL: name, Vars: conds}, + Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, }) } } diff --git a/chainable_api.go b/chainable_api.go index 9b46a95b..e1b73457 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -172,10 +172,7 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() - if tx.Statement.Joins == nil { - tx.Statement.Joins = map[string][]interface{}{} - } - tx.Statement.Joins[query] = args + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args}) return } diff --git a/statement.go b/statement.go index 6114f468..214a15bb 100644 --- a/statement.go +++ b/statement.go @@ -29,7 +29,7 @@ type Statement struct { Distinct bool Selects []string // selected columns Omits []string // omit columns - Joins map[string][]interface{} + Joins []join Preloads map[string][]interface{} Settings sync.Map ConnPool ConnPool @@ -44,6 +44,11 @@ type Statement struct { assigns []interface{} } +type join struct { + Name string + Conds []interface{} +} + // StatementModifier statement modifier interface type StatementModifier interface { ModifyStatement(*Statement) @@ -401,7 +406,6 @@ func (stmt *Statement) clone() *Statement { Distinct: stmt.Distinct, Selects: stmt.Selects, Omits: stmt.Omits, - Joins: map[string][]interface{}{}, Preloads: map[string][]interface{}{}, ConnPool: stmt.ConnPool, Schema: stmt.Schema, @@ -417,8 +421,9 @@ func (stmt *Statement) clone() *Statement { newStmt.Preloads[k] = p } - for k, j := range stmt.Joins { - newStmt.Joins[k] = j + if len(stmt.Joins) > 0 { + newStmt.Joins = make([]join, len(stmt.Joins)) + copy(newStmt.Joins, stmt.Joins) } stmt.Settings.Range(func(k, v interface{}) bool { diff --git a/tests/joins_test.go b/tests/joins_test.go index e54d3784..f78ddf67 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "regexp" "sort" "testing" @@ -88,6 +89,13 @@ func TestJoinConds(t *testing.T) { if db5.Error != nil { t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error()) } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + stmt := dryDB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5).Statement + + if !regexp.MustCompile("SELECT .* FROM .users. left join pets.*join accounts.*").MatchString(stmt.SQL.String()) { + t.Errorf("joins should be ordered, but got %v", stmt.SQL.String()) + } } func TestJoinsWithSelect(t *testing.T) {