forked from mirror/gorm
Fix unordered joins, close #3267
This commit is contained in:
parent
2b510d6423
commit
3a97639880
|
@ -104,12 +104,12 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
joins := []clause.Join{}
|
joins := []clause.Join{}
|
||||||
for name, conds := range db.Statement.Joins {
|
for _, join := range db.Statement.Joins {
|
||||||
if db.Statement.Schema == nil {
|
if db.Statement.Schema == nil {
|
||||||
joins = append(joins, clause.Join{
|
joins = append(joins, clause.Join{
|
||||||
Expression: clause.Expr{SQL: name, Vars: conds},
|
Expression: clause.Expr{SQL: join.Name, Vars: join.Conds},
|
||||||
})
|
})
|
||||||
} else if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok {
|
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
|
||||||
tableAliasName := relation.Name
|
tableAliasName := relation.Name
|
||||||
|
|
||||||
for _, s := range relation.FieldSchema.DBNames {
|
for _, s := range relation.FieldSchema.DBNames {
|
||||||
|
@ -149,7 +149,7 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
joins = append(joins, clause.Join{
|
joins = append(joins, clause.Join{
|
||||||
Expression: clause.Expr{SQL: name, Vars: conds},
|
Expression: clause.Expr{SQL: join.Name, Vars: join.Conds},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -172,10 +172,7 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
|
||||||
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||||
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
|
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
if tx.Statement.Joins == nil {
|
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args})
|
||||||
tx.Statement.Joins = map[string][]interface{}{}
|
|
||||||
}
|
|
||||||
tx.Statement.Joins[query] = args
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
13
statement.go
13
statement.go
|
@ -29,7 +29,7 @@ type Statement struct {
|
||||||
Distinct bool
|
Distinct bool
|
||||||
Selects []string // selected columns
|
Selects []string // selected columns
|
||||||
Omits []string // omit columns
|
Omits []string // omit columns
|
||||||
Joins map[string][]interface{}
|
Joins []join
|
||||||
Preloads map[string][]interface{}
|
Preloads map[string][]interface{}
|
||||||
Settings sync.Map
|
Settings sync.Map
|
||||||
ConnPool ConnPool
|
ConnPool ConnPool
|
||||||
|
@ -44,6 +44,11 @@ type Statement struct {
|
||||||
assigns []interface{}
|
assigns []interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type join struct {
|
||||||
|
Name string
|
||||||
|
Conds []interface{}
|
||||||
|
}
|
||||||
|
|
||||||
// StatementModifier statement modifier interface
|
// StatementModifier statement modifier interface
|
||||||
type StatementModifier interface {
|
type StatementModifier interface {
|
||||||
ModifyStatement(*Statement)
|
ModifyStatement(*Statement)
|
||||||
|
@ -401,7 +406,6 @@ func (stmt *Statement) clone() *Statement {
|
||||||
Distinct: stmt.Distinct,
|
Distinct: stmt.Distinct,
|
||||||
Selects: stmt.Selects,
|
Selects: stmt.Selects,
|
||||||
Omits: stmt.Omits,
|
Omits: stmt.Omits,
|
||||||
Joins: map[string][]interface{}{},
|
|
||||||
Preloads: map[string][]interface{}{},
|
Preloads: map[string][]interface{}{},
|
||||||
ConnPool: stmt.ConnPool,
|
ConnPool: stmt.ConnPool,
|
||||||
Schema: stmt.Schema,
|
Schema: stmt.Schema,
|
||||||
|
@ -417,8 +421,9 @@ func (stmt *Statement) clone() *Statement {
|
||||||
newStmt.Preloads[k] = p
|
newStmt.Preloads[k] = p
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, j := range stmt.Joins {
|
if len(stmt.Joins) > 0 {
|
||||||
newStmt.Joins[k] = j
|
newStmt.Joins = make([]join, len(stmt.Joins))
|
||||||
|
copy(newStmt.Joins, stmt.Joins)
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt.Settings.Range(func(k, v interface{}) bool {
|
stmt.Settings.Range(func(k, v interface{}) bool {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package tests_test
|
package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -88,6 +89,13 @@ func TestJoinConds(t *testing.T) {
|
||||||
if db5.Error != nil {
|
if db5.Error != nil {
|
||||||
t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error())
|
t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dryDB := DB.Session(&gorm.Session{DryRun: true})
|
||||||
|
stmt := dryDB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5).Statement
|
||||||
|
|
||||||
|
if !regexp.MustCompile("SELECT .* FROM .users. left join pets.*join accounts.*").MatchString(stmt.SQL.String()) {
|
||||||
|
t.Errorf("joins should be ordered, but got %v", stmt.SQL.String())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestJoinsWithSelect(t *testing.T) {
|
func TestJoinsWithSelect(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue