forked from mirror/gorm
Fix Count with complicated Select, close #3826
This commit is contained in:
parent
f655041908
commit
1ef1f0bfe4
|
@ -93,10 +93,12 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
|||
}
|
||||
delete(tx.Statement.Clauses, "SELECT")
|
||||
case string:
|
||||
fields := strings.FieldsFunc(v, utils.IsValidDBNameChar)
|
||||
|
||||
// normal field names
|
||||
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
|
||||
if (strings.Contains(v, " ?") || strings.Contains(v, "(?")) && len(args) > 0 {
|
||||
tx.Statement.AddClause(clause.Select{
|
||||
Distinct: db.Statement.Distinct,
|
||||
Expression: clause.Expr{SQL: v, Vars: args},
|
||||
})
|
||||
} else {
|
||||
tx.Statement.Selects = []string{v}
|
||||
|
||||
for _, arg := range args {
|
||||
|
@ -115,11 +117,6 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
|||
}
|
||||
|
||||
delete(tx.Statement.Clauses, "SELECT")
|
||||
} else {
|
||||
tx.Statement.AddClause(clause.Select{
|
||||
Distinct: db.Statement.Distinct,
|
||||
Expression: clause.Expr{SQL: v, Vars: args},
|
||||
})
|
||||
}
|
||||
default:
|
||||
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
|
||||
|
|
|
@ -355,29 +355,38 @@ func (db *DB) Count(count *int64) (tx *DB) {
|
|||
}()
|
||||
}
|
||||
|
||||
if selectClause, ok := db.Statement.Clauses["SELECT"]; ok {
|
||||
defer func() {
|
||||
db.Statement.Clauses["SELECT"] = selectClause
|
||||
}()
|
||||
} else {
|
||||
defer delete(tx.Statement.Clauses, "SELECT")
|
||||
}
|
||||
|
||||
if len(tx.Statement.Selects) == 0 {
|
||||
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
|
||||
defer delete(tx.Statement.Clauses, "SELECT")
|
||||
} else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") {
|
||||
expr := clause.Expr{SQL: "count(1)"}
|
||||
|
||||
if len(tx.Statement.Selects) == 1 {
|
||||
dbName := tx.Statement.Selects[0]
|
||||
if tx.Statement.Parse(tx.Statement.Model) == nil {
|
||||
if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
|
||||
dbName = f.DBName
|
||||
fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar)
|
||||
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
|
||||
if tx.Statement.Parse(tx.Statement.Model) == nil {
|
||||
if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
|
||||
dbName = f.DBName
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tx.Statement.Distinct {
|
||||
expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}}
|
||||
} else {
|
||||
expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
|
||||
if tx.Statement.Distinct {
|
||||
expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}}
|
||||
} else {
|
||||
expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tx.Statement.AddClause(clause.Select{Expression: expr})
|
||||
defer delete(tx.Statement.Clauses, "SELECT")
|
||||
}
|
||||
|
||||
if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
|
||||
|
@ -457,11 +466,13 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
|||
tx.AddError(ErrModelValueRequired)
|
||||
}
|
||||
|
||||
fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
|
||||
tx.Statement.AddClauseIfNotExists(clause.Select{
|
||||
Distinct: tx.Statement.Distinct,
|
||||
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
|
||||
})
|
||||
if len(tx.Statement.Selects) != 1 {
|
||||
fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
|
||||
tx.Statement.AddClauseIfNotExists(clause.Select{
|
||||
Distinct: tx.Statement.Distinct,
|
||||
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
|
||||
})
|
||||
}
|
||||
tx.Statement.Dest = dest
|
||||
tx.callbacks.Query().Execute(tx)
|
||||
return
|
||||
|
|
|
@ -3,6 +3,8 @@ package tests_test
|
|||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
@ -77,4 +79,46 @@ func TestCount(t *testing.T) {
|
|||
if err := DB.Table("users").Where("users.name = ?", user1.Name).Order("name").Count(&count5).Error; err != nil || count5 != 1 {
|
||||
t.Errorf("count with join, got error: %v, count %v", err, count)
|
||||
}
|
||||
|
||||
var count6 int64
|
||||
if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select(
|
||||
"(CASE WHEN name=? THEN ? ELSE ? END) as name", "count-1", "main", "other",
|
||||
).Count(&count6).Find(&users).Error; err != nil || count6 != 3 {
|
||||
t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err))
|
||||
}
|
||||
|
||||
expects := []User{User{Name: "main"}, {Name: "other"}, {Name: "other"}}
|
||||
sort.SliceStable(users, func(i, j int) bool {
|
||||
return strings.Compare(users[i].Name, users[j].Name) < 0
|
||||
})
|
||||
|
||||
AssertEqual(t, users, expects)
|
||||
|
||||
var count7 int64
|
||||
if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select(
|
||||
"(CASE WHEN name=? THEN ? ELSE ? END) as name, age", "count-1", "main", "other",
|
||||
).Count(&count7).Find(&users).Error; err != nil || count7 != 3 {
|
||||
t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err))
|
||||
}
|
||||
|
||||
expects = []User{User{Name: "main", Age: 18}, {Name: "other", Age: 18}, {Name: "other", Age: 18}}
|
||||
sort.SliceStable(users, func(i, j int) bool {
|
||||
return strings.Compare(users[i].Name, users[j].Name) < 0
|
||||
})
|
||||
|
||||
AssertEqual(t, users, expects)
|
||||
|
||||
var count8 int64
|
||||
if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select(
|
||||
"(CASE WHEN age=18 THEN 1 ELSE 2 END) as age", "name",
|
||||
).Count(&count8).Find(&users).Error; err != nil || count8 != 3 {
|
||||
t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err))
|
||||
}
|
||||
|
||||
expects = []User{User{Name: "count-1", Age: 1}, {Name: "count-2", Age: 1}, {Name: "count-3", Age: 1}}
|
||||
sort.SliceStable(users, func(i, j int) bool {
|
||||
return strings.Compare(users[i].Name, users[j].Name) < 0
|
||||
})
|
||||
|
||||
AssertEqual(t, users, expects)
|
||||
}
|
||||
|
|
|
@ -677,7 +677,7 @@ func TestPluckWithSelect(t *testing.T) {
|
|||
DB.Create(&users)
|
||||
|
||||
var userAges []int
|
||||
err := DB.Model(&User{}).Where("name like ?", "pluck_with_select%").Select("age + 1 as user_age").Pluck("user_age", &userAges).Error
|
||||
err := DB.Model(&User{}).Where("name like ?", "pluck_with_select%").Select("age + 1 as user_age").Pluck("user_age", &userAges).Error
|
||||
if err != nil {
|
||||
t.Fatalf("got error when pluck user_age: %v", err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue