Refactor Join ON

This commit is contained in:
Jinzhu 2021-09-07 21:21:44 +08:00
parent ba16b2368f
commit a16db07945
4 changed files with 42 additions and 38 deletions

View File

@ -125,19 +125,6 @@ func BuildQuerySQL(db *gorm.DB) {
}) })
} }
if join.On != nil {
primaryFields := make([]clause.Column, len(relation.FieldSchema.PrimaryFieldDBNames))
for idx, ref := range relation.FieldSchema.PrimaryFieldDBNames {
primaryFields[idx] = clause.Column{Table: tableAliasName, Name: ref}
}
exprs := db.Statement.BuildCondition("(?) = (?)", primaryFields, join.On)
joins = append(joins, clause.Join{
Type: clause.LeftJoin,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs},
})
} else {
exprs := make([]clause.Expression, len(relation.References)) exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References { for idx, ref := range relation.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
@ -160,12 +147,26 @@ func BuildQuerySQL(db *gorm.DB) {
} }
} }
if join.On != nil {
onStmt := gorm.Statement{Table: tableAliasName, DB: db}
join.On.Build(&onStmt)
onSQL := onStmt.SQL.String()
vars := onStmt.Vars
for idx, v := range onStmt.Vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
}
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
}
joins = append(joins, clause.Join{ joins = append(joins, clause.Join{
Type: clause.LeftJoin, Type: clause.LeftJoin,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs}, ON: clause.Where{Exprs: exprs},
}) })
}
} else { } else {
joins = append(joins, clause.Join{ joins = append(joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},

View File

@ -177,7 +177,9 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
if len(args) > 0 { if len(args) > 0 {
if db, ok := args[0].(*DB); ok { if db, ok := args[0].(*DB); ok {
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args[1:], On: db}) if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args[1:], On: &where})
}
return return
} }
} }

View File

@ -50,7 +50,7 @@ type Statement struct {
type join struct { type join struct {
Name string Name string
Conds []interface{} Conds []interface{}
On interface{} On *clause.Where
} }
// StatementModifier statement modifier interface // StatementModifier statement modifier interface

View File

@ -109,14 +109,15 @@ func TestJoinOn(t *testing.T) {
DB.Save(&user) DB.Save(&user)
var user1 User var user1 User
onQuery := DB.Select("id").Where("user_id = users.id AND name = ?", "joins-on_pet_1").Model(&Pet{}) onQuery := DB.Where(&Pet{Name: "joins-on_pet_1"})
if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil { if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil {
t.Fatalf("Failed to load with joins on, got error: %v", err) t.Fatalf("Failed to load with joins on, got error: %v", err)
} }
AssertEqual(t, user1.NamedPet.Name, "joins-on_pet_1") AssertEqual(t, user1.NamedPet.Name, "joins-on_pet_1")
onQuery2 := DB.Select("id").Where("user_id = users.id AND name = ?", "joins-on_pet_2").Model(&Pet{}) onQuery2 := DB.Where(&Pet{Name: "joins-on_pet_2"})
var user2 User var user2 User
if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil { if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil {
t.Fatalf("Failed to load with joins on, got error: %v", err) t.Fatalf("Failed to load with joins on, got error: %v", err)