From 52cc438d07cef6975b3407594c612f8e856b88af Mon Sep 17 00:00:00 2001 From: Adrien Carreira Date: Sat, 17 Jul 2021 15:45:15 +0200 Subject: [PATCH] JoinsOn unit test + use all primary keys --- callbacks/query.go | 10 ++++++++-- chainable_api.go | 2 +- statement.go | 2 +- tests/joins_test.go | 20 ++++++++++++++++++++ utils/tests/models.go | 2 ++ 5 files changed, 32 insertions(+), 4 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index e5f1250c..570a85d0 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -125,7 +125,7 @@ func BuildQuerySQL(db *gorm.DB) { }) } - if join.On != nil { + if join.On == nil { exprs := make([]clause.Expression, len(relation.References)) for idx, ref := range relation.References { if ref.OwnPrimaryKey { @@ -153,10 +153,16 @@ func BuildQuerySQL(db *gorm.DB) { ON: clause.Where{Exprs: exprs}, }) } else { + primaryFields := make([]clause.Column, len(relation.FieldSchema.PrimaryFieldDBNames)) + for idx, ref := range relation.FieldSchema.PrimaryFieldDBNames { + primaryFields[idx] = clause.Column{Table: tableAliasName, Name: ref} + } + + exprs := db.Statement.BuildCondition("(?) = (?)", primaryFields, join.On) joins = append(joins, clause.Join{ Type: clause.LeftJoin, Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, - ON: clause.Where{Exprs: []clause.Expression{join.On}}, + ON: clause.Where{Exprs: exprs}, }) } } else { diff --git a/chainable_api.go b/chainable_api.go index 32943a83..184931ff 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -177,7 +177,7 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { return } -func (db *DB) JoinsOn(query string, on clause.Expression, args ...interface{}) (tx *DB) { +func (db *DB) JoinsOn(query string, on interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: on}) return diff --git a/statement.go b/statement.go index 89824bc1..b21b8854 100644 --- a/statement.go +++ b/statement.go @@ -50,7 +50,7 @@ type Statement struct { type join struct { Name string Conds []interface{} - On clause.Expression + On interface{} } // StatementModifier statement modifier interface diff --git a/tests/joins_test.go b/tests/joins_test.go index 46611f5f..0b46d69c 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -104,6 +104,26 @@ func TestJoinConds(t *testing.T) { } } +func TestJoinOn(t *testing.T) { + var user = *GetUser("joins-on", Config{Pets: 2}) + DB.Save(&user) + + var user1 User + onQuery := DB.Select("id").Where("user_id = users.id AND name = ?", "joins-on_pet_1").Model(&Pet{}) + + if err := DB.JoinsOn("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + AssertEqual(t, user1.NamedPet.Name, "joins-on_pet_1") + + onQuery2 := DB.Select("id").Where("user_id = users.id AND name = ?", "joins-on_pet_2").Model(&Pet{}) + var user2 User + if err := DB.JoinsOn("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + AssertEqual(t, user2.NamedPet.Name, "joins-on_pet_2") +} + func TestJoinsWithSelect(t *testing.T) { type result struct { ID uint diff --git a/utils/tests/models.go b/utils/tests/models.go index 2c5e71c0..8e833c93 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -11,6 +11,7 @@ import ( // He works in a Company (belongs to), he has a Manager (belongs to - single-table), and also managed a Team (has many - single-table) // He speaks many languages (many to many) and has many friends (many to many - single-table) // His pet also has one Toy (has one - polymorphic) +// NamedPet is a reference to a Named `Pets` (has many) type User struct { gorm.Model Name string @@ -18,6 +19,7 @@ type User struct { Birthday *time.Time Account Account Pets []*Pet + NamedPet *Pet Toys []Toy `gorm:"polymorphic:Owner"` CompanyID *int Company Company