diff --git a/do.go b/do.go index 168a3f21..a6fe828e 100644 --- a/do.go +++ b/do.go @@ -61,9 +61,10 @@ func (s *Do) hasError() bool { return len(s.Errors) > 0 } -func (s *Do) setModel(value interface{}) { +func (s *Do) setModel(value interface{}) *Do { s.model = &Model{data: value, driver: s.driver} s.value = value + return s } func (s *Do) addToVars(value interface{}) string { @@ -114,9 +115,26 @@ func (s *Do) prepareCreateSql() { return } -func (s *Do) saveAssociation(typ string) { - if typ == "before" { - } else if typ == "after" { +func (s *Do) saveBeforeAssociations() { + for _, field := range s.model.beforeAssociations() { + do := &Do{chain: s.chain, db: s.db, driver: s.driver} + do.setModel(field.Value).save() + } +} + +func (s *Do) saveAfterAssociations() { + for _, field := range s.model.afterAssociations() { + reflect_value := reflect.ValueOf(field.Value) + switch reflect.TypeOf(field.Value).Kind() { + case reflect.Slice: + for i := 0; i < reflect_value.Len(); i++ { + do := &Do{chain: s.chain, db: s.db, driver: s.driver} + do.setModel(reflect_value.Index(i).Addr().Interface()).save() + } + default: + do := &Do{chain: s.chain, db: s.db, driver: s.driver} + do.setModel(field.Value).save() + } } } @@ -124,6 +142,7 @@ func (s *Do) create() { s.err(s.model.callMethod("BeforeCreate")) s.err(s.model.callMethod("BeforeSave")) + s.saveBeforeAssociations() s.prepareCreateSql() if !s.hasError() { @@ -139,8 +158,9 @@ func (s *Do) create() { } if !s.hasError() { - result := reflect.ValueOf(s.value).Elem() + result := reflect.Indirect(reflect.ValueOf(s.value)) setFieldValue(result.FieldByName(s.model.primaryKey()), id) + s.saveAfterAssociations() s.err(s.model.callMethod("AfterCreate")) s.err(s.model.callMethod("AfterSave")) @@ -212,10 +232,12 @@ func (s *Do) update() { s.err(s.model.callMethod("BeforeUpdate")) s.err(s.model.callMethod("BeforeSave")) + s.saveBeforeAssociations() s.prepareUpdateSql(update_attrs) if !s.hasError() { s.exec() + s.saveAfterAssociations() if !s.hasError() { s.err(s.model.callMethod("AfterUpdate")) diff --git a/gorm_test.go b/gorm_test.go index d3d5152b..6cacb7de 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -1016,6 +1016,10 @@ type Comment struct { } func TestSubStruct(t *testing.T) { + db.DropTable(Category{}) + db.DropTable(Post{}) + db.DropTable(Comment{}) + db.CreateTable(Category{}) db.CreateTable(Post{}) db.CreateTable(Comment{}) @@ -1034,4 +1038,12 @@ func TestSubStruct(t *testing.T) { if db.First(&Category{}, "name = ?", "Category 1").Error != nil { t.Errorf("Category should be saved") } + + if db.First(&Comment{}, "content = ?", "Comment 1").Error != nil { + t.Errorf("Comment 1 should be saved") + } + + if db.First(&Comment{}, "content = ?", "Comment 2").Error != nil { + t.Errorf("Comment 2 should be saved") + } } diff --git a/model.go b/model.go index fb30bd8e..8bb39e58 100644 --- a/model.go +++ b/model.go @@ -25,6 +25,9 @@ type Field struct { AutoUpdateTime bool IsPrimaryKey bool IsBlank bool + + beforeAssociation bool + afterAssociation bool } func (m *Model) primaryKeyZero() bool { @@ -66,10 +69,14 @@ func (m *Model) primaryKeyDb() string { func (m *Model) fields(operation string) (fields []Field) { if len(m._cache_fields[operation]) > 0 { - return + return m._cache_fields[operation] } indirect_value := reflect.Indirect(reflect.ValueOf(m.data)) + if !indirect_value.IsValid() { + return + } + typ := indirect_value.Type() for i := 0; i < typ.NumField(); i++ { @@ -89,9 +96,19 @@ func (m *Model) fields(operation string) (fields []Field) { field.IsBlank = value.Int() == 0 case reflect.String: field.IsBlank = value.String() == "" - default: + case reflect.Slice: + if value.Len() == 0 { + field.IsBlank = true + } + case reflect.Struct: if is_time { field.IsBlank = time_value.IsZero() + } else { + m := &Model{data: value.Interface(), driver: m.driver} + fields := m.columnsHasValue("other") + if len(fields) == 0 { + field.IsBlank = true + } } } @@ -115,9 +132,16 @@ func (m *Model) fields(operation string) (fields []Field) { } else { switch reflect.TypeOf(field.Value).Kind() { case reflect.Slice: + field.afterAssociation = true case reflect.Struct: if is_time { field.SqlType = getSqlType(m.driver, field.Value, 0) + } else { + if indirect_value.FieldByName(p.Name + "Id").IsValid() { + field.beforeAssociation = true + } else { + field.afterAssociation = true + } } default: field.SqlType = getSqlType(m.driver, field.Value, 0) @@ -258,8 +282,26 @@ func (m *Model) setValueByColumn(name string, value interface{}, out interface{} setFieldValue(data.FieldByName(snakeToUpperCamel(name)), value) } +func (m *Model) beforeAssociations() (fields []Field) { + for _, field := range m.fields("null") { + if field.beforeAssociation && !field.IsBlank { + fields = append(fields, field) + } + } + return +} + +func (m *Model) afterAssociations() (fields []Field) { + for _, field := range m.fields("null") { + if field.afterAssociation && !field.IsBlank { + fields = append(fields, field) + } + } + return +} + func setFieldValue(field reflect.Value, value interface{}) { - if field.IsValid() { + if field.IsValid() && field.CanAddr() { switch field.Kind() { case reflect.Int, reflect.Int32, reflect.Int64: if str, ok := value.(string); ok {