diff --git a/finisher_api.go b/finisher_api.go index af040106..25c56e49 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -269,6 +269,7 @@ func (db *DB) Count(count *int64) (tx *DB) { if len(tx.Statement.Selects) == 0 { tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) + defer tx.Statement.AddClause(clause.Select{}) } else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") { expr := clause.Expr{SQL: "count(1)"} @@ -281,6 +282,7 @@ func (db *DB) Count(count *int64) (tx *DB) { } tx.Statement.AddClause(clause.Select{Expression: expr}) + defer tx.Statement.AddClause(clause.Select{}) } tx.Statement.Dest = count diff --git a/tests/count_test.go b/tests/count_test.go index 0662ae5c..826d6a36 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -27,6 +27,14 @@ func TestCount(t *testing.T) { t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users)) } + if err := DB.Model(&User{}).Where("name = ?", user1.Name).Or("name = ?", user3.Name).Count(&count).Find(&users).Error; err != nil { + t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + if count != int64(len(users)) { + t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users)) + } + DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2) if count1 != 1 || count2 != 3 { t.Errorf("multiple count in chain should works")