Fix Count with Select when Model not specfied, close #3220

This commit is contained in:
Jinzhu 2020-08-03 10:30:25 +08:00
parent 2676fa4fb8
commit f83b00d20d
3 changed files with 25 additions and 2 deletions

View File

@ -274,11 +274,18 @@ func (db *DB) Count(count *int64) (tx *DB) {
expr := clause.Expr{SQL: "count(1)"} expr := clause.Expr{SQL: "count(1)"}
if len(tx.Statement.Selects) == 1 { if len(tx.Statement.Selects) == 1 {
dbName := tx.Statement.Selects[0]
if tx.Statement.Parse(tx.Statement.Model) == nil { if tx.Statement.Parse(tx.Statement.Model) == nil {
if f := tx.Statement.Schema.LookUpField(tx.Statement.Selects[0]); f != nil { if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: f.DBName}}} dbName = f.DBName
} }
} }
if tx.Statement.Distinct {
expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}}
} else {
expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
}
} }
tx.Statement.AddClause(clause.Select{Expression: expr}) tx.Statement.AddClause(clause.Select{Expression: expr})

View File

@ -72,6 +72,10 @@ type Tabler interface {
// get data type from dialector // get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
if dest == nil {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
modelType := reflect.ValueOf(dest).Type() modelType := reflect.ValueOf(dest).Type()
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem() modelType = modelType.Elem()

View File

@ -2,6 +2,7 @@ package tests_test
import ( import (
"fmt" "fmt"
"regexp"
"testing" "testing"
"gorm.io/gorm" "gorm.io/gorm"
@ -55,4 +56,15 @@ func TestCount(t *testing.T) {
if count3 != 2 { if count3 != 2 {
t.Errorf("Should get correct count for count with group, but got %v", count3) t.Errorf("Should get correct count for count with group, but got %v", count3)
} }
dryDB := DB.Session(&gorm.Session{DryRun: true})
result := dryDB.Table("users").Select("name").Count(&count)
if !regexp.MustCompile(`SELECT COUNT\(.name.\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String())
}
result = dryDB.Table("users").Distinct("name").Count(&count)
if !regexp.MustCompile(`SELECT COUNT\(DISTINCT\(.name.\)\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String())
}
} }