mirror of https://github.com/go-gorm/gorm.git
Add Count tests
This commit is contained in:
parent
1c39ac921b
commit
cbc4a81140
|
@ -247,11 +247,12 @@ func (association *Association) Clear() error {
|
||||||
return association.Replace()
|
return association.Replace()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (association *Association) Count() (count int) {
|
func (association *Association) Count() (count int64) {
|
||||||
if association.Error == nil {
|
if association.Error == nil {
|
||||||
var (
|
var (
|
||||||
tx = association.DB
|
conds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
|
||||||
conds = association.Relationship.ToQueryConditions(tx.Statement.ReflectValue)
|
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
|
||||||
|
tx = association.DB.Model(modelValue)
|
||||||
)
|
)
|
||||||
|
|
||||||
if association.Relationship.JoinTable != nil {
|
if association.Relationship.JoinTable != nil {
|
||||||
|
|
|
@ -73,6 +73,7 @@ func (cs *callbacks) Raw() *processor {
|
||||||
|
|
||||||
func (p *processor) Execute(db *DB) {
|
func (p *processor) Execute(db *DB) {
|
||||||
curTime := time.Now()
|
curTime := time.Now()
|
||||||
|
db.RowsAffected = 0
|
||||||
if stmt := db.Statement; stmt != nil {
|
if stmt := db.Statement; stmt != nil {
|
||||||
if stmt.Model == nil {
|
if stmt.Model == nil {
|
||||||
stmt.Model = stmt.Dest
|
stmt.Model = stmt.Dest
|
||||||
|
@ -102,7 +103,7 @@ func (p *processor) Execute(db *DB) {
|
||||||
}, db.Error)
|
}, db.Error)
|
||||||
|
|
||||||
stmt.reinit()
|
stmt.reinit()
|
||||||
db.Config.statementPool.Put(stmt)
|
// db.Config.statementPool.Put(stmt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,11 @@ func Query(db *gorm.DB) {
|
||||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||||
Name: f.DBName,
|
Name: f.DBName,
|
||||||
})
|
})
|
||||||
|
} else {
|
||||||
|
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||||
|
Name: name,
|
||||||
|
Raw: true,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -85,7 +90,7 @@ func Query(db *gorm.DB) {
|
||||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Statement.AddClauseIfNotExists(clauseSelect)
|
db.Statement.AddClause(clauseSelect)
|
||||||
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
|
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -49,6 +49,11 @@ func Scan(rows *sql.Rows, db *gorm.DB) {
|
||||||
}
|
}
|
||||||
*dest = append(*dest, v)
|
*dest = append(*dest, v)
|
||||||
}
|
}
|
||||||
|
case *int, *int64, *uint, *uint64:
|
||||||
|
for rows.Next() {
|
||||||
|
db.RowsAffected++
|
||||||
|
rows.Scan(dest)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
|
|
|
@ -41,8 +41,5 @@ func (values Values) Build(builder Builder) {
|
||||||
// MergeClause merge values clauses
|
// MergeClause merge values clauses
|
||||||
func (values Values) MergeClause(clause *Clause) {
|
func (values Values) MergeClause(clause *Clause) {
|
||||||
clause.Name = ""
|
clause.Name = ""
|
||||||
if v, ok := clause.Expression.(Values); ok {
|
|
||||||
values.Values = append(v.Values, values.Values...)
|
|
||||||
}
|
|
||||||
clause.Expression = values
|
clause.Expression = values
|
||||||
}
|
}
|
||||||
|
|
|
@ -145,8 +145,19 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Count(value interface{}) (tx *DB) {
|
func (db *DB) Count(count *int64) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
if len(tx.Statement.Selects) == 0 {
|
||||||
|
tx.Statement.Selects = []string{"count(1)"}
|
||||||
|
}
|
||||||
|
if tx.Statement.Model == nil {
|
||||||
|
tx.Statement.Model = tx.Statement.Dest
|
||||||
|
}
|
||||||
|
tx.Statement.Dest = count
|
||||||
|
tx.callbacks.Query().Execute(tx)
|
||||||
|
if db.RowsAffected != 1 {
|
||||||
|
*count = db.RowsAffected
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
54
statement.go
54
statement.go
|
@ -63,6 +63,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
||||||
case clause.Table:
|
case clause.Table:
|
||||||
if v.Name == clause.CurrentTable {
|
if v.Name == clause.CurrentTable {
|
||||||
stmt.DB.Dialector.QuoteTo(writer, stmt.Table)
|
stmt.DB.Dialector.QuoteTo(writer, stmt.Table)
|
||||||
|
} else if v.Raw {
|
||||||
|
writer.WriteString(v.Name)
|
||||||
} else {
|
} else {
|
||||||
stmt.DB.Dialector.QuoteTo(writer, v.Name)
|
stmt.DB.Dialector.QuoteTo(writer, v.Name)
|
||||||
}
|
}
|
||||||
|
@ -85,6 +87,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
||||||
if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil {
|
if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil {
|
||||||
stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName)
|
stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName)
|
||||||
}
|
}
|
||||||
|
} else if v.Raw {
|
||||||
|
writer.WriteString(v.Name)
|
||||||
} else {
|
} else {
|
||||||
stmt.DB.Dialector.QuoteTo(writer, v.Name)
|
stmt.DB.Dialector.QuoteTo(writer, v.Name)
|
||||||
}
|
}
|
||||||
|
@ -275,33 +279,33 @@ func (stmt *Statement) Parse(value interface{}) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stmt *Statement) reinit() {
|
func (stmt *Statement) reinit() {
|
||||||
stmt.Table = ""
|
// stmt.Table = ""
|
||||||
stmt.Model = nil
|
// stmt.Model = nil
|
||||||
stmt.Selects = nil
|
// stmt.Selects = nil
|
||||||
stmt.Omits = nil
|
// stmt.Omits = nil
|
||||||
stmt.ConnPool = stmt.DB.Config.ConnPool
|
// stmt.ConnPool = stmt.DB.Config.ConnPool
|
||||||
stmt.Schema = nil
|
// stmt.Context = context.Background()
|
||||||
stmt.Context = context.Background()
|
// stmt.RaiseErrorOnNotFound = false
|
||||||
stmt.RaiseErrorOnNotFound = false
|
|
||||||
|
|
||||||
|
// for k := range stmt.Clauses {
|
||||||
|
// delete(stmt.Clauses, k)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// for k := range stmt.Joins {
|
||||||
|
// delete(stmt.Joins, k)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// for k := range stmt.Preloads {
|
||||||
|
// delete(stmt.Preloads, k)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// stmt.Settings.Range(func(k, _ interface{}) bool {
|
||||||
|
// stmt.Settings.Delete(k)
|
||||||
|
// return true
|
||||||
|
// })
|
||||||
|
|
||||||
|
stmt.Schema = nil
|
||||||
stmt.SQL.Reset()
|
stmt.SQL.Reset()
|
||||||
stmt.Vars = nil
|
stmt.Vars = nil
|
||||||
stmt.NamedVars = nil
|
stmt.NamedVars = nil
|
||||||
|
|
||||||
for k := range stmt.Clauses {
|
|
||||||
delete(stmt.Clauses, k)
|
|
||||||
}
|
|
||||||
|
|
||||||
for k := range stmt.Joins {
|
|
||||||
delete(stmt.Joins, k)
|
|
||||||
}
|
|
||||||
|
|
||||||
for k := range stmt.Preloads {
|
|
||||||
delete(stmt.Preloads, k)
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt.Settings.Range(func(k, _ interface{}) bool {
|
|
||||||
stmt.Settings.Delete(k)
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,4 +21,12 @@ func TestAssociationForBelongsTo(t *testing.T) {
|
||||||
user2.Manager = &User{}
|
user2.Manager = &User{}
|
||||||
DB.Model(&user2).Association("Manager").Find(user2.Manager)
|
DB.Model(&user2).Association("Manager").Find(user2.Manager)
|
||||||
CheckUser(t, user2, user)
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
if count := DB.Model(&user).Association("Company").Count(); count != 1 {
|
||||||
|
t.Errorf("invalid company count, got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Model(&user).Association("Manager").Count(); count != 1 {
|
||||||
|
t.Errorf("invalid manager count, got %v", count)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
package tests_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/jinzhu/gorm/tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCount(t *testing.T) {
|
||||||
|
var (
|
||||||
|
user1 = *GetUser("count-1", Config{})
|
||||||
|
user2 = *GetUser("count-2", Config{})
|
||||||
|
user3 = *GetUser("count-3", Config{})
|
||||||
|
users []User
|
||||||
|
count, count1, count2 int64
|
||||||
|
)
|
||||||
|
|
||||||
|
DB.Save(&user1).Save(&user2).Save(&user3)
|
||||||
|
|
||||||
|
if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil {
|
||||||
|
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if count != int64(len(users)) {
|
||||||
|
t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users))
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2)
|
||||||
|
if count1 != 1 || count2 != 3 {
|
||||||
|
t.Errorf("multiple count in chain should works")
|
||||||
|
}
|
||||||
|
|
||||||
|
var count3 int64
|
||||||
|
if err := DB.Model(&User{}).Where("name in ?", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil {
|
||||||
|
t.Errorf("No error should happen when count with group, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count3 != 2 {
|
||||||
|
t.Errorf("Should get correct count for count with group, but got %v", count3)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue