From 5b1d3e4a771947f5caae6950b86ab32fd8e56507 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 20:21:52 +0800 Subject: [PATCH] Test Joins --- callbacks/query.go | 6 +----- finisher_api.go | 5 +++-- statement.go | 10 ++++----- tests/joins_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 12 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 9f96fd1a..55f2c65b 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -123,11 +123,7 @@ func BuildQuerySQL(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.From{}) } - if len(clauseSelect.Columns) > 0 { - db.Statement.AddClause(clauseSelect) - } else { - db.Statement.AddClauseIfNotExists(clauseSelect) - } + db.Statement.AddClauseIfNotExists(clauseSelect) db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") } diff --git a/finisher_api.go b/finisher_api.go index cfbb98c1..49b08fa4 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -233,9 +233,10 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Count(count *int64) (tx *DB) { tx = db.getInstance() - if len(tx.Statement.Selects) == 0 { - tx.Statement.Selects = []string{"count(1)"} + if s, ok := tx.Statement.Clauses["SELECT"].Expression.(clause.Select); !ok || len(s.Columns) == 0 { + tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) } + if tx.Statement.Model == nil { tx.Statement.Model = tx.Statement.Dest } diff --git a/statement.go b/statement.go index e0d92c5e..444d5c37 100644 --- a/statement.go +++ b/statement.go @@ -196,7 +196,7 @@ func (stmt *Statement) AddClause(v clause.Interface) { // AddClauseIfNotExists add clause if not exists func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { - if _, ok := stmt.Clauses[v.Name()]; !ok { + if c, ok := stmt.Clauses[v.Name()]; !ok && c.Expression == nil { stmt.AddClause(v) } } @@ -248,9 +248,9 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con for _, field := range s.Fields { if v, isZero := field.ValueOf(reflectValue); !isZero { if field.DBName == "" { - conds = append(conds, clause.Eq{Column: field.Name, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) } else { - conds = append(conds, clause.Eq{Column: field.DBName, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) } } } @@ -259,9 +259,9 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con for _, field := range s.Fields { if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { if field.DBName == "" { - conds = append(conds, clause.Eq{Column: field.Name, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) } else { - conds = append(conds, clause.Eq{Column: field.DBName, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) } } } diff --git a/tests/joins_test.go b/tests/joins_test.go index 556130ee..8a9cdde5 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -4,6 +4,7 @@ import ( "sort" "testing" + "github.com/jinzhu/gorm" . "github.com/jinzhu/gorm/tests" ) @@ -53,3 +54,54 @@ func TestJoinsForSlice(t *testing.T) { CheckUser(t, user, users2[idx]) } } + +func TestJoinConds(t *testing.T) { + var user = *GetUser("joins-conds", Config{Account: true, Pets: 3}) + DB.Save(&user) + + var users1 []User + DB.Joins("left join pets on pets.user_id = users.id").Where("users.name = ?", user.Name).Find(&users1) + if len(users1) != 3 { + t.Errorf("should find two users using left join, but got %v", len(users1)) + } + + var users2 []User + DB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Where("users.name = ?", user.Name).First(&users2) + if len(users2) != 1 { + t.Errorf("should find one users using left join with conditions, but got %v", len(users2)) + } + + var users3 []User + DB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where("users.name = ?", user.Name).First(&users3) + if len(users3) != 1 { + t.Errorf("should find one users using multiple left join conditions, but got %v", len(users3)) + } + + var users4 []User + DB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number+"non-exist").Where("users.name = ?", user.Name).First(&users4) + if len(users4) != 0 { + t.Errorf("should find no user when searching with unexisting credit card, but got %v", len(users4)) + } + + var users5 []User + db5 := DB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5) + if db5.Error != nil { + t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error()) + } +} + +func TestJoinsWithSelect(t *testing.T) { + type result struct { + ID uint + Name string + } + + user := *GetUser("joins_with_select", Config{Pets: 2}) + DB.Save(&user) + + var results []result + DB.Table("users").Select("users.id, pets.name").Joins("left join pets on pets.user_id = users.id").Where("users.name = ?", "joins_with_select").Scan(&results) + if len(results) != 2 || results[0].Name != user.Pets[0].Name || results[1].Name != user.Pets[1].Name { + t.Errorf("Should find all two pets with Join select") + } +}