forked from mirror/gorm
Make save sub structs works
This commit is contained in:
parent
5b671a84b6
commit
aa352d405b
32
do.go
32
do.go
|
@ -88,11 +88,11 @@ func (s *Do) exec(sql ...string) {
|
|||
s.err(err)
|
||||
}
|
||||
|
||||
func (s *Do) save() {
|
||||
func (s *Do) save() (i int64) {
|
||||
if s.model.primaryKeyZero() {
|
||||
s.create()
|
||||
return s.create()
|
||||
} else {
|
||||
s.update()
|
||||
return s.update()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -118,7 +118,10 @@ func (s *Do) prepareCreateSql() {
|
|||
func (s *Do) saveBeforeAssociations() {
|
||||
for _, field := range s.model.beforeAssociations() {
|
||||
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
||||
do.setModel(field.Value).save()
|
||||
id := do.setModel(field.Value).save()
|
||||
if len(field.foreignKey) > 0 {
|
||||
s.model.setValueByColumn(field.foreignKey, id, s.model.data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -128,8 +131,12 @@ func (s *Do) saveAfterAssociations() {
|
|||
switch reflect.TypeOf(field.Value).Kind() {
|
||||
case reflect.Slice:
|
||||
for i := 0; i < reflect_value.Len(); i++ {
|
||||
value := reflect_value.Index(i).Addr().Interface()
|
||||
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
||||
do.setModel(reflect_value.Index(i).Addr().Interface()).save()
|
||||
if len(field.foreignKey) > 0 {
|
||||
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), value)
|
||||
}
|
||||
do.setModel(value).save()
|
||||
}
|
||||
default:
|
||||
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
||||
|
@ -138,7 +145,7 @@ func (s *Do) saveAfterAssociations() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Do) create() {
|
||||
func (s *Do) create() (i int64) {
|
||||
s.err(s.model.callMethod("BeforeCreate"))
|
||||
s.err(s.model.callMethod("BeforeSave"))
|
||||
|
||||
|
@ -167,6 +174,7 @@ func (s *Do) create() {
|
|||
s.err(s.model.callMethod("AfterCreate"))
|
||||
s.err(s.model.callMethod("AfterSave"))
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
return
|
||||
|
@ -221,7 +229,7 @@ func (s *Do) prepareUpdateSql(results map[string]interface{}) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *Do) update() {
|
||||
func (s *Do) update() (i int64) {
|
||||
update_attrs := s.updateAttrs
|
||||
if len(update_attrs) > 0 {
|
||||
var need_update bool
|
||||
|
@ -246,7 +254,8 @@ func (s *Do) update() {
|
|||
s.err(s.model.callMethod("AfterSave"))
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
return s.model.primaryKeyValue()
|
||||
}
|
||||
|
||||
func (s *Do) prepareDeleteSql() {
|
||||
|
@ -318,7 +327,12 @@ func (s *Do) query() {
|
|||
for _, value := range columns {
|
||||
field := dest.FieldByName(snakeToUpperCamel(value))
|
||||
if field.IsValid() {
|
||||
values = append(values, field.Addr().Interface())
|
||||
if field.CanAddr() {
|
||||
values = append(values, field.Addr().Interface())
|
||||
} else {
|
||||
s.err(errors.New(fmt.Sprintf("Can't take address of %v, should be ptr", dest)))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
var null interface{}
|
||||
values = append(values, &null)
|
||||
|
|
19
gorm_test.go
19
gorm_test.go
|
@ -1041,11 +1041,28 @@ func TestSubStruct(t *testing.T) {
|
|||
t.Errorf("Category should be saved")
|
||||
}
|
||||
|
||||
var p Post
|
||||
db.First(&p, post.Id)
|
||||
if post.CategoryId == 0 || p.CategoryId == 0 {
|
||||
t.Errorf("Category Id should exist")
|
||||
}
|
||||
|
||||
if db.First(&Comment{}, "content = ?", "Comment 1").Error != nil {
|
||||
t.Errorf("Comment 1 should be saved")
|
||||
}
|
||||
if post.Comments[0].PostId == 0 {
|
||||
t.Errorf("Comment Should have post id")
|
||||
}
|
||||
|
||||
if db.First(&Comment{}, "content = ?", "Comment 2").Error != nil {
|
||||
var comment Comment
|
||||
if db.First(&comment, "content = ?", "Comment 2").Error != nil {
|
||||
t.Errorf("Comment 2 should be saved")
|
||||
}
|
||||
|
||||
if comment.PostId == 0 {
|
||||
t.Errorf("Comment 2 Should have post id")
|
||||
}
|
||||
|
||||
comment3 := Comment{Content: "Comment 3", Post: Post{Title: "Title 3", Body: "Body 3"}}
|
||||
db.Save(&comment3)
|
||||
}
|
||||
|
|
6
model.go
6
model.go
|
@ -28,6 +28,7 @@ type Field struct {
|
|||
|
||||
beforeAssociation bool
|
||||
afterAssociation bool
|
||||
foreignKey string
|
||||
}
|
||||
|
||||
func (m *Model) primaryKeyZero() bool {
|
||||
|
@ -139,12 +140,17 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
|
||||
switch field_value.Kind() {
|
||||
case reflect.Slice:
|
||||
foreign_key := typ.Name() + "Id"
|
||||
if reflect.New(field_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() {
|
||||
field.foreignKey = foreign_key
|
||||
}
|
||||
field.afterAssociation = true
|
||||
case reflect.Struct:
|
||||
if is_time {
|
||||
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
||||
} else {
|
||||
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
|
||||
field.foreignKey = p.Name + "Id"
|
||||
field.beforeAssociation = true
|
||||
} else {
|
||||
field.afterAssociation = true
|
||||
|
|
Loading…
Reference in New Issue