Make save sub structs works

This commit is contained in:
Jinzhu 2013-11-02 20:05:05 +08:00
parent 5b671a84b6
commit aa352d405b
3 changed files with 47 additions and 10 deletions

30
do.go
View File

@ -88,11 +88,11 @@ func (s *Do) exec(sql ...string) {
s.err(err) s.err(err)
} }
func (s *Do) save() { func (s *Do) save() (i int64) {
if s.model.primaryKeyZero() { if s.model.primaryKeyZero() {
s.create() return s.create()
} else { } else {
s.update() return s.update()
} }
return return
} }
@ -118,7 +118,10 @@ func (s *Do) prepareCreateSql() {
func (s *Do) saveBeforeAssociations() { func (s *Do) saveBeforeAssociations() {
for _, field := range s.model.beforeAssociations() { for _, field := range s.model.beforeAssociations() {
do := &Do{chain: s.chain, db: s.db, driver: s.driver} do := &Do{chain: s.chain, db: s.db, driver: s.driver}
do.setModel(field.Value).save() id := do.setModel(field.Value).save()
if len(field.foreignKey) > 0 {
s.model.setValueByColumn(field.foreignKey, id, s.model.data)
}
} }
} }
@ -128,8 +131,12 @@ func (s *Do) saveAfterAssociations() {
switch reflect.TypeOf(field.Value).Kind() { switch reflect.TypeOf(field.Value).Kind() {
case reflect.Slice: case reflect.Slice:
for i := 0; i < reflect_value.Len(); i++ { for i := 0; i < reflect_value.Len(); i++ {
value := reflect_value.Index(i).Addr().Interface()
do := &Do{chain: s.chain, db: s.db, driver: s.driver} do := &Do{chain: s.chain, db: s.db, driver: s.driver}
do.setModel(reflect_value.Index(i).Addr().Interface()).save() if len(field.foreignKey) > 0 {
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), value)
}
do.setModel(value).save()
} }
default: default:
do := &Do{chain: s.chain, db: s.db, driver: s.driver} do := &Do{chain: s.chain, db: s.db, driver: s.driver}
@ -138,7 +145,7 @@ func (s *Do) saveAfterAssociations() {
} }
} }
func (s *Do) create() { func (s *Do) create() (i int64) {
s.err(s.model.callMethod("BeforeCreate")) s.err(s.model.callMethod("BeforeCreate"))
s.err(s.model.callMethod("BeforeSave")) s.err(s.model.callMethod("BeforeSave"))
@ -167,6 +174,7 @@ func (s *Do) create() {
s.err(s.model.callMethod("AfterCreate")) s.err(s.model.callMethod("AfterCreate"))
s.err(s.model.callMethod("AfterSave")) s.err(s.model.callMethod("AfterSave"))
} }
return id
} }
return return
@ -221,7 +229,7 @@ func (s *Do) prepareUpdateSql(results map[string]interface{}) {
return return
} }
func (s *Do) update() { func (s *Do) update() (i int64) {
update_attrs := s.updateAttrs update_attrs := s.updateAttrs
if len(update_attrs) > 0 { if len(update_attrs) > 0 {
var need_update bool var need_update bool
@ -246,7 +254,8 @@ func (s *Do) update() {
s.err(s.model.callMethod("AfterSave")) s.err(s.model.callMethod("AfterSave"))
} }
} }
return
return s.model.primaryKeyValue()
} }
func (s *Do) prepareDeleteSql() { func (s *Do) prepareDeleteSql() {
@ -318,7 +327,12 @@ func (s *Do) query() {
for _, value := range columns { for _, value := range columns {
field := dest.FieldByName(snakeToUpperCamel(value)) field := dest.FieldByName(snakeToUpperCamel(value))
if field.IsValid() { if field.IsValid() {
if field.CanAddr() {
values = append(values, field.Addr().Interface()) values = append(values, field.Addr().Interface())
} else {
s.err(errors.New(fmt.Sprintf("Can't take address of %v, should be ptr", dest)))
return
}
} else { } else {
var null interface{} var null interface{}
values = append(values, &null) values = append(values, &null)

View File

@ -1041,11 +1041,28 @@ func TestSubStruct(t *testing.T) {
t.Errorf("Category should be saved") t.Errorf("Category should be saved")
} }
var p Post
db.First(&p, post.Id)
if post.CategoryId == 0 || p.CategoryId == 0 {
t.Errorf("Category Id should exist")
}
if db.First(&Comment{}, "content = ?", "Comment 1").Error != nil { if db.First(&Comment{}, "content = ?", "Comment 1").Error != nil {
t.Errorf("Comment 1 should be saved") t.Errorf("Comment 1 should be saved")
} }
if post.Comments[0].PostId == 0 {
t.Errorf("Comment Should have post id")
}
if db.First(&Comment{}, "content = ?", "Comment 2").Error != nil { var comment Comment
if db.First(&comment, "content = ?", "Comment 2").Error != nil {
t.Errorf("Comment 2 should be saved") t.Errorf("Comment 2 should be saved")
} }
if comment.PostId == 0 {
t.Errorf("Comment 2 Should have post id")
}
comment3 := Comment{Content: "Comment 3", Post: Post{Title: "Title 3", Body: "Body 3"}}
db.Save(&comment3)
} }

View File

@ -28,6 +28,7 @@ type Field struct {
beforeAssociation bool beforeAssociation bool
afterAssociation bool afterAssociation bool
foreignKey string
} }
func (m *Model) primaryKeyZero() bool { func (m *Model) primaryKeyZero() bool {
@ -139,12 +140,17 @@ func (m *Model) fields(operation string) (fields []Field) {
switch field_value.Kind() { switch field_value.Kind() {
case reflect.Slice: case reflect.Slice:
foreign_key := typ.Name() + "Id"
if reflect.New(field_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() {
field.foreignKey = foreign_key
}
field.afterAssociation = true field.afterAssociation = true
case reflect.Struct: case reflect.Struct:
if is_time { if is_time {
field.SqlType = getSqlType(m.driver, field.Value, 0) field.SqlType = getSqlType(m.driver, field.Value, 0)
} else { } else {
if indirect_value.FieldByName(p.Name + "Id").IsValid() { if indirect_value.FieldByName(p.Name + "Id").IsValid() {
field.foreignKey = p.Name + "Id"
field.beforeAssociation = true field.beforeAssociation = true
} else { } else {
field.afterAssociation = true field.afterAssociation = true