Reflect do.go

This commit is contained in:
Jinzhu 2013-11-16 20:16:23 +08:00
parent 0ad707b410
commit 6404f803e8
3 changed files with 38 additions and 46 deletions

34
do.go
View File

@ -18,11 +18,10 @@ type Do struct {
search *search search *search
model *Model model *Model
tableName string tableName string
usingUpdate bool
value interface{} value interface{}
update_attrs map[string]interface{} usingUpdate bool
hasUpdate bool hasUpdate bool
ignoreProtectedAttrs bool update_attrs map[string]interface{}
sql string sql string
sqlVars []interface{} sqlVars []interface{}
startedTransaction bool startedTransaction bool
@ -53,11 +52,7 @@ func (s *Do) err(err error) error {
func (s *Do) setModel(value interface{}) *Do { func (s *Do) setModel(value interface{}) *Do {
s.model = &Model{data: value, do: s} s.model = &Model{data: value, do: s}
s.value = value s.value = value
if s.db.search == nil {
s.search = &search{}
} else {
s.search = s.db.search s.search = s.db.search
}
return s return s
} }
@ -67,7 +62,9 @@ func (s *Do) addToVars(value interface{}) string {
} }
func (s *Do) trace(t time.Time) { func (s *Do) trace(t time.Time) {
if len(s.sql) > 0 {
s.db.slog(s.sql, t, s.sqlVars...) s.db.slog(s.sql, t, s.sqlVars...)
}
} }
func (s *Do) exec(sqls ...string) *Do { func (s *Do) exec(sqls ...string) *Do {
@ -113,12 +110,11 @@ func (s *Do) saveBeforeAssociations() {
for _, field := range s.model.beforeAssociations() { for _, field := range s.model.beforeAssociations() {
do := &Do{db: s.db} do := &Do{db: s.db}
reflect_value := reflect.ValueOf(field.Value) if field.reflectValue.CanAddr() {
if reflect_value.CanAddr() { do.setModel(field.reflectValue.Addr().Interface()).save()
do.setModel(reflect_value.Addr().Interface()).save()
} else { } else {
// If can't take address, then clone the value and set it back // If can't take address, then clone the value and set it back
dest_value := reflect.New(reflect_value.Type()).Elem() dest_value := reflect.New(field.reflectValue.Type()).Elem()
m := &Model{data: field.Value, do: s} m := &Model{data: field.Value, do: s}
for _, f := range m.columnsHasValue("other") { for _, f := range m.columnsHasValue("other") {
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
@ -170,6 +166,7 @@ func (s *Do) saveAfterAssociations() {
} }
func (s *Do) create() (i interface{}) { func (s *Do) create() (i interface{}) {
defer s.trace(time.Now())
s.model.callMethod("BeforeCreate") s.model.callMethod("BeforeCreate")
s.model.callMethod("BeforeSave") s.model.callMethod("BeforeSave")
@ -178,8 +175,6 @@ func (s *Do) create() (i interface{}) {
if !s.db.hasError() { if !s.db.hasError() {
var id interface{} var id interface{}
now := time.Now()
if s.dialect().SupportLastInsertId() { if s.dialect().SupportLastInsertId() {
if sql_result, err := s.db.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil { if sql_result, err := s.db.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil {
id, err = sql_result.LastInsertId() id, err = sql_result.LastInsertId()
@ -188,7 +183,6 @@ func (s *Do) create() (i interface{}) {
} else { } else {
s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(&id)) s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
} }
s.db.slog(s.sql, now, s.sqlVars...)
if !s.db.hasError() { if !s.db.hasError() {
s.model.setValueByColumn(s.model.primaryKey(), id, s.value) s.model.setValueByColumn(s.model.primaryKey(), id, s.value)
@ -348,6 +342,7 @@ func (s *Do) related(value interface{}, foreign_keys ...string) *Do {
} }
func (s *Do) query() *Do { func (s *Do) query() *Do {
defer s.trace(time.Now())
var ( var (
is_slice bool is_slice bool
dest_type reflect.Type dest_type reflect.Type
@ -363,9 +358,7 @@ func (s *Do) query() *Do {
s.prepareQuerySql() s.prepareQuerySql()
if !s.db.hasError() { if !s.db.hasError() {
now := time.Now()
rows, err := s.db.db.Query(s.sql, s.sqlVars...) rows, err := s.db.db.Query(s.sql, s.sqlVars...)
s.db.slog(s.sql, now, s.sqlVars...)
if s.err(err) != nil { if s.err(err) != nil {
return s return s
@ -406,17 +399,19 @@ func (s *Do) query() *Do {
} }
func (s *Do) count(value interface{}) *Do { func (s *Do) count(value interface{}) *Do {
defer s.trace(time.Now())
s.search = s.search.clone().selects("count(*)") s.search = s.search.clone().selects("count(*)")
s.prepareQuerySql() s.prepareQuerySql()
if !s.db.hasError() { if !s.db.hasError() {
now := time.Now()
s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(value)) s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(value))
s.db.slog(s.sql, now, s.sqlVars...)
} }
return s return s
} }
func (s *Do) pluck(column string, value interface{}) *Do { func (s *Do) pluck(column string, value interface{}) *Do {
defer s.trace(time.Now())
dest_out := reflect.Indirect(reflect.ValueOf(value)) dest_out := reflect.Indirect(reflect.ValueOf(value))
s.search = s.search.clone().selects(column) s.search = s.search.clone().selects(column)
if dest_out.Kind() != reflect.Slice { if dest_out.Kind() != reflect.Slice {
@ -427,9 +422,7 @@ func (s *Do) pluck(column string, value interface{}) *Do {
s.prepareQuerySql() s.prepareQuerySql()
if !s.db.hasError() { if !s.db.hasError() {
now := time.Now()
rows, err := s.db.db.Query(s.sql, s.sqlVars...) rows, err := s.db.db.Query(s.sql, s.sqlVars...)
s.db.slog(s.sql, now, s.sqlVars...)
if s.err(err) == nil { if s.err(err) == nil {
defer rows.Close() defer rows.Close()
@ -653,7 +646,6 @@ func (s *Do) createTable() *Do {
} }
s.sql = fmt.Sprintf("CREATE TABLE %v (%v)", s.table(), strings.Join(sqls, ",")) s.sql = fmt.Sprintf("CREATE TABLE %v (%v)", s.table(), strings.Join(sqls, ","))
s.exec() s.exec()
return s return s
} }

View File

@ -1242,6 +1242,7 @@ func (c Cart) TableName() string {
} }
func TestTableName(t *testing.T) { func TestTableName(t *testing.T) {
db := db.clone()
if db.do(Order{}).table() != "orders" { if db.do(Order{}).table() != "orders" {
t.Errorf("Order table name should be orders") t.Errorf("Order table name should be orders")
} }
@ -1393,7 +1394,7 @@ func (s *CreditCard) BeforeSave() (err error) {
} }
func BenchmarkGorm(b *testing.B) { func BenchmarkGorm(b *testing.B) {
b.N = 5000 b.N = 2000
for x := 0; x < b.N; x++ { for x := 0; x < b.N; x++ {
e := strconv.Itoa(x) + "benchmark@example.org" e := strconv.Itoa(x) + "benchmark@example.org"
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
@ -1416,7 +1417,7 @@ func BenchmarkRawSql(b *testing.B) {
update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3" update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3"
delete_sql := "DELETE FROM orders WHERE id = $1" delete_sql := "DELETE FROM orders WHERE id = $1"
b.N = 5000 b.N = 2000
for x := 0; x < b.N; x++ { for x := 0; x < b.N; x++ {
var id int64 var id int64
e := strconv.Itoa(x) + "benchmark@example.org" e := strconv.Itoa(x) + "benchmark@example.org"

View File

@ -24,8 +24,7 @@ func (m *Model) primaryKeyZero() bool {
func (m *Model) primaryKeyValue() interface{} { func (m *Model) primaryKeyValue() interface{} {
if data := m.reflectData(); data.Kind() == reflect.Struct { if data := m.reflectData(); data.Kind() == reflect.Struct {
field := data.FieldByName(m.primaryKey()) if field := data.FieldByName(m.primaryKey()); field.IsValid() {
if data.FieldByName(m.primaryKey()).IsValid() {
return field.Interface() return field.Interface()
} }
} }
@ -133,7 +132,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} {
results := map[string]interface{}{} results := map[string]interface{}{}
for _, field := range m.fields(operation) { for _, field := range m.fields(operation) {
if !field.isPrimaryKey && (len(field.sqlTag()) > 0) { if !field.isPrimaryKey && len(field.sqlTag()) > 0 {
results[field.dbName] = field.Value results[field.dbName] = field.Value
} }
} }
@ -141,9 +140,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} {
} }
func (m *Model) hasColumn(name string) bool { func (m *Model) hasColumn(name string) bool {
data := m.reflectData() if data := m.reflectData(); data.Kind() == reflect.Struct {
if data.Kind() == reflect.Struct {
return data.FieldByName(name).IsValid() return data.FieldByName(name).IsValid()
} else if data.Kind() == reflect.Slice { } else if data.Kind() == reflect.Slice {
return reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid() return reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
@ -152,9 +149,7 @@ func (m *Model) hasColumn(name string) bool {
} }
func (m *Model) columnAndValue(name string) (has_column bool, is_slice bool, value interface{}) { func (m *Model) columnAndValue(name string) (has_column bool, is_slice bool, value interface{}) {
data := m.reflectData() if data := m.reflectData(); data.Kind() == reflect.Struct {
if data.Kind() == reflect.Struct {
if has_column = data.FieldByName(name).IsValid(); has_column { if has_column = data.FieldByName(name).IsValid(); has_column {
value = data.FieldByName(name).Interface() value = data.FieldByName(name).Interface()
} }
@ -165,15 +160,19 @@ func (m *Model) columnAndValue(name string) (has_column bool, is_slice bool, val
return return
} }
func (m *Model) typeName() string { func (m *Model) typ() reflect.Type {
typ := m.reflectData().Type() typ := m.reflectData().Type()
if typ.Kind() == reflect.Slice { if typ.Kind() == reflect.Slice {
return typ.Elem().Name() return typ.Elem()
} else { } else {
return typ.Name() return typ
} }
} }
func (m *Model) typeName() string {
return m.typ().Name()
}
func (m *Model) tableName() (str string) { func (m *Model) tableName() (str string) {
if m.data == nil { if m.data == nil {
m.do.err(errors.New("Model haven't been set")) m.do.err(errors.New("Model haven't been set"))