diff --git a/callbacks/query.go b/callbacks/query.go index b3293576..16202187 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -37,7 +37,7 @@ func Query(db *gorm.DB) { } func BuildQuerySQL(db *gorm.DB) { - clauseSelect := clause.Select{} + clauseSelect := clause.Select{Distinct: db.Statement.Distinct} if db.Statement.ReflectValue.Kind() == reflect.Struct { var conds []clause.Expression diff --git a/chainable_api.go b/chainable_api.go index b1ae3132..6c5a6f77 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -45,6 +45,16 @@ func (db *DB) Table(name string) (tx *DB) { return } +// Distinct specify distinct fields that you want querying +func (db *DB) Distinct(args ...interface{}) (tx *DB) { + tx = db + if len(args) > 0 { + tx = tx.Select(args[0], args[1:]...) + } + tx.Statement.Distinct = true + return tx +} + // Select specify fields that you want when querying, creating, updating func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() diff --git a/clause/select.go b/clause/select.go index 20b17e07..a1b77de8 100644 --- a/clause/select.go +++ b/clause/select.go @@ -2,6 +2,7 @@ package clause // Select select attrs when querying, updating, creating type Select struct { + Distinct bool Columns []Column Expression Expression } @@ -12,6 +13,10 @@ func (s Select) Name() string { func (s Select) Build(builder Builder) { if len(s.Columns) > 0 { + if s.Distinct { + builder.WriteString(" DISTINCT ") + } + for idx, column := range s.Columns { if idx > 0 { builder.WriteByte(',') diff --git a/errors.go b/errors.go index 82f24df2..ff06f24e 100644 --- a/errors.go +++ b/errors.go @@ -23,4 +23,6 @@ var ( ErrPtrStructSupported = errors.New("only ptr of struct supported") // ErrorPrimaryKeyRequired primary keys required ErrorPrimaryKeyRequired = errors.New("primary key required") + // ErrorModelValueRequired model value required + ErrorModelValueRequired = errors.New("model value required") ) diff --git a/finisher_api.go b/finisher_api.go index e493b406..d6de7aa3 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -233,13 +233,24 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Count(count *int64) (tx *DB) { tx = db.getInstance() - if s, ok := tx.Statement.Clauses["SELECT"].Expression.(clause.Select); !ok || len(s.Columns) == 0 { - tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) - } - if tx.Statement.Model == nil { tx.Statement.Model = tx.Statement.Dest } + + if len(tx.Statement.Selects) == 0 { + tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) + } else if len(tx.Statement.Selects) == 1 && !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") { + column := tx.Statement.Selects[0] + if tx.Statement.Parse(tx.Statement.Model) == nil { + if f := tx.Statement.Schema.LookUpField(column); f != nil { + column = f.DBName + } + } + tx.Statement.AddClause(clause.Select{ + Expression: clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: column}}}, + }) + } + tx.Statement.Dest = count tx.callbacks.Query().Execute(tx) if db.RowsAffected != 1 { @@ -273,9 +284,22 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { // db.Find(&users).Pluck("age", &ages) func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClauseIfNotExists(clause.Select{Columns: []clause.Column{{Name: column}}}) - tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) + if tx.Statement.Model != nil { + if tx.Statement.Parse(tx.Statement.Model) == nil { + if f := tx.Statement.Schema.LookUpField(column); f != nil { + column = f.DBName + } + } + + tx.Statement.AddClauseIfNotExists(clause.Select{ + Distinct: tx.Statement.Distinct, + Columns: []clause.Column{{Name: column}}, + }) + tx.Statement.Dest = dest + tx.callbacks.Query().Execute(tx) + } else { + tx.AddError(ErrorModelValueRequired) + } return } diff --git a/statement.go b/statement.go index ffe3c75b..755d93ac 100644 --- a/statement.go +++ b/statement.go @@ -23,6 +23,7 @@ type Statement struct { Dest interface{} ReflectValue reflect.Value Clauses map[string]clause.Clause + Distinct bool Selects []string // selected columns Omits []string // omit columns Joins map[string][]interface{} @@ -331,6 +332,7 @@ func (stmt *Statement) clone() *Statement { Dest: stmt.Dest, ReflectValue: stmt.ReflectValue, Clauses: map[string]clause.Clause{}, + Distinct: stmt.Distinct, Selects: stmt.Selects, Omits: stmt.Omits, Joins: map[string][]interface{}{}, diff --git a/tests/distinct_test.go b/tests/distinct_test.go new file mode 100644 index 00000000..f5a969a8 --- /dev/null +++ b/tests/distinct_test.go @@ -0,0 +1,60 @@ +package tests_test + +import ( + "testing" + + . "gorm.io/gorm/utils/tests" +) + +func TestDistinct(t *testing.T) { + var users = []User{ + *GetUser("distinct", Config{}), + *GetUser("distinct", Config{}), + *GetUser("distinct", Config{}), + *GetUser("distinct-2", Config{}), + *GetUser("distinct-3", Config{}), + } + users[0].Age = 20 + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create users: %v", err) + } + + var names []string + DB.Model(&User{}).Where("name like ?", "distinct%").Order("name").Pluck("Name", &names) + AssertEqual(t, names, []string{"distinct", "distinct", "distinct", "distinct-2", "distinct-3"}) + + var names1 []string + DB.Model(&User{}).Where("name like ?", "distinct%").Distinct().Order("name").Pluck("Name", &names1) + + AssertEqual(t, names1, []string{"distinct", "distinct-2", "distinct-3"}) + + var results []User + if err := DB.Distinct("name", "age").Where("name like ?", "distinct%").Order("name, age desc").Find(&results).Error; err != nil { + t.Errorf("failed to query users, got error: %v", err) + } + + expects := []User{ + {Name: "distinct", Age: 20}, + {Name: "distinct", Age: 18}, + {Name: "distinct-2", Age: 18}, + {Name: "distinct-3", Age: 18}, + } + + if len(results) != 4 { + t.Fatalf("invalid results length found, expects: %v, got %v", len(expects), len(results)) + } + + for idx, expect := range expects { + AssertObjEqual(t, results[idx], expect, "Name", "Age") + } + + var count int64 + if err := DB.Model(&User{}).Where("name like ?", "distinct%").Count(&count).Error; err != nil || count != 5 { + t.Errorf("failed to query users count, got error: %v, count: %v", err, count) + } + + if err := DB.Model(&User{}).Distinct("name").Where("name like ?", "distinct%").Count(&count).Error; err != nil || count != 3 { + t.Errorf("failed to query users count, got error: %v, count %v", err, count) + } +}