diff --git a/finisher_api.go b/finisher_api.go index 77bea578..33a4f121 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -274,11 +274,18 @@ func (db *DB) Count(count *int64) (tx *DB) { expr := clause.Expr{SQL: "count(1)"} if len(tx.Statement.Selects) == 1 { + dbName := tx.Statement.Selects[0] if tx.Statement.Parse(tx.Statement.Model) == nil { - if f := tx.Statement.Schema.LookUpField(tx.Statement.Selects[0]); f != nil { - expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: f.DBName}}} + if f := tx.Statement.Schema.LookUpField(dbName); f != nil { + dbName = f.DBName } } + + if tx.Statement.Distinct { + expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}} + } else { + expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}} + } } tx.Statement.AddClause(clause.Select{Expression: expr}) diff --git a/schema/schema.go b/schema/schema.go index 1106f0c5..9206c24e 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -72,6 +72,10 @@ type Tabler interface { // get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { + if dest == nil { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + modelType := reflect.ValueOf(dest).Type() for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() diff --git a/tests/count_test.go b/tests/count_test.go index 826d6a36..05661ae8 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -2,6 +2,7 @@ package tests_test import ( "fmt" + "regexp" "testing" "gorm.io/gorm" @@ -55,4 +56,15 @@ func TestCount(t *testing.T) { if count3 != 2 { t.Errorf("Should get correct count for count with group, but got %v", count3) } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + result := dryDB.Table("users").Select("name").Count(&count) + if !regexp.MustCompile(`SELECT COUNT\(.name.\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Table("users").Distinct("name").Count(&count) + if !regexp.MustCompile(`SELECT COUNT\(DISTINCT\(.name.\)\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String()) + } }