Add Distinct support

This commit is contained in:
Jinzhu 2020-06-05 19:19:08 +08:00
parent d50879cc28
commit eda2f023b0
7 changed files with 111 additions and 8 deletions

View File

@ -37,7 +37,7 @@ func Query(db *gorm.DB) {
} }
func BuildQuerySQL(db *gorm.DB) { func BuildQuerySQL(db *gorm.DB) {
clauseSelect := clause.Select{} clauseSelect := clause.Select{Distinct: db.Statement.Distinct}
if db.Statement.ReflectValue.Kind() == reflect.Struct { if db.Statement.ReflectValue.Kind() == reflect.Struct {
var conds []clause.Expression var conds []clause.Expression

View File

@ -45,6 +45,16 @@ func (db *DB) Table(name string) (tx *DB) {
return return
} }
// Distinct specify distinct fields that you want querying
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
tx = db
if len(args) > 0 {
tx = tx.Select(args[0], args[1:]...)
}
tx.Statement.Distinct = true
return tx
}
// Select specify fields that you want when querying, creating, updating // Select specify fields that you want when querying, creating, updating
func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()

View File

@ -2,6 +2,7 @@ package clause
// Select select attrs when querying, updating, creating // Select select attrs when querying, updating, creating
type Select struct { type Select struct {
Distinct bool
Columns []Column Columns []Column
Expression Expression Expression Expression
} }
@ -12,6 +13,10 @@ func (s Select) Name() string {
func (s Select) Build(builder Builder) { func (s Select) Build(builder Builder) {
if len(s.Columns) > 0 { if len(s.Columns) > 0 {
if s.Distinct {
builder.WriteString(" DISTINCT ")
}
for idx, column := range s.Columns { for idx, column := range s.Columns {
if idx > 0 { if idx > 0 {
builder.WriteByte(',') builder.WriteByte(',')

View File

@ -23,4 +23,6 @@ var (
ErrPtrStructSupported = errors.New("only ptr of struct supported") ErrPtrStructSupported = errors.New("only ptr of struct supported")
// ErrorPrimaryKeyRequired primary keys required // ErrorPrimaryKeyRequired primary keys required
ErrorPrimaryKeyRequired = errors.New("primary key required") ErrorPrimaryKeyRequired = errors.New("primary key required")
// ErrorModelValueRequired model value required
ErrorModelValueRequired = errors.New("model value required")
) )

View File

@ -233,13 +233,24 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
func (db *DB) Count(count *int64) (tx *DB) { func (db *DB) Count(count *int64) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if s, ok := tx.Statement.Clauses["SELECT"].Expression.(clause.Select); !ok || len(s.Columns) == 0 {
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
}
if tx.Statement.Model == nil { if tx.Statement.Model == nil {
tx.Statement.Model = tx.Statement.Dest tx.Statement.Model = tx.Statement.Dest
} }
if len(tx.Statement.Selects) == 0 {
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
} else if len(tx.Statement.Selects) == 1 && !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") {
column := tx.Statement.Selects[0]
if tx.Statement.Parse(tx.Statement.Model) == nil {
if f := tx.Statement.Schema.LookUpField(column); f != nil {
column = f.DBName
}
}
tx.Statement.AddClause(clause.Select{
Expression: clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: column}}},
})
}
tx.Statement.Dest = count tx.Statement.Dest = count
tx.callbacks.Query().Execute(tx) tx.callbacks.Query().Execute(tx)
if db.RowsAffected != 1 { if db.RowsAffected != 1 {
@ -273,9 +284,22 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
// db.Find(&users).Pluck("age", &ages) // db.Find(&users).Pluck("age", &ages)
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClauseIfNotExists(clause.Select{Columns: []clause.Column{{Name: column}}}) if tx.Statement.Model != nil {
tx.Statement.Dest = dest if tx.Statement.Parse(tx.Statement.Model) == nil {
tx.callbacks.Query().Execute(tx) if f := tx.Statement.Schema.LookUpField(column); f != nil {
column = f.DBName
}
}
tx.Statement.AddClauseIfNotExists(clause.Select{
Distinct: tx.Statement.Distinct,
Columns: []clause.Column{{Name: column}},
})
tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx)
} else {
tx.AddError(ErrorModelValueRequired)
}
return return
} }

View File

@ -23,6 +23,7 @@ type Statement struct {
Dest interface{} Dest interface{}
ReflectValue reflect.Value ReflectValue reflect.Value
Clauses map[string]clause.Clause Clauses map[string]clause.Clause
Distinct bool
Selects []string // selected columns Selects []string // selected columns
Omits []string // omit columns Omits []string // omit columns
Joins map[string][]interface{} Joins map[string][]interface{}
@ -331,6 +332,7 @@ func (stmt *Statement) clone() *Statement {
Dest: stmt.Dest, Dest: stmt.Dest,
ReflectValue: stmt.ReflectValue, ReflectValue: stmt.ReflectValue,
Clauses: map[string]clause.Clause{}, Clauses: map[string]clause.Clause{},
Distinct: stmt.Distinct,
Selects: stmt.Selects, Selects: stmt.Selects,
Omits: stmt.Omits, Omits: stmt.Omits,
Joins: map[string][]interface{}{}, Joins: map[string][]interface{}{},

60
tests/distinct_test.go Normal file
View File

@ -0,0 +1,60 @@
package tests_test
import (
"testing"
. "gorm.io/gorm/utils/tests"
)
func TestDistinct(t *testing.T) {
var users = []User{
*GetUser("distinct", Config{}),
*GetUser("distinct", Config{}),
*GetUser("distinct", Config{}),
*GetUser("distinct-2", Config{}),
*GetUser("distinct-3", Config{}),
}
users[0].Age = 20
if err := DB.Create(&users).Error; err != nil {
t.Fatalf("errors happened when create users: %v", err)
}
var names []string
DB.Model(&User{}).Where("name like ?", "distinct%").Order("name").Pluck("Name", &names)
AssertEqual(t, names, []string{"distinct", "distinct", "distinct", "distinct-2", "distinct-3"})
var names1 []string
DB.Model(&User{}).Where("name like ?", "distinct%").Distinct().Order("name").Pluck("Name", &names1)
AssertEqual(t, names1, []string{"distinct", "distinct-2", "distinct-3"})
var results []User
if err := DB.Distinct("name", "age").Where("name like ?", "distinct%").Order("name, age desc").Find(&results).Error; err != nil {
t.Errorf("failed to query users, got error: %v", err)
}
expects := []User{
{Name: "distinct", Age: 20},
{Name: "distinct", Age: 18},
{Name: "distinct-2", Age: 18},
{Name: "distinct-3", Age: 18},
}
if len(results) != 4 {
t.Fatalf("invalid results length found, expects: %v, got %v", len(expects), len(results))
}
for idx, expect := range expects {
AssertObjEqual(t, results[idx], expect, "Name", "Age")
}
var count int64
if err := DB.Model(&User{}).Where("name like ?", "distinct%").Count(&count).Error; err != nil || count != 5 {
t.Errorf("failed to query users count, got error: %v, count: %v", err, count)
}
if err := DB.Model(&User{}).Distinct("name").Where("name like ?", "distinct%").Count(&count).Error; err != nil || count != 3 {
t.Errorf("failed to query users count, got error: %v, count %v", err, count)
}
}