diff --git a/do.go b/do.go index 8f3b0213..168a3f21 100644 --- a/do.go +++ b/do.go @@ -114,6 +114,12 @@ func (s *Do) prepareCreateSql() { return } +func (s *Do) saveAssociation(typ string) { + if typ == "before" { + } else if typ == "after" { + } +} + func (s *Do) create() { s.err(s.model.callMethod("BeforeCreate")) s.err(s.model.callMethod("BeforeSave")) @@ -154,7 +160,7 @@ func (s *Do) setUpdateAttrs(values interface{}, ignore_protected_attrs ...bool) } case interface{}: m := &Model{data: values, driver: s.driver} - fields := m.columnsHasValue("") + fields := m.columnsHasValue("other") s.updateAttrs = make(map[string]interface{}, len(fields)) for _, field := range fields { @@ -401,7 +407,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { case interface{}: m := &Model{data: query, driver: s.driver} var sqls []string - for _, field := range m.columnsHasValue("") { + for _, field := range m.columnsHasValue("other") { sqls = append(sqls, fmt.Sprintf(" ( %v = %v ) ", field.DbName, s.addToVars(field.Value))) } return strings.Join(sqls, " AND ") @@ -459,7 +465,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) { case interface{}: m := &Model{data: query, driver: s.driver} var sqls []string - for _, field := range m.columnsHasValue("") { + for _, field := range m.columnsHasValue("other") { sqls = append(sqls, fmt.Sprintf(" ( %v <> %v ) ", field.DbName, s.addToVars(field.Value))) } return strings.Join(sqls, " AND ") @@ -565,7 +571,7 @@ func (s *Do) combinedSql() string { func (s *Do) createTable() *Do { var sqls []string - for _, field := range s.model.fields("") { + for _, field := range s.model.fields("create") { if len(field.SqlType) > 0 { sqls = append(sqls, field.DbName+" "+field.SqlType) } @@ -602,7 +608,7 @@ func (s *Do) initializeWithSearchCondition() { switch reflect.ValueOf(obj).Kind() { case reflect.Struct: m := &Model{data: obj, driver: s.driver} - for _, field := range m.columnsHasValue("") { + for _, field := range m.columnsHasValue("other") { m.setValueByColumn(field.DbName, field.Value, s.value) } case reflect.Map: @@ -613,7 +619,7 @@ func (s *Do) initializeWithSearchCondition() { } case interface{}: m := &Model{data: query, driver: s.driver} - for _, field := range m.columnsHasValue("") { + for _, field := range m.columnsHasValue("other") { m.setValueByColumn(field.DbName, field.Value, s.value) } } diff --git a/gorm_test.go b/gorm_test.go index e975a4ae..d3d5152b 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -1024,10 +1024,14 @@ func TestSubStruct(t *testing.T) { Title: "post 1", Body: "body 1", Comments: []Comment{{Content: "Comment 1"}, {Content: "Comment 2"}}, - Category: Category{Name: "category"}, + Category: Category{Name: "Category 1"}, } if err := db.Save(&post).Error; err != nil { t.Errorf("Got errors when save post", err) } + + if db.First(&Category{}, "name = ?", "Category 1").Error != nil { + t.Errorf("Category should be saved") + } } diff --git a/model.go b/model.go index 17f80388..fb30bd8e 100644 --- a/model.go +++ b/model.go @@ -11,8 +11,9 @@ import ( ) type Model struct { - data interface{} - driver string + data interface{} + driver string + _cache_fields map[string][]Field } type Field struct { @@ -64,6 +65,10 @@ func (m *Model) primaryKeyDb() string { } func (m *Model) fields(operation string) (fields []Field) { + if len(m._cache_fields[operation]) > 0 { + return + } + indirect_value := reflect.Indirect(reflect.ValueOf(m.data)) typ := indirect_value.Type() @@ -121,8 +126,12 @@ func (m *Model) fields(operation string) (fields []Field) { fields = append(fields, field) } } - return + if len(m._cache_fields) == 0 { + m._cache_fields = make(map[string][]Field) + } + m._cache_fields[operation] = fields + return } func (m *Model) columnsHasValue(operation string) (fields []Field) {