From 362779575c2a91d29074b0a03b27187d615070ce Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 17 Jul 2020 11:24:24 +0800 Subject: [PATCH] Fix Select with specific symbol, close #3157 --- chainable_api.go | 6 ++++-- clause/select.go | 8 ++++++++ tests/distinct_test.go | 8 ++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 730f6308..7c352268 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -60,11 +60,11 @@ func (db *DB) Table(name string) (tx *DB) { // Distinct specify distinct fields that you want querying func (db *DB) Distinct(args ...interface{}) (tx *DB) { - tx = db + tx = db.getInstance() + tx.Statement.Distinct = true if len(args) > 0 { tx = tx.Select(args[0], args[1:]...) } - tx.Statement.Distinct = true return tx } @@ -102,6 +102,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { tx.Statement.Selects = append(tx.Statement.Selects, arg...) default: tx.Statement.AddClause(clause.Select{ + Distinct: db.Statement.Distinct, Expression: clause.Expr{SQL: v, Vars: args}, }) return @@ -109,6 +110,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { } } else { tx.Statement.AddClause(clause.Select{ + Distinct: db.Statement.Distinct, Expression: clause.Expr{SQL: v, Vars: args}, }) } diff --git a/clause/select.go b/clause/select.go index 9c2bc625..b93b8769 100644 --- a/clause/select.go +++ b/clause/select.go @@ -30,6 +30,14 @@ func (s Select) Build(builder Builder) { func (s Select) MergeClause(clause *Clause) { if s.Expression != nil { + if s.Distinct { + if expr, ok := s.Expression.(Expr); ok { + expr.SQL = "DISTINCT " + expr.SQL + clause.Expression = expr + return + } + } + clause.Expression = s.Expression } else { clause.Expression = s diff --git a/tests/distinct_test.go b/tests/distinct_test.go index 248602d3..29a320ff 100644 --- a/tests/distinct_test.go +++ b/tests/distinct_test.go @@ -1,8 +1,10 @@ package tests_test import ( + "regexp" "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -57,4 +59,10 @@ func TestDistinct(t *testing.T) { if err := DB.Model(&User{}).Distinct("name").Where("name like ?", "distinct%").Count(&count).Error; err != nil || count != 3 { t.Errorf("failed to query users count, got error: %v, count %v", err, count) } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + r := dryDB.Distinct("u.id, u.*").Table("user_speaks as s").Joins("inner join users as u on u.id = s.user_id").Where("s.language_code ='US' or s.language_code ='ES'").Find(&User{}) + if !regexp.MustCompile(`SELECT DISTINCT u\.id, u\.\* FROM user_speaks as s inner join users as u`).MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Distinct with u.*, but got %v", r.Statement.SQL.String()) + } }