From 788da015d1dd4384b15cd20958d65b5b8037470c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 11 Nov 2013 21:55:44 +0800 Subject: [PATCH] Clean up code --- README.md | 5 +++- do.go | 66 +++++++++++++--------------------------------------- gorm_test.go | 2 ++ model.go | 2 +- 4 files changed, 23 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index de1070c3..a508f7bb 100644 --- a/README.md +++ b/README.md @@ -555,6 +555,9 @@ db.Model(&User{}).Pluck("name", &names) // Set Table With Table db.Table("deleted_users").Pluck("name", &names) //// SELECT name FROM deleted_users; + +// Pluck more than one column? Do it like this +db.Select("name, age").Find(&users) ``` ## Callbacks @@ -719,7 +722,7 @@ db.Where("email = ?", "x@example.org").Attrs(User{FromIp: "111.111.111.111"}).Fi ``` ## TODO -* Join, Having, Group, Includes, Pluck (distinct) +* Join, Having, Group, Includes * Scopes * Index, Unique, Valiations diff --git a/do.go b/do.go index 7fe13012..89525bf4 100644 --- a/do.go +++ b/do.go @@ -52,10 +52,6 @@ func (s *Do) err(err error) error { return err } -func (s *Do) hasError() bool { - return s.chain.hasError() -} - func (s *Do) setModel(value interface{}) *Do { s.model = &Model{data: value, do: s} s.value = value @@ -72,7 +68,7 @@ func (s *Do) addToVars(value interface{}) string { } func (s *Do) exec(sqls ...string) (err error) { - if s.hasError() { + if s.chain.hasError() { return } else { if len(sqls) > 0 { @@ -178,7 +174,7 @@ func (s *Do) create() (i interface{}) { s.saveBeforeAssociations() s.prepareCreateSql() - if !s.hasError() { + if !s.chain.hasError() { now := time.Now() var id interface{} if s.chain.driver() == "postgres" { @@ -191,7 +187,7 @@ func (s *Do) create() (i interface{}) { } s.chain.slog(s.sql, now, s.sqlVars...) - if !s.hasError() { + if !s.chain.hasError() { result := reflect.Indirect(reflect.ValueOf(s.value)) if !setFieldValue(result.FieldByName(s.model.primaryKey()), id) { fmt.Printf("Can't set primary key for %#v\n", result.Interface()) @@ -272,7 +268,7 @@ func (s *Do) update() (i interface{}) { s.saveBeforeAssociations() s.prepareUpdateSql(update_attrs) - if !s.hasError() { + if !s.chain.hasError() { s.exec() s.saveAfterAssociations() @@ -291,7 +287,7 @@ func (s *Do) prepareDeleteSql() { func (s *Do) delete() { s.model.callMethod("BeforeDelete") - if !s.hasError() { + if !s.chain.hasError() { if !s.unscoped && s.model.hasColumn("DeletedAt") { delete_sql := "deleted_at=" + s.addToVars(time.Now()) s.sql = fmt.Sprintf("UPDATE %v SET %v %v", s.tableName(), delete_sql, s.combinedSql()) @@ -379,7 +375,7 @@ func (s *Do) query() { } s.prepareQuerySql() - if !s.hasError() { + if !s.chain.hasError() { now := time.Now() rows, err := s.db.Query(s.sql, s.sqlVars...) s.chain.slog(s.sql, now, s.sqlVars...) @@ -433,26 +429,12 @@ func (s *Do) query() { } func (s *Do) count(value interface{}) { - dest_out := reflect.Indirect(reflect.ValueOf(value)) - s.prepareQuerySql() - if !s.hasError() { + if !s.chain.hasError() { now := time.Now() - rows, err := s.db.Query(s.sql, s.sqlVars...) + s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(value)) s.chain.slog(s.sql, now, s.sqlVars...) - if s.err(err) != nil { - return - } - - defer rows.Close() - for rows.Next() { - var dest int64 - if s.err(rows.Scan(&dest)) == nil { - setFieldValue(dest_out, dest) - } - } } - return } func (s *Do) pluck(column string, value interface{}) { @@ -460,42 +442,26 @@ func (s *Do) pluck(column string, value interface{}) { dest_out := reflect.Indirect(reflect.ValueOf(value)) if dest_out.Kind() != reflect.Slice { - s.err(errors.New("Return results should be a slice")) + s.err(errors.New("Results should be a slice")) return } - dest_type := dest_out.Type().Elem() s.prepareQuerySql() - if !s.hasError() { + if !s.chain.hasError() { now := time.Now() rows, err := s.db.Query(s.sql, s.sqlVars...) s.chain.slog(s.sql, now, s.sqlVars...) - if s.err(err) != nil { - return - } - defer rows.Close() - for rows.Next() { - dest := reflect.New(dest_type).Elem().Interface() - s.err(rows.Scan(&dest)) - - switch dest.(type) { - case []uint8: - if dest_type.String() == "string" { - dest = string(dest.([]uint8)) - } else if dest_type.String() == "int64" { - dest, _ = strconv.Atoi(string(dest.([]uint8))) - dest = int64(dest.(int)) - } - - dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest))) - default: - dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest))) + if s.err(err) == nil { + defer rows.Close() + for rows.Next() { + dest := reflect.New(dest_out.Type().Elem()).Interface() + s.err(rows.Scan(dest)) + dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest).Elem())) } } } - return } func (s *Do) where(where ...interface{}) *Do { diff --git a/gorm_test.go b/gorm_test.go index 6a445a84..225f2668 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -505,6 +505,8 @@ func TestOrderAndPluck(t *testing.T) { if !(names[0] == "1" && names[2] == "3" && names[3] == "3" && ages[2] == 24 && ages[3] == 22) { t.Errorf("Should be ordered correctly with multiple orders") } + + db.Model(User{}).Select("name, age").Find(&[]User{}) } func TestLimit(t *testing.T) { diff --git a/model.go b/model.go index 9be43bfd..17c872fd 100644 --- a/model.go +++ b/model.go @@ -309,7 +309,7 @@ func (m *Model) tableName() (str string) { } func (m *Model) callMethod(method string) { - if m.data == nil || m.do.hasError() { + if m.data == nil || m.do.chain.hasError() { return }