forked from mirror/gorm
Add Count tests
This commit is contained in:
parent
1c39ac921b
commit
cbc4a81140
|
@ -247,11 +247,12 @@ func (association *Association) Clear() error {
|
|||
return association.Replace()
|
||||
}
|
||||
|
||||
func (association *Association) Count() (count int) {
|
||||
func (association *Association) Count() (count int64) {
|
||||
if association.Error == nil {
|
||||
var (
|
||||
tx = association.DB
|
||||
conds = association.Relationship.ToQueryConditions(tx.Statement.ReflectValue)
|
||||
conds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
|
||||
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
|
||||
tx = association.DB.Model(modelValue)
|
||||
)
|
||||
|
||||
if association.Relationship.JoinTable != nil {
|
||||
|
|
|
@ -73,6 +73,7 @@ func (cs *callbacks) Raw() *processor {
|
|||
|
||||
func (p *processor) Execute(db *DB) {
|
||||
curTime := time.Now()
|
||||
db.RowsAffected = 0
|
||||
if stmt := db.Statement; stmt != nil {
|
||||
if stmt.Model == nil {
|
||||
stmt.Model = stmt.Dest
|
||||
|
@ -102,7 +103,7 @@ func (p *processor) Execute(db *DB) {
|
|||
}, db.Error)
|
||||
|
||||
stmt.reinit()
|
||||
db.Config.statementPool.Put(stmt)
|
||||
// db.Config.statementPool.Put(stmt)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,11 @@ func Query(db *gorm.DB) {
|
|||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||
Name: f.DBName,
|
||||
})
|
||||
} else {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||
Name: name,
|
||||
Raw: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -85,7 +90,7 @@ func Query(db *gorm.DB) {
|
|||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
}
|
||||
|
||||
db.Statement.AddClauseIfNotExists(clauseSelect)
|
||||
db.Statement.AddClause(clauseSelect)
|
||||
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
|
||||
}
|
||||
|
||||
|
|
|
@ -49,6 +49,11 @@ func Scan(rows *sql.Rows, db *gorm.DB) {
|
|||
}
|
||||
*dest = append(*dest, v)
|
||||
}
|
||||
case *int, *int64, *uint, *uint64:
|
||||
for rows.Next() {
|
||||
db.RowsAffected++
|
||||
rows.Scan(dest)
|
||||
}
|
||||
default:
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
|
|
|
@ -41,8 +41,5 @@ func (values Values) Build(builder Builder) {
|
|||
// MergeClause merge values clauses
|
||||
func (values Values) MergeClause(clause *Clause) {
|
||||
clause.Name = ""
|
||||
if v, ok := clause.Expression.(Values); ok {
|
||||
values.Values = append(v.Values, values.Values...)
|
||||
}
|
||||
clause.Expression = values
|
||||
}
|
||||
|
|
|
@ -145,8 +145,19 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
|
|||
return
|
||||
}
|
||||
|
||||
func (db *DB) Count(value interface{}) (tx *DB) {
|
||||
func (db *DB) Count(count *int64) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if len(tx.Statement.Selects) == 0 {
|
||||
tx.Statement.Selects = []string{"count(1)"}
|
||||
}
|
||||
if tx.Statement.Model == nil {
|
||||
tx.Statement.Model = tx.Statement.Dest
|
||||
}
|
||||
tx.Statement.Dest = count
|
||||
tx.callbacks.Query().Execute(tx)
|
||||
if db.RowsAffected != 1 {
|
||||
*count = db.RowsAffected
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
54
statement.go
54
statement.go
|
@ -63,6 +63,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
|||
case clause.Table:
|
||||
if v.Name == clause.CurrentTable {
|
||||
stmt.DB.Dialector.QuoteTo(writer, stmt.Table)
|
||||
} else if v.Raw {
|
||||
writer.WriteString(v.Name)
|
||||
} else {
|
||||
stmt.DB.Dialector.QuoteTo(writer, v.Name)
|
||||
}
|
||||
|
@ -85,6 +87,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
|||
if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil {
|
||||
stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName)
|
||||
}
|
||||
} else if v.Raw {
|
||||
writer.WriteString(v.Name)
|
||||
} else {
|
||||
stmt.DB.Dialector.QuoteTo(writer, v.Name)
|
||||
}
|
||||
|
@ -275,33 +279,33 @@ func (stmt *Statement) Parse(value interface{}) (err error) {
|
|||
}
|
||||
|
||||
func (stmt *Statement) reinit() {
|
||||
stmt.Table = ""
|
||||
stmt.Model = nil
|
||||
stmt.Selects = nil
|
||||
stmt.Omits = nil
|
||||
stmt.ConnPool = stmt.DB.Config.ConnPool
|
||||
stmt.Schema = nil
|
||||
stmt.Context = context.Background()
|
||||
stmt.RaiseErrorOnNotFound = false
|
||||
// stmt.Table = ""
|
||||
// stmt.Model = nil
|
||||
// stmt.Selects = nil
|
||||
// stmt.Omits = nil
|
||||
// stmt.ConnPool = stmt.DB.Config.ConnPool
|
||||
// stmt.Context = context.Background()
|
||||
// stmt.RaiseErrorOnNotFound = false
|
||||
|
||||
// for k := range stmt.Clauses {
|
||||
// delete(stmt.Clauses, k)
|
||||
// }
|
||||
|
||||
// for k := range stmt.Joins {
|
||||
// delete(stmt.Joins, k)
|
||||
// }
|
||||
|
||||
// for k := range stmt.Preloads {
|
||||
// delete(stmt.Preloads, k)
|
||||
// }
|
||||
|
||||
// stmt.Settings.Range(func(k, _ interface{}) bool {
|
||||
// stmt.Settings.Delete(k)
|
||||
// return true
|
||||
// })
|
||||
|
||||
stmt.Schema = nil
|
||||
stmt.SQL.Reset()
|
||||
stmt.Vars = nil
|
||||
stmt.NamedVars = nil
|
||||
|
||||
for k := range stmt.Clauses {
|
||||
delete(stmt.Clauses, k)
|
||||
}
|
||||
|
||||
for k := range stmt.Joins {
|
||||
delete(stmt.Joins, k)
|
||||
}
|
||||
|
||||
for k := range stmt.Preloads {
|
||||
delete(stmt.Preloads, k)
|
||||
}
|
||||
|
||||
stmt.Settings.Range(func(k, _ interface{}) bool {
|
||||
stmt.Settings.Delete(k)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
|
|
@ -21,4 +21,12 @@ func TestAssociationForBelongsTo(t *testing.T) {
|
|||
user2.Manager = &User{}
|
||||
DB.Model(&user2).Association("Manager").Find(user2.Manager)
|
||||
CheckUser(t, user2, user)
|
||||
|
||||
if count := DB.Model(&user).Association("Company").Count(); count != 1 {
|
||||
t.Errorf("invalid company count, got %v", count)
|
||||
}
|
||||
|
||||
if count := DB.Model(&user).Association("Manager").Count(); count != 1 {
|
||||
t.Errorf("invalid manager count, got %v", count)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
. "github.com/jinzhu/gorm/tests"
|
||||
)
|
||||
|
||||
func TestCount(t *testing.T) {
|
||||
var (
|
||||
user1 = *GetUser("count-1", Config{})
|
||||
user2 = *GetUser("count-2", Config{})
|
||||
user3 = *GetUser("count-3", Config{})
|
||||
users []User
|
||||
count, count1, count2 int64
|
||||
)
|
||||
|
||||
DB.Save(&user1).Save(&user2).Save(&user3)
|
||||
|
||||
if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil {
|
||||
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
|
||||
}
|
||||
|
||||
if count != int64(len(users)) {
|
||||
t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users))
|
||||
}
|
||||
|
||||
DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2)
|
||||
if count1 != 1 || count2 != 3 {
|
||||
t.Errorf("multiple count in chain should works")
|
||||
}
|
||||
|
||||
var count3 int64
|
||||
if err := DB.Model(&User{}).Where("name in ?", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil {
|
||||
t.Errorf("No error should happen when count with group, but got %v", err)
|
||||
}
|
||||
|
||||
if count3 != 2 {
|
||||
t.Errorf("Should get correct count for count with group, but got %v", count3)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue