From ec49f57394e5a77e294aabb7b005f875cacf1111 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 27 Oct 2013 20:07:13 +0800 Subject: [PATCH] Make Count works --- chain.go | 6 ++++-- do.go | 14 ++++++++++++++ orm_test.go | 20 +++++++++++++++++--- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/chain.go b/chain.go index 2eb98dc7..58300400 100644 --- a/chain.go +++ b/chain.go @@ -45,6 +45,7 @@ func (s *Chain) do(value interface{}) *Do { do.offsetStr = s.offsetStr do.limitStr = s.limitStr + s.value = value do.setModel(value) return &do } @@ -101,8 +102,9 @@ func (s *Chain) Order(value string, reorder ...bool) *Chain { return s } -func (s *Chain) Count() int64 { - return 0 +func (s *Chain) Count(value interface{}) *Chain { + s.Select("count(*)").do(s.value).count(value) + return s } func (s *Chain) Select(value interface{}) *Chain { diff --git a/do.go b/do.go index d3a1d19e..c8d7d45a 100644 --- a/do.go +++ b/do.go @@ -208,6 +208,20 @@ func (s *Do) query() { } } +func (s *Do) count(value interface{}) { + dest_out := reflect.Indirect(reflect.ValueOf(value)) + + s.prepareQuerySql() + rows, err := s.db.Query(s.Sql, s.SqlVars...) + s.err(err) + for rows.Next() { + var dest int64 + s.err(rows.Scan(&dest)) + dest_out.Set(reflect.ValueOf(dest)) + } + return +} + func (s *Do) pluck(value interface{}) *Do { dest_out := reflect.Indirect(reflect.ValueOf(value)) dest_type := dest_out.Type().Elem() diff --git a/orm_test.go b/orm_test.go index 6a9da3e3..f7bf4cfd 100644 --- a/orm_test.go +++ b/orm_test.go @@ -63,8 +63,8 @@ func init() { } func TestFirst(t *testing.T) { - // var u1, u2 User - // db.Where("name = ?", "3").Order("age desc").First(&u1).First(&u2) + var u1, u2 User + db.Where("name = ?", "3").Order("age desc").First(&u1).First(&u2) } func TestSaveAndFind(t *testing.T) { @@ -302,7 +302,7 @@ func TestOffset(t *testing.T) { } } -func TestOrAndNot(t *testing.T) { +func TestOr(t *testing.T) { var users []User db.Where("name = ?", "1").Or("name = ?", "3").Find(&users) if len(users) != 3 { @@ -310,6 +310,20 @@ func TestOrAndNot(t *testing.T) { } } +func TestCount(t *testing.T) { + var count, count1, count2 int64 + var users []User + db.Where("name = ?", "1").Or("name = ?", "3").Find(&users).Count(&count) + if count != int64(len(users)) { + t.Errorf("Count() method should get same value of users count") + } + + db.Model(&User{}).Where("name = ?", "1").Count(&count1).Or("name = ?", "3").Count(&count2) + if !(count1 == 1 && count2 == 3) { + t.Errorf("Multiple count should works well also") + } +} + func TestCreatedAtAndUpdatedAt(t *testing.T) { name := "check_created_at_and_updated_at" u := User{Name: name, Age: 1}