Add Count tests

This commit is contained in:
Jinzhu 2020-05-24 11:32:59 +08:00
parent 1c39ac921b
commit cbc4a81140
9 changed files with 108 additions and 34 deletions

View File

@ -247,11 +247,12 @@ func (association *Association) Clear() error {
return association.Replace() return association.Replace()
} }
func (association *Association) Count() (count int) { func (association *Association) Count() (count int64) {
if association.Error == nil { if association.Error == nil {
var ( var (
tx = association.DB conds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
conds = association.Relationship.ToQueryConditions(tx.Statement.ReflectValue) modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue)
) )
if association.Relationship.JoinTable != nil { if association.Relationship.JoinTable != nil {

View File

@ -73,6 +73,7 @@ func (cs *callbacks) Raw() *processor {
func (p *processor) Execute(db *DB) { func (p *processor) Execute(db *DB) {
curTime := time.Now() curTime := time.Now()
db.RowsAffected = 0
if stmt := db.Statement; stmt != nil { if stmt := db.Statement; stmt != nil {
if stmt.Model == nil { if stmt.Model == nil {
stmt.Model = stmt.Dest stmt.Model = stmt.Dest
@ -102,7 +103,7 @@ func (p *processor) Execute(db *DB) {
}, db.Error) }, db.Error)
stmt.reinit() stmt.reinit()
db.Config.statementPool.Put(stmt) // db.Config.statementPool.Put(stmt)
} }
} }

View File

@ -21,6 +21,11 @@ func Query(db *gorm.DB) {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Name: f.DBName, Name: f.DBName,
}) })
} else {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Name: name,
Raw: true,
})
} }
} }
} }
@ -85,7 +90,7 @@ func Query(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.From{}) db.Statement.AddClauseIfNotExists(clause.From{})
} }
db.Statement.AddClauseIfNotExists(clauseSelect) db.Statement.AddClause(clauseSelect)
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
} }

View File

@ -49,6 +49,11 @@ func Scan(rows *sql.Rows, db *gorm.DB) {
} }
*dest = append(*dest, v) *dest = append(*dest, v)
} }
case *int, *int64, *uint, *uint64:
for rows.Next() {
db.RowsAffected++
rows.Scan(dest)
}
default: default:
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:

View File

@ -41,8 +41,5 @@ func (values Values) Build(builder Builder) {
// MergeClause merge values clauses // MergeClause merge values clauses
func (values Values) MergeClause(clause *Clause) { func (values Values) MergeClause(clause *Clause) {
clause.Name = "" clause.Name = ""
if v, ok := clause.Expression.(Values); ok {
values.Values = append(v.Values, values.Values...)
}
clause.Expression = values clause.Expression = values
} }

View File

@ -145,8 +145,19 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
return return
} }
func (db *DB) Count(value interface{}) (tx *DB) { func (db *DB) Count(count *int64) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if len(tx.Statement.Selects) == 0 {
tx.Statement.Selects = []string{"count(1)"}
}
if tx.Statement.Model == nil {
tx.Statement.Model = tx.Statement.Dest
}
tx.Statement.Dest = count
tx.callbacks.Query().Execute(tx)
if db.RowsAffected != 1 {
*count = db.RowsAffected
}
return return
} }

View File

@ -63,6 +63,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
case clause.Table: case clause.Table:
if v.Name == clause.CurrentTable { if v.Name == clause.CurrentTable {
stmt.DB.Dialector.QuoteTo(writer, stmt.Table) stmt.DB.Dialector.QuoteTo(writer, stmt.Table)
} else if v.Raw {
writer.WriteString(v.Name)
} else { } else {
stmt.DB.Dialector.QuoteTo(writer, v.Name) stmt.DB.Dialector.QuoteTo(writer, v.Name)
} }
@ -85,6 +87,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil { if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil {
stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName) stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName)
} }
} else if v.Raw {
writer.WriteString(v.Name)
} else { } else {
stmt.DB.Dialector.QuoteTo(writer, v.Name) stmt.DB.Dialector.QuoteTo(writer, v.Name)
} }
@ -275,33 +279,33 @@ func (stmt *Statement) Parse(value interface{}) (err error) {
} }
func (stmt *Statement) reinit() { func (stmt *Statement) reinit() {
stmt.Table = "" // stmt.Table = ""
stmt.Model = nil // stmt.Model = nil
stmt.Selects = nil // stmt.Selects = nil
stmt.Omits = nil // stmt.Omits = nil
stmt.ConnPool = stmt.DB.Config.ConnPool // stmt.ConnPool = stmt.DB.Config.ConnPool
stmt.Schema = nil // stmt.Context = context.Background()
stmt.Context = context.Background() // stmt.RaiseErrorOnNotFound = false
stmt.RaiseErrorOnNotFound = false
// for k := range stmt.Clauses {
// delete(stmt.Clauses, k)
// }
// for k := range stmt.Joins {
// delete(stmt.Joins, k)
// }
// for k := range stmt.Preloads {
// delete(stmt.Preloads, k)
// }
// stmt.Settings.Range(func(k, _ interface{}) bool {
// stmt.Settings.Delete(k)
// return true
// })
stmt.Schema = nil
stmt.SQL.Reset() stmt.SQL.Reset()
stmt.Vars = nil stmt.Vars = nil
stmt.NamedVars = nil stmt.NamedVars = nil
for k := range stmt.Clauses {
delete(stmt.Clauses, k)
}
for k := range stmt.Joins {
delete(stmt.Joins, k)
}
for k := range stmt.Preloads {
delete(stmt.Preloads, k)
}
stmt.Settings.Range(func(k, _ interface{}) bool {
stmt.Settings.Delete(k)
return true
})
} }

View File

@ -21,4 +21,12 @@ func TestAssociationForBelongsTo(t *testing.T) {
user2.Manager = &User{} user2.Manager = &User{}
DB.Model(&user2).Association("Manager").Find(user2.Manager) DB.Model(&user2).Association("Manager").Find(user2.Manager)
CheckUser(t, user2, user) CheckUser(t, user2, user)
if count := DB.Model(&user).Association("Company").Count(); count != 1 {
t.Errorf("invalid company count, got %v", count)
}
if count := DB.Model(&user).Association("Manager").Count(); count != 1 {
t.Errorf("invalid manager count, got %v", count)
}
} }

42
tests/count_test.go Normal file
View File

@ -0,0 +1,42 @@
package tests_test
import (
"fmt"
"testing"
. "github.com/jinzhu/gorm/tests"
)
func TestCount(t *testing.T) {
var (
user1 = *GetUser("count-1", Config{})
user2 = *GetUser("count-2", Config{})
user3 = *GetUser("count-3", Config{})
users []User
count, count1, count2 int64
)
DB.Save(&user1).Save(&user2).Save(&user3)
if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil {
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
}
if count != int64(len(users)) {
t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users))
}
DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2)
if count1 != 1 || count2 != 3 {
t.Errorf("multiple count in chain should works")
}
var count3 int64
if err := DB.Model(&User{}).Where("name in ?", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil {
t.Errorf("No error should happen when count with group, but got %v", err)
}
if count3 != 2 {
t.Errorf("Should get correct count for count with group, but got %v", count3)
}
}