Fix unordered joins, close #3267

This commit is contained in:
Jinzhu 2020-08-23 10:40:37 +08:00
parent 2b510d6423
commit 3a97639880
4 changed files with 22 additions and 12 deletions

View File

@ -104,12 +104,12 @@ func BuildQuerySQL(db *gorm.DB) {
} }
joins := []clause.Join{} joins := []clause.Join{}
for name, conds := range db.Statement.Joins { for _, join := range db.Statement.Joins {
if db.Statement.Schema == nil { if db.Statement.Schema == nil {
joins = append(joins, clause.Join{ joins = append(joins, clause.Join{
Expression: clause.Expr{SQL: name, Vars: conds}, Expression: clause.Expr{SQL: join.Name, Vars: join.Conds},
}) })
} else if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
tableAliasName := relation.Name tableAliasName := relation.Name
for _, s := range relation.FieldSchema.DBNames { for _, s := range relation.FieldSchema.DBNames {
@ -149,7 +149,7 @@ func BuildQuerySQL(db *gorm.DB) {
}) })
} else { } else {
joins = append(joins, clause.Join{ joins = append(joins, clause.Join{
Expression: clause.Expr{SQL: name, Vars: conds}, Expression: clause.Expr{SQL: join.Name, Vars: join.Conds},
}) })
} }
} }

View File

@ -172,10 +172,7 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if tx.Statement.Joins == nil { tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args})
tx.Statement.Joins = map[string][]interface{}{}
}
tx.Statement.Joins[query] = args
return return
} }

View File

@ -29,7 +29,7 @@ type Statement struct {
Distinct bool Distinct bool
Selects []string // selected columns Selects []string // selected columns
Omits []string // omit columns Omits []string // omit columns
Joins map[string][]interface{} Joins []join
Preloads map[string][]interface{} Preloads map[string][]interface{}
Settings sync.Map Settings sync.Map
ConnPool ConnPool ConnPool ConnPool
@ -44,6 +44,11 @@ type Statement struct {
assigns []interface{} assigns []interface{}
} }
type join struct {
Name string
Conds []interface{}
}
// StatementModifier statement modifier interface // StatementModifier statement modifier interface
type StatementModifier interface { type StatementModifier interface {
ModifyStatement(*Statement) ModifyStatement(*Statement)
@ -401,7 +406,6 @@ func (stmt *Statement) clone() *Statement {
Distinct: stmt.Distinct, Distinct: stmt.Distinct,
Selects: stmt.Selects, Selects: stmt.Selects,
Omits: stmt.Omits, Omits: stmt.Omits,
Joins: map[string][]interface{}{},
Preloads: map[string][]interface{}{}, Preloads: map[string][]interface{}{},
ConnPool: stmt.ConnPool, ConnPool: stmt.ConnPool,
Schema: stmt.Schema, Schema: stmt.Schema,
@ -417,8 +421,9 @@ func (stmt *Statement) clone() *Statement {
newStmt.Preloads[k] = p newStmt.Preloads[k] = p
} }
for k, j := range stmt.Joins { if len(stmt.Joins) > 0 {
newStmt.Joins[k] = j newStmt.Joins = make([]join, len(stmt.Joins))
copy(newStmt.Joins, stmt.Joins)
} }
stmt.Settings.Range(func(k, v interface{}) bool { stmt.Settings.Range(func(k, v interface{}) bool {

View File

@ -1,6 +1,7 @@
package tests_test package tests_test
import ( import (
"regexp"
"sort" "sort"
"testing" "testing"
@ -88,6 +89,13 @@ func TestJoinConds(t *testing.T) {
if db5.Error != nil { if db5.Error != nil {
t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error()) t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error())
} }
dryDB := DB.Session(&gorm.Session{DryRun: true})
stmt := dryDB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5).Statement
if !regexp.MustCompile("SELECT .* FROM .users. left join pets.*join accounts.*").MatchString(stmt.SQL.String()) {
t.Errorf("joins should be ordered, but got %v", stmt.SQL.String())
}
} }
func TestJoinsWithSelect(t *testing.T) { func TestJoinsWithSelect(t *testing.T) {