From a16db07945e5f5acf348649debd2130dfcfeeb92 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 7 Sep 2021 21:21:44 +0800 Subject: [PATCH] Refactor Join ON --- callbacks/query.go | 69 +++++++++++++++++++++++---------------------- chainable_api.go | 4 ++- statement.go | 2 +- tests/joins_test.go | 5 ++-- 4 files changed, 42 insertions(+), 38 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index a4093c63..1cfd618c 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -125,47 +125,48 @@ func BuildQuerySQL(db *gorm.DB) { }) } - if join.On != nil { - primaryFields := make([]clause.Column, len(relation.FieldSchema.PrimaryFieldDBNames)) - for idx, ref := range relation.FieldSchema.PrimaryFieldDBNames { - primaryFields[idx] = clause.Column{Table: tableAliasName, Name: ref} - } - - exprs := db.Statement.BuildCondition("(?) = (?)", primaryFields, join.On) - joins = append(joins, clause.Join{ - Type: clause.LeftJoin, - Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, - ON: clause.Where{Exprs: exprs}, - }) - } else { - exprs := make([]clause.Expression, len(relation.References)) - for idx, ref := range relation.References { - if ref.OwnPrimaryKey { + exprs := make([]clause.Expression, len(relation.References)) + for idx, ref := range relation.References { + if ref.OwnPrimaryKey { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + } + } else { + if ref.PrimaryValue == "" { exprs[idx] = clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, } } else { - if ref.PrimaryValue == "" { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, - } - } else { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, - } + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, } } } - - joins = append(joins, clause.Join{ - Type: clause.LeftJoin, - Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, - ON: clause.Where{Exprs: exprs}, - }) } + + if join.On != nil { + onStmt := gorm.Statement{Table: tableAliasName, DB: db} + join.On.Build(&onStmt) + onSQL := onStmt.SQL.String() + vars := onStmt.Vars + for idx, v := range onStmt.Vars { + bindvar := strings.Builder{} + onStmt.Vars = vars[0 : idx+1] + db.Dialector.BindVarTo(&bindvar, &onStmt, v) + onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + } + + exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) + } + + joins = append(joins, clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: exprs}, + }) } else { joins = append(joins, clause.Join{ Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, diff --git a/chainable_api.go b/chainable_api.go index 8fd7ee3c..01ab2597 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -177,7 +177,9 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { if len(args) > 0 { if db, ok := args[0].(*DB); ok { - tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args[1:], On: db}) + if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args[1:], On: &where}) + } return } } diff --git a/statement.go b/statement.go index b21b8854..38363443 100644 --- a/statement.go +++ b/statement.go @@ -50,7 +50,7 @@ type Statement struct { type join struct { Name string Conds []interface{} - On interface{} + On *clause.Where } // StatementModifier statement modifier interface diff --git a/tests/joins_test.go b/tests/joins_test.go index 21c73c19..e560f38a 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -109,14 +109,15 @@ func TestJoinOn(t *testing.T) { DB.Save(&user) var user1 User - onQuery := DB.Select("id").Where("user_id = users.id AND name = ?", "joins-on_pet_1").Model(&Pet{}) + onQuery := DB.Where(&Pet{Name: "joins-on_pet_1"}) if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil { t.Fatalf("Failed to load with joins on, got error: %v", err) } + AssertEqual(t, user1.NamedPet.Name, "joins-on_pet_1") - onQuery2 := DB.Select("id").Where("user_id = users.id AND name = ?", "joins-on_pet_2").Model(&Pet{}) + onQuery2 := DB.Where(&Pet{Name: "joins-on_pet_2"}) var user2 User if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil { t.Fatalf("Failed to load with joins on, got error: %v", err)