From 6404f803e87df6912f641305f040542088e2c066 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 16 Nov 2013 20:16:23 +0800 Subject: [PATCH] Reflect do.go --- do.go | 56 ++++++++++++++++++++++------------------------------ gorm_test.go | 5 +++-- model.go | 23 +++++++++++---------- 3 files changed, 38 insertions(+), 46 deletions(-) diff --git a/do.go b/do.go index a8a554fb..8a7a3aab 100644 --- a/do.go +++ b/do.go @@ -14,18 +14,17 @@ import ( ) type Do struct { - db *DB - search *search - model *Model - tableName string - usingUpdate bool - value interface{} - update_attrs map[string]interface{} - hasUpdate bool - ignoreProtectedAttrs bool - sql string - sqlVars []interface{} - startedTransaction bool + db *DB + search *search + model *Model + tableName string + value interface{} + usingUpdate bool + hasUpdate bool + update_attrs map[string]interface{} + sql string + sqlVars []interface{} + startedTransaction bool } func (s *Do) table() string { @@ -53,11 +52,7 @@ func (s *Do) err(err error) error { func (s *Do) setModel(value interface{}) *Do { s.model = &Model{data: value, do: s} s.value = value - if s.db.search == nil { - s.search = &search{} - } else { - s.search = s.db.search - } + s.search = s.db.search return s } @@ -67,7 +62,9 @@ func (s *Do) addToVars(value interface{}) string { } func (s *Do) trace(t time.Time) { - s.db.slog(s.sql, t, s.sqlVars...) + if len(s.sql) > 0 { + s.db.slog(s.sql, t, s.sqlVars...) + } } func (s *Do) exec(sqls ...string) *Do { @@ -113,12 +110,11 @@ func (s *Do) saveBeforeAssociations() { for _, field := range s.model.beforeAssociations() { do := &Do{db: s.db} - reflect_value := reflect.ValueOf(field.Value) - if reflect_value.CanAddr() { - do.setModel(reflect_value.Addr().Interface()).save() + if field.reflectValue.CanAddr() { + do.setModel(field.reflectValue.Addr().Interface()).save() } else { // If can't take address, then clone the value and set it back - dest_value := reflect.New(reflect_value.Type()).Elem() + dest_value := reflect.New(field.reflectValue.Type()).Elem() m := &Model{data: field.Value, do: s} for _, f := range m.columnsHasValue("other") { dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) @@ -170,6 +166,7 @@ func (s *Do) saveAfterAssociations() { } func (s *Do) create() (i interface{}) { + defer s.trace(time.Now()) s.model.callMethod("BeforeCreate") s.model.callMethod("BeforeSave") @@ -178,8 +175,6 @@ func (s *Do) create() (i interface{}) { if !s.db.hasError() { var id interface{} - - now := time.Now() if s.dialect().SupportLastInsertId() { if sql_result, err := s.db.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil { id, err = sql_result.LastInsertId() @@ -188,7 +183,6 @@ func (s *Do) create() (i interface{}) { } else { s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(&id)) } - s.db.slog(s.sql, now, s.sqlVars...) if !s.db.hasError() { s.model.setValueByColumn(s.model.primaryKey(), id, s.value) @@ -348,6 +342,7 @@ func (s *Do) related(value interface{}, foreign_keys ...string) *Do { } func (s *Do) query() *Do { + defer s.trace(time.Now()) var ( is_slice bool dest_type reflect.Type @@ -363,9 +358,7 @@ func (s *Do) query() *Do { s.prepareQuerySql() if !s.db.hasError() { - now := time.Now() rows, err := s.db.db.Query(s.sql, s.sqlVars...) - s.db.slog(s.sql, now, s.sqlVars...) if s.err(err) != nil { return s @@ -406,17 +399,19 @@ func (s *Do) query() *Do { } func (s *Do) count(value interface{}) *Do { + defer s.trace(time.Now()) + s.search = s.search.clone().selects("count(*)") s.prepareQuerySql() if !s.db.hasError() { - now := time.Now() s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(value)) - s.db.slog(s.sql, now, s.sqlVars...) } return s } func (s *Do) pluck(column string, value interface{}) *Do { + defer s.trace(time.Now()) + dest_out := reflect.Indirect(reflect.ValueOf(value)) s.search = s.search.clone().selects(column) if dest_out.Kind() != reflect.Slice { @@ -427,9 +422,7 @@ func (s *Do) pluck(column string, value interface{}) *Do { s.prepareQuerySql() if !s.db.hasError() { - now := time.Now() rows, err := s.db.db.Query(s.sql, s.sqlVars...) - s.db.slog(s.sql, now, s.sqlVars...) if s.err(err) == nil { defer rows.Close() @@ -653,7 +646,6 @@ func (s *Do) createTable() *Do { } s.sql = fmt.Sprintf("CREATE TABLE %v (%v)", s.table(), strings.Join(sqls, ",")) - s.exec() return s } diff --git a/gorm_test.go b/gorm_test.go index 8db3598e..945c697f 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -1242,6 +1242,7 @@ func (c Cart) TableName() string { } func TestTableName(t *testing.T) { + db := db.clone() if db.do(Order{}).table() != "orders" { t.Errorf("Order table name should be orders") } @@ -1393,7 +1394,7 @@ func (s *CreditCard) BeforeSave() (err error) { } func BenchmarkGorm(b *testing.B) { - b.N = 5000 + b.N = 2000 for x := 0; x < b.N; x++ { e := strconv.Itoa(x) + "benchmark@example.org" email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} @@ -1416,7 +1417,7 @@ func BenchmarkRawSql(b *testing.B) { update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3" delete_sql := "DELETE FROM orders WHERE id = $1" - b.N = 5000 + b.N = 2000 for x := 0; x < b.N; x++ { var id int64 e := strconv.Itoa(x) + "benchmark@example.org" diff --git a/model.go b/model.go index ea1aa603..0bb29e84 100644 --- a/model.go +++ b/model.go @@ -24,8 +24,7 @@ func (m *Model) primaryKeyZero() bool { func (m *Model) primaryKeyValue() interface{} { if data := m.reflectData(); data.Kind() == reflect.Struct { - field := data.FieldByName(m.primaryKey()) - if data.FieldByName(m.primaryKey()).IsValid() { + if field := data.FieldByName(m.primaryKey()); field.IsValid() { return field.Interface() } } @@ -133,7 +132,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} { results := map[string]interface{}{} for _, field := range m.fields(operation) { - if !field.isPrimaryKey && (len(field.sqlTag()) > 0) { + if !field.isPrimaryKey && len(field.sqlTag()) > 0 { results[field.dbName] = field.Value } } @@ -141,9 +140,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} { } func (m *Model) hasColumn(name string) bool { - data := m.reflectData() - - if data.Kind() == reflect.Struct { + if data := m.reflectData(); data.Kind() == reflect.Struct { return data.FieldByName(name).IsValid() } else if data.Kind() == reflect.Slice { return reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid() @@ -152,9 +149,7 @@ func (m *Model) hasColumn(name string) bool { } func (m *Model) columnAndValue(name string) (has_column bool, is_slice bool, value interface{}) { - data := m.reflectData() - - if data.Kind() == reflect.Struct { + if data := m.reflectData(); data.Kind() == reflect.Struct { if has_column = data.FieldByName(name).IsValid(); has_column { value = data.FieldByName(name).Interface() } @@ -165,15 +160,19 @@ func (m *Model) columnAndValue(name string) (has_column bool, is_slice bool, val return } -func (m *Model) typeName() string { +func (m *Model) typ() reflect.Type { typ := m.reflectData().Type() if typ.Kind() == reflect.Slice { - return typ.Elem().Name() + return typ.Elem() } else { - return typ.Name() + return typ } } +func (m *Model) typeName() string { + return m.typ().Name() +} + func (m *Model) tableName() (str string) { if m.data == nil { m.do.err(errors.New("Model haven't been set"))