forked from mirror/gorm
Fix Count with complicated Select, close #3826
This commit is contained in:
parent
f655041908
commit
1ef1f0bfe4
|
@ -93,10 +93,12 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
||||||
}
|
}
|
||||||
delete(tx.Statement.Clauses, "SELECT")
|
delete(tx.Statement.Clauses, "SELECT")
|
||||||
case string:
|
case string:
|
||||||
fields := strings.FieldsFunc(v, utils.IsValidDBNameChar)
|
if (strings.Contains(v, " ?") || strings.Contains(v, "(?")) && len(args) > 0 {
|
||||||
|
tx.Statement.AddClause(clause.Select{
|
||||||
// normal field names
|
Distinct: db.Statement.Distinct,
|
||||||
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
|
Expression: clause.Expr{SQL: v, Vars: args},
|
||||||
|
})
|
||||||
|
} else {
|
||||||
tx.Statement.Selects = []string{v}
|
tx.Statement.Selects = []string{v}
|
||||||
|
|
||||||
for _, arg := range args {
|
for _, arg := range args {
|
||||||
|
@ -115,11 +117,6 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(tx.Statement.Clauses, "SELECT")
|
delete(tx.Statement.Clauses, "SELECT")
|
||||||
} else {
|
|
||||||
tx.Statement.AddClause(clause.Select{
|
|
||||||
Distinct: db.Statement.Distinct,
|
|
||||||
Expression: clause.Expr{SQL: v, Vars: args},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
|
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
|
||||||
|
|
|
@ -355,14 +355,23 @@ func (db *DB) Count(count *int64) (tx *DB) {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if selectClause, ok := db.Statement.Clauses["SELECT"]; ok {
|
||||||
|
defer func() {
|
||||||
|
db.Statement.Clauses["SELECT"] = selectClause
|
||||||
|
}()
|
||||||
|
} else {
|
||||||
|
defer delete(tx.Statement.Clauses, "SELECT")
|
||||||
|
}
|
||||||
|
|
||||||
if len(tx.Statement.Selects) == 0 {
|
if len(tx.Statement.Selects) == 0 {
|
||||||
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
|
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
|
||||||
defer delete(tx.Statement.Clauses, "SELECT")
|
|
||||||
} else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") {
|
} else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") {
|
||||||
expr := clause.Expr{SQL: "count(1)"}
|
expr := clause.Expr{SQL: "count(1)"}
|
||||||
|
|
||||||
if len(tx.Statement.Selects) == 1 {
|
if len(tx.Statement.Selects) == 1 {
|
||||||
dbName := tx.Statement.Selects[0]
|
dbName := tx.Statement.Selects[0]
|
||||||
|
fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar)
|
||||||
|
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
|
||||||
if tx.Statement.Parse(tx.Statement.Model) == nil {
|
if tx.Statement.Parse(tx.Statement.Model) == nil {
|
||||||
if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
|
if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
|
||||||
dbName = f.DBName
|
dbName = f.DBName
|
||||||
|
@ -375,9 +384,9 @@ func (db *DB) Count(count *int64) (tx *DB) {
|
||||||
expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
|
expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
tx.Statement.AddClause(clause.Select{Expression: expr})
|
tx.Statement.AddClause(clause.Select{Expression: expr})
|
||||||
defer delete(tx.Statement.Clauses, "SELECT")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
|
if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
|
||||||
|
@ -457,11 +466,13 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
||||||
tx.AddError(ErrModelValueRequired)
|
tx.AddError(ErrModelValueRequired)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(tx.Statement.Selects) != 1 {
|
||||||
fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
|
fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
|
||||||
tx.Statement.AddClauseIfNotExists(clause.Select{
|
tx.Statement.AddClauseIfNotExists(clause.Select{
|
||||||
Distinct: tx.Statement.Distinct,
|
Distinct: tx.Statement.Distinct,
|
||||||
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
|
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
|
||||||
})
|
})
|
||||||
|
}
|
||||||
tx.Statement.Dest = dest
|
tx.Statement.Dest = dest
|
||||||
tx.callbacks.Query().Execute(tx)
|
tx.callbacks.Query().Execute(tx)
|
||||||
return
|
return
|
||||||
|
|
|
@ -3,6 +3,8 @@ package tests_test
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
@ -77,4 +79,46 @@ func TestCount(t *testing.T) {
|
||||||
if err := DB.Table("users").Where("users.name = ?", user1.Name).Order("name").Count(&count5).Error; err != nil || count5 != 1 {
|
if err := DB.Table("users").Where("users.name = ?", user1.Name).Order("name").Count(&count5).Error; err != nil || count5 != 1 {
|
||||||
t.Errorf("count with join, got error: %v, count %v", err, count)
|
t.Errorf("count with join, got error: %v, count %v", err, count)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var count6 int64
|
||||||
|
if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select(
|
||||||
|
"(CASE WHEN name=? THEN ? ELSE ? END) as name", "count-1", "main", "other",
|
||||||
|
).Count(&count6).Find(&users).Error; err != nil || count6 != 3 {
|
||||||
|
t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
expects := []User{User{Name: "main"}, {Name: "other"}, {Name: "other"}}
|
||||||
|
sort.SliceStable(users, func(i, j int) bool {
|
||||||
|
return strings.Compare(users[i].Name, users[j].Name) < 0
|
||||||
|
})
|
||||||
|
|
||||||
|
AssertEqual(t, users, expects)
|
||||||
|
|
||||||
|
var count7 int64
|
||||||
|
if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select(
|
||||||
|
"(CASE WHEN name=? THEN ? ELSE ? END) as name, age", "count-1", "main", "other",
|
||||||
|
).Count(&count7).Find(&users).Error; err != nil || count7 != 3 {
|
||||||
|
t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
expects = []User{User{Name: "main", Age: 18}, {Name: "other", Age: 18}, {Name: "other", Age: 18}}
|
||||||
|
sort.SliceStable(users, func(i, j int) bool {
|
||||||
|
return strings.Compare(users[i].Name, users[j].Name) < 0
|
||||||
|
})
|
||||||
|
|
||||||
|
AssertEqual(t, users, expects)
|
||||||
|
|
||||||
|
var count8 int64
|
||||||
|
if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select(
|
||||||
|
"(CASE WHEN age=18 THEN 1 ELSE 2 END) as age", "name",
|
||||||
|
).Count(&count8).Find(&users).Error; err != nil || count8 != 3 {
|
||||||
|
t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
expects = []User{User{Name: "count-1", Age: 1}, {Name: "count-2", Age: 1}, {Name: "count-3", Age: 1}}
|
||||||
|
sort.SliceStable(users, func(i, j int) bool {
|
||||||
|
return strings.Compare(users[i].Name, users[j].Name) < 0
|
||||||
|
})
|
||||||
|
|
||||||
|
AssertEqual(t, users, expects)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue