diff --git a/do.go b/do.go index 787ce376..20744818 100644 --- a/do.go +++ b/do.go @@ -88,11 +88,11 @@ func (s *Do) exec(sql ...string) { s.err(err) } -func (s *Do) save() { +func (s *Do) save() (i int64) { if s.model.primaryKeyZero() { - s.create() + return s.create() } else { - s.update() + return s.update() } return } @@ -118,7 +118,10 @@ func (s *Do) prepareCreateSql() { func (s *Do) saveBeforeAssociations() { for _, field := range s.model.beforeAssociations() { do := &Do{chain: s.chain, db: s.db, driver: s.driver} - do.setModel(field.Value).save() + id := do.setModel(field.Value).save() + if len(field.foreignKey) > 0 { + s.model.setValueByColumn(field.foreignKey, id, s.model.data) + } } } @@ -128,8 +131,12 @@ func (s *Do) saveAfterAssociations() { switch reflect.TypeOf(field.Value).Kind() { case reflect.Slice: for i := 0; i < reflect_value.Len(); i++ { + value := reflect_value.Index(i).Addr().Interface() do := &Do{chain: s.chain, db: s.db, driver: s.driver} - do.setModel(reflect_value.Index(i).Addr().Interface()).save() + if len(field.foreignKey) > 0 { + s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), value) + } + do.setModel(value).save() } default: do := &Do{chain: s.chain, db: s.db, driver: s.driver} @@ -138,7 +145,7 @@ func (s *Do) saveAfterAssociations() { } } -func (s *Do) create() { +func (s *Do) create() (i int64) { s.err(s.model.callMethod("BeforeCreate")) s.err(s.model.callMethod("BeforeSave")) @@ -167,6 +174,7 @@ func (s *Do) create() { s.err(s.model.callMethod("AfterCreate")) s.err(s.model.callMethod("AfterSave")) } + return id } return @@ -221,7 +229,7 @@ func (s *Do) prepareUpdateSql(results map[string]interface{}) { return } -func (s *Do) update() { +func (s *Do) update() (i int64) { update_attrs := s.updateAttrs if len(update_attrs) > 0 { var need_update bool @@ -246,7 +254,8 @@ func (s *Do) update() { s.err(s.model.callMethod("AfterSave")) } } - return + + return s.model.primaryKeyValue() } func (s *Do) prepareDeleteSql() { @@ -318,7 +327,12 @@ func (s *Do) query() { for _, value := range columns { field := dest.FieldByName(snakeToUpperCamel(value)) if field.IsValid() { - values = append(values, field.Addr().Interface()) + if field.CanAddr() { + values = append(values, field.Addr().Interface()) + } else { + s.err(errors.New(fmt.Sprintf("Can't take address of %v, should be ptr", dest))) + return + } } else { var null interface{} values = append(values, &null) diff --git a/gorm_test.go b/gorm_test.go index ece9db44..de9d8dc5 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -1041,11 +1041,28 @@ func TestSubStruct(t *testing.T) { t.Errorf("Category should be saved") } + var p Post + db.First(&p, post.Id) + if post.CategoryId == 0 || p.CategoryId == 0 { + t.Errorf("Category Id should exist") + } + if db.First(&Comment{}, "content = ?", "Comment 1").Error != nil { t.Errorf("Comment 1 should be saved") } + if post.Comments[0].PostId == 0 { + t.Errorf("Comment Should have post id") + } - if db.First(&Comment{}, "content = ?", "Comment 2").Error != nil { + var comment Comment + if db.First(&comment, "content = ?", "Comment 2").Error != nil { t.Errorf("Comment 2 should be saved") } + + if comment.PostId == 0 { + t.Errorf("Comment 2 Should have post id") + } + + comment3 := Comment{Content: "Comment 3", Post: Post{Title: "Title 3", Body: "Body 3"}} + db.Save(&comment3) } diff --git a/model.go b/model.go index f4b85e19..9e068c35 100644 --- a/model.go +++ b/model.go @@ -28,6 +28,7 @@ type Field struct { beforeAssociation bool afterAssociation bool + foreignKey string } func (m *Model) primaryKeyZero() bool { @@ -139,12 +140,17 @@ func (m *Model) fields(operation string) (fields []Field) { switch field_value.Kind() { case reflect.Slice: + foreign_key := typ.Name() + "Id" + if reflect.New(field_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() { + field.foreignKey = foreign_key + } field.afterAssociation = true case reflect.Struct: if is_time { field.SqlType = getSqlType(m.driver, field.Value, 0) } else { if indirect_value.FieldByName(p.Name + "Id").IsValid() { + field.foreignKey = p.Name + "Id" field.beforeAssociation = true } else { field.afterAssociation = true