Fix Count with complicated Select, close #3826

This commit is contained in:
Jinzhu 2020-12-06 14:04:37 +08:00
parent f655041908
commit 1ef1f0bfe4
4 changed files with 77 additions and 25 deletions

View File

@ -93,10 +93,12 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
} }
delete(tx.Statement.Clauses, "SELECT") delete(tx.Statement.Clauses, "SELECT")
case string: case string:
fields := strings.FieldsFunc(v, utils.IsValidDBNameChar) if (strings.Contains(v, " ?") || strings.Contains(v, "(?")) && len(args) > 0 {
tx.Statement.AddClause(clause.Select{
// normal field names Distinct: db.Statement.Distinct,
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { Expression: clause.Expr{SQL: v, Vars: args},
})
} else {
tx.Statement.Selects = []string{v} tx.Statement.Selects = []string{v}
for _, arg := range args { for _, arg := range args {
@ -115,11 +117,6 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
} }
delete(tx.Statement.Clauses, "SELECT") delete(tx.Statement.Clauses, "SELECT")
} else {
tx.Statement.AddClause(clause.Select{
Distinct: db.Statement.Distinct,
Expression: clause.Expr{SQL: v, Vars: args},
})
} }
default: default:
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))

View File

@ -355,14 +355,23 @@ func (db *DB) Count(count *int64) (tx *DB) {
}() }()
} }
if selectClause, ok := db.Statement.Clauses["SELECT"]; ok {
defer func() {
db.Statement.Clauses["SELECT"] = selectClause
}()
} else {
defer delete(tx.Statement.Clauses, "SELECT")
}
if len(tx.Statement.Selects) == 0 { if len(tx.Statement.Selects) == 0 {
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
defer delete(tx.Statement.Clauses, "SELECT")
} else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") { } else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") {
expr := clause.Expr{SQL: "count(1)"} expr := clause.Expr{SQL: "count(1)"}
if len(tx.Statement.Selects) == 1 { if len(tx.Statement.Selects) == 1 {
dbName := tx.Statement.Selects[0] dbName := tx.Statement.Selects[0]
fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar)
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
if tx.Statement.Parse(tx.Statement.Model) == nil { if tx.Statement.Parse(tx.Statement.Model) == nil {
if f := tx.Statement.Schema.LookUpField(dbName); f != nil { if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
dbName = f.DBName dbName = f.DBName
@ -375,9 +384,9 @@ func (db *DB) Count(count *int64) (tx *DB) {
expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}} expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
} }
} }
}
tx.Statement.AddClause(clause.Select{Expression: expr}) tx.Statement.AddClause(clause.Select{Expression: expr})
defer delete(tx.Statement.Clauses, "SELECT")
} }
if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok { if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
@ -457,11 +466,13 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
tx.AddError(ErrModelValueRequired) tx.AddError(ErrModelValueRequired)
} }
if len(tx.Statement.Selects) != 1 {
fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
tx.Statement.AddClauseIfNotExists(clause.Select{ tx.Statement.AddClauseIfNotExists(clause.Select{
Distinct: tx.Statement.Distinct, Distinct: tx.Statement.Distinct,
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
}) })
}
tx.Statement.Dest = dest tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx) tx.callbacks.Query().Execute(tx)
return return

View File

@ -3,6 +3,8 @@ package tests_test
import ( import (
"fmt" "fmt"
"regexp" "regexp"
"sort"
"strings"
"testing" "testing"
"gorm.io/gorm" "gorm.io/gorm"
@ -77,4 +79,46 @@ func TestCount(t *testing.T) {
if err := DB.Table("users").Where("users.name = ?", user1.Name).Order("name").Count(&count5).Error; err != nil || count5 != 1 { if err := DB.Table("users").Where("users.name = ?", user1.Name).Order("name").Count(&count5).Error; err != nil || count5 != 1 {
t.Errorf("count with join, got error: %v, count %v", err, count) t.Errorf("count with join, got error: %v, count %v", err, count)
} }
var count6 int64
if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select(
"(CASE WHEN name=? THEN ? ELSE ? END) as name", "count-1", "main", "other",
).Count(&count6).Find(&users).Error; err != nil || count6 != 3 {
t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err))
}
expects := []User{User{Name: "main"}, {Name: "other"}, {Name: "other"}}
sort.SliceStable(users, func(i, j int) bool {
return strings.Compare(users[i].Name, users[j].Name) < 0
})
AssertEqual(t, users, expects)
var count7 int64
if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select(
"(CASE WHEN name=? THEN ? ELSE ? END) as name, age", "count-1", "main", "other",
).Count(&count7).Find(&users).Error; err != nil || count7 != 3 {
t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err))
}
expects = []User{User{Name: "main", Age: 18}, {Name: "other", Age: 18}, {Name: "other", Age: 18}}
sort.SliceStable(users, func(i, j int) bool {
return strings.Compare(users[i].Name, users[j].Name) < 0
})
AssertEqual(t, users, expects)
var count8 int64
if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select(
"(CASE WHEN age=18 THEN 1 ELSE 2 END) as age", "name",
).Count(&count8).Find(&users).Error; err != nil || count8 != 3 {
t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err))
}
expects = []User{User{Name: "count-1", Age: 1}, {Name: "count-2", Age: 1}, {Name: "count-3", Age: 1}}
sort.SliceStable(users, func(i, j int) bool {
return strings.Compare(users[i].Name, users[j].Name) < 0
})
AssertEqual(t, users, expects)
} }