Cache fields

This commit is contained in:
Jinzhu 2013-11-02 15:17:11 +08:00
parent 28b49124eb
commit 8c36a5d193
3 changed files with 29 additions and 10 deletions

18
do.go
View File

@ -114,6 +114,12 @@ func (s *Do) prepareCreateSql() {
return return
} }
func (s *Do) saveAssociation(typ string) {
if typ == "before" {
} else if typ == "after" {
}
}
func (s *Do) create() { func (s *Do) create() {
s.err(s.model.callMethod("BeforeCreate")) s.err(s.model.callMethod("BeforeCreate"))
s.err(s.model.callMethod("BeforeSave")) s.err(s.model.callMethod("BeforeSave"))
@ -154,7 +160,7 @@ func (s *Do) setUpdateAttrs(values interface{}, ignore_protected_attrs ...bool)
} }
case interface{}: case interface{}:
m := &Model{data: values, driver: s.driver} m := &Model{data: values, driver: s.driver}
fields := m.columnsHasValue("") fields := m.columnsHasValue("other")
s.updateAttrs = make(map[string]interface{}, len(fields)) s.updateAttrs = make(map[string]interface{}, len(fields))
for _, field := range fields { for _, field := range fields {
@ -401,7 +407,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
case interface{}: case interface{}:
m := &Model{data: query, driver: s.driver} m := &Model{data: query, driver: s.driver}
var sqls []string var sqls []string
for _, field := range m.columnsHasValue("") { for _, field := range m.columnsHasValue("other") {
sqls = append(sqls, fmt.Sprintf(" ( %v = %v ) ", field.DbName, s.addToVars(field.Value))) sqls = append(sqls, fmt.Sprintf(" ( %v = %v ) ", field.DbName, s.addToVars(field.Value)))
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
@ -459,7 +465,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
case interface{}: case interface{}:
m := &Model{data: query, driver: s.driver} m := &Model{data: query, driver: s.driver}
var sqls []string var sqls []string
for _, field := range m.columnsHasValue("") { for _, field := range m.columnsHasValue("other") {
sqls = append(sqls, fmt.Sprintf(" ( %v <> %v ) ", field.DbName, s.addToVars(field.Value))) sqls = append(sqls, fmt.Sprintf(" ( %v <> %v ) ", field.DbName, s.addToVars(field.Value)))
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
@ -565,7 +571,7 @@ func (s *Do) combinedSql() string {
func (s *Do) createTable() *Do { func (s *Do) createTable() *Do {
var sqls []string var sqls []string
for _, field := range s.model.fields("") { for _, field := range s.model.fields("create") {
if len(field.SqlType) > 0 { if len(field.SqlType) > 0 {
sqls = append(sqls, field.DbName+" "+field.SqlType) sqls = append(sqls, field.DbName+" "+field.SqlType)
} }
@ -602,7 +608,7 @@ func (s *Do) initializeWithSearchCondition() {
switch reflect.ValueOf(obj).Kind() { switch reflect.ValueOf(obj).Kind() {
case reflect.Struct: case reflect.Struct:
m := &Model{data: obj, driver: s.driver} m := &Model{data: obj, driver: s.driver}
for _, field := range m.columnsHasValue("") { for _, field := range m.columnsHasValue("other") {
m.setValueByColumn(field.DbName, field.Value, s.value) m.setValueByColumn(field.DbName, field.Value, s.value)
} }
case reflect.Map: case reflect.Map:
@ -613,7 +619,7 @@ func (s *Do) initializeWithSearchCondition() {
} }
case interface{}: case interface{}:
m := &Model{data: query, driver: s.driver} m := &Model{data: query, driver: s.driver}
for _, field := range m.columnsHasValue("") { for _, field := range m.columnsHasValue("other") {
m.setValueByColumn(field.DbName, field.Value, s.value) m.setValueByColumn(field.DbName, field.Value, s.value)
} }
} }

View File

@ -1024,10 +1024,14 @@ func TestSubStruct(t *testing.T) {
Title: "post 1", Title: "post 1",
Body: "body 1", Body: "body 1",
Comments: []Comment{{Content: "Comment 1"}, {Content: "Comment 2"}}, Comments: []Comment{{Content: "Comment 1"}, {Content: "Comment 2"}},
Category: Category{Name: "category"}, Category: Category{Name: "Category 1"},
} }
if err := db.Save(&post).Error; err != nil { if err := db.Save(&post).Error; err != nil {
t.Errorf("Got errors when save post", err) t.Errorf("Got errors when save post", err)
} }
if db.First(&Category{}, "name = ?", "Category 1").Error != nil {
t.Errorf("Category should be saved")
}
} }

View File

@ -11,8 +11,9 @@ import (
) )
type Model struct { type Model struct {
data interface{} data interface{}
driver string driver string
_cache_fields map[string][]Field
} }
type Field struct { type Field struct {
@ -64,6 +65,10 @@ func (m *Model) primaryKeyDb() string {
} }
func (m *Model) fields(operation string) (fields []Field) { func (m *Model) fields(operation string) (fields []Field) {
if len(m._cache_fields[operation]) > 0 {
return
}
indirect_value := reflect.Indirect(reflect.ValueOf(m.data)) indirect_value := reflect.Indirect(reflect.ValueOf(m.data))
typ := indirect_value.Type() typ := indirect_value.Type()
@ -121,8 +126,12 @@ func (m *Model) fields(operation string) (fields []Field) {
fields = append(fields, field) fields = append(fields, field)
} }
} }
return
if len(m._cache_fields) == 0 {
m._cache_fields = make(map[string][]Field)
}
m._cache_fields[operation] = fields
return
} }
func (m *Model) columnsHasValue(operation string) (fields []Field) { func (m *Model) columnsHasValue(operation string) (fields []Field) {