mirror of https://github.com/go-gorm/gorm.git
Add Distinct support
This commit is contained in:
parent
d50879cc28
commit
eda2f023b0
|
@ -37,7 +37,7 @@ func Query(db *gorm.DB) {
|
|||
}
|
||||
|
||||
func BuildQuerySQL(db *gorm.DB) {
|
||||
clauseSelect := clause.Select{}
|
||||
clauseSelect := clause.Select{Distinct: db.Statement.Distinct}
|
||||
|
||||
if db.Statement.ReflectValue.Kind() == reflect.Struct {
|
||||
var conds []clause.Expression
|
||||
|
|
|
@ -45,6 +45,16 @@ func (db *DB) Table(name string) (tx *DB) {
|
|||
return
|
||||
}
|
||||
|
||||
// Distinct specify distinct fields that you want querying
|
||||
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
|
||||
tx = db
|
||||
if len(args) > 0 {
|
||||
tx = tx.Select(args[0], args[1:]...)
|
||||
}
|
||||
tx.Statement.Distinct = true
|
||||
return tx
|
||||
}
|
||||
|
||||
// Select specify fields that you want when querying, creating, updating
|
||||
func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
|
|
@ -2,6 +2,7 @@ package clause
|
|||
|
||||
// Select select attrs when querying, updating, creating
|
||||
type Select struct {
|
||||
Distinct bool
|
||||
Columns []Column
|
||||
Expression Expression
|
||||
}
|
||||
|
@ -12,6 +13,10 @@ func (s Select) Name() string {
|
|||
|
||||
func (s Select) Build(builder Builder) {
|
||||
if len(s.Columns) > 0 {
|
||||
if s.Distinct {
|
||||
builder.WriteString(" DISTINCT ")
|
||||
}
|
||||
|
||||
for idx, column := range s.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
|
|
|
@ -23,4 +23,6 @@ var (
|
|||
ErrPtrStructSupported = errors.New("only ptr of struct supported")
|
||||
// ErrorPrimaryKeyRequired primary keys required
|
||||
ErrorPrimaryKeyRequired = errors.New("primary key required")
|
||||
// ErrorModelValueRequired model value required
|
||||
ErrorModelValueRequired = errors.New("model value required")
|
||||
)
|
||||
|
|
|
@ -233,13 +233,24 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
|
|||
|
||||
func (db *DB) Count(count *int64) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if s, ok := tx.Statement.Clauses["SELECT"].Expression.(clause.Select); !ok || len(s.Columns) == 0 {
|
||||
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
|
||||
}
|
||||
|
||||
if tx.Statement.Model == nil {
|
||||
tx.Statement.Model = tx.Statement.Dest
|
||||
}
|
||||
|
||||
if len(tx.Statement.Selects) == 0 {
|
||||
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
|
||||
} else if len(tx.Statement.Selects) == 1 && !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") {
|
||||
column := tx.Statement.Selects[0]
|
||||
if tx.Statement.Parse(tx.Statement.Model) == nil {
|
||||
if f := tx.Statement.Schema.LookUpField(column); f != nil {
|
||||
column = f.DBName
|
||||
}
|
||||
}
|
||||
tx.Statement.AddClause(clause.Select{
|
||||
Expression: clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: column}}},
|
||||
})
|
||||
}
|
||||
|
||||
tx.Statement.Dest = count
|
||||
tx.callbacks.Query().Execute(tx)
|
||||
if db.RowsAffected != 1 {
|
||||
|
@ -273,9 +284,22 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
|
|||
// db.Find(&users).Pluck("age", &ages)
|
||||
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClauseIfNotExists(clause.Select{Columns: []clause.Column{{Name: column}}})
|
||||
if tx.Statement.Model != nil {
|
||||
if tx.Statement.Parse(tx.Statement.Model) == nil {
|
||||
if f := tx.Statement.Schema.LookUpField(column); f != nil {
|
||||
column = f.DBName
|
||||
}
|
||||
}
|
||||
|
||||
tx.Statement.AddClauseIfNotExists(clause.Select{
|
||||
Distinct: tx.Statement.Distinct,
|
||||
Columns: []clause.Column{{Name: column}},
|
||||
})
|
||||
tx.Statement.Dest = dest
|
||||
tx.callbacks.Query().Execute(tx)
|
||||
} else {
|
||||
tx.AddError(ErrorModelValueRequired)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ type Statement struct {
|
|||
Dest interface{}
|
||||
ReflectValue reflect.Value
|
||||
Clauses map[string]clause.Clause
|
||||
Distinct bool
|
||||
Selects []string // selected columns
|
||||
Omits []string // omit columns
|
||||
Joins map[string][]interface{}
|
||||
|
@ -331,6 +332,7 @@ func (stmt *Statement) clone() *Statement {
|
|||
Dest: stmt.Dest,
|
||||
ReflectValue: stmt.ReflectValue,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
Distinct: stmt.Distinct,
|
||||
Selects: stmt.Selects,
|
||||
Omits: stmt.Omits,
|
||||
Joins: map[string][]interface{}{},
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestDistinct(t *testing.T) {
|
||||
var users = []User{
|
||||
*GetUser("distinct", Config{}),
|
||||
*GetUser("distinct", Config{}),
|
||||
*GetUser("distinct", Config{}),
|
||||
*GetUser("distinct-2", Config{}),
|
||||
*GetUser("distinct-3", Config{}),
|
||||
}
|
||||
users[0].Age = 20
|
||||
|
||||
if err := DB.Create(&users).Error; err != nil {
|
||||
t.Fatalf("errors happened when create users: %v", err)
|
||||
}
|
||||
|
||||
var names []string
|
||||
DB.Model(&User{}).Where("name like ?", "distinct%").Order("name").Pluck("Name", &names)
|
||||
AssertEqual(t, names, []string{"distinct", "distinct", "distinct", "distinct-2", "distinct-3"})
|
||||
|
||||
var names1 []string
|
||||
DB.Model(&User{}).Where("name like ?", "distinct%").Distinct().Order("name").Pluck("Name", &names1)
|
||||
|
||||
AssertEqual(t, names1, []string{"distinct", "distinct-2", "distinct-3"})
|
||||
|
||||
var results []User
|
||||
if err := DB.Distinct("name", "age").Where("name like ?", "distinct%").Order("name, age desc").Find(&results).Error; err != nil {
|
||||
t.Errorf("failed to query users, got error: %v", err)
|
||||
}
|
||||
|
||||
expects := []User{
|
||||
{Name: "distinct", Age: 20},
|
||||
{Name: "distinct", Age: 18},
|
||||
{Name: "distinct-2", Age: 18},
|
||||
{Name: "distinct-3", Age: 18},
|
||||
}
|
||||
|
||||
if len(results) != 4 {
|
||||
t.Fatalf("invalid results length found, expects: %v, got %v", len(expects), len(results))
|
||||
}
|
||||
|
||||
for idx, expect := range expects {
|
||||
AssertObjEqual(t, results[idx], expect, "Name", "Age")
|
||||
}
|
||||
|
||||
var count int64
|
||||
if err := DB.Model(&User{}).Where("name like ?", "distinct%").Count(&count).Error; err != nil || count != 5 {
|
||||
t.Errorf("failed to query users count, got error: %v, count: %v", err, count)
|
||||
}
|
||||
|
||||
if err := DB.Model(&User{}).Distinct("name").Where("name like ?", "distinct%").Count(&count).Error; err != nil || count != 3 {
|
||||
t.Errorf("failed to query users count, got error: %v, count %v", err, count)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue