mirror of https://github.com/go-gorm/gorm.git
Make save sub structs works
This commit is contained in:
parent
5b671a84b6
commit
aa352d405b
32
do.go
32
do.go
|
@ -88,11 +88,11 @@ func (s *Do) exec(sql ...string) {
|
||||||
s.err(err)
|
s.err(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) save() {
|
func (s *Do) save() (i int64) {
|
||||||
if s.model.primaryKeyZero() {
|
if s.model.primaryKeyZero() {
|
||||||
s.create()
|
return s.create()
|
||||||
} else {
|
} else {
|
||||||
s.update()
|
return s.update()
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -118,7 +118,10 @@ func (s *Do) prepareCreateSql() {
|
||||||
func (s *Do) saveBeforeAssociations() {
|
func (s *Do) saveBeforeAssociations() {
|
||||||
for _, field := range s.model.beforeAssociations() {
|
for _, field := range s.model.beforeAssociations() {
|
||||||
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
||||||
do.setModel(field.Value).save()
|
id := do.setModel(field.Value).save()
|
||||||
|
if len(field.foreignKey) > 0 {
|
||||||
|
s.model.setValueByColumn(field.foreignKey, id, s.model.data)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -128,8 +131,12 @@ func (s *Do) saveAfterAssociations() {
|
||||||
switch reflect.TypeOf(field.Value).Kind() {
|
switch reflect.TypeOf(field.Value).Kind() {
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
for i := 0; i < reflect_value.Len(); i++ {
|
for i := 0; i < reflect_value.Len(); i++ {
|
||||||
|
value := reflect_value.Index(i).Addr().Interface()
|
||||||
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
||||||
do.setModel(reflect_value.Index(i).Addr().Interface()).save()
|
if len(field.foreignKey) > 0 {
|
||||||
|
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), value)
|
||||||
|
}
|
||||||
|
do.setModel(value).save()
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
||||||
|
@ -138,7 +145,7 @@ func (s *Do) saveAfterAssociations() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) create() {
|
func (s *Do) create() (i int64) {
|
||||||
s.err(s.model.callMethod("BeforeCreate"))
|
s.err(s.model.callMethod("BeforeCreate"))
|
||||||
s.err(s.model.callMethod("BeforeSave"))
|
s.err(s.model.callMethod("BeforeSave"))
|
||||||
|
|
||||||
|
@ -167,6 +174,7 @@ func (s *Do) create() {
|
||||||
s.err(s.model.callMethod("AfterCreate"))
|
s.err(s.model.callMethod("AfterCreate"))
|
||||||
s.err(s.model.callMethod("AfterSave"))
|
s.err(s.model.callMethod("AfterSave"))
|
||||||
}
|
}
|
||||||
|
return id
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -221,7 +229,7 @@ func (s *Do) prepareUpdateSql(results map[string]interface{}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) update() {
|
func (s *Do) update() (i int64) {
|
||||||
update_attrs := s.updateAttrs
|
update_attrs := s.updateAttrs
|
||||||
if len(update_attrs) > 0 {
|
if len(update_attrs) > 0 {
|
||||||
var need_update bool
|
var need_update bool
|
||||||
|
@ -246,7 +254,8 @@ func (s *Do) update() {
|
||||||
s.err(s.model.callMethod("AfterSave"))
|
s.err(s.model.callMethod("AfterSave"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
|
||||||
|
return s.model.primaryKeyValue()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) prepareDeleteSql() {
|
func (s *Do) prepareDeleteSql() {
|
||||||
|
@ -318,7 +327,12 @@ func (s *Do) query() {
|
||||||
for _, value := range columns {
|
for _, value := range columns {
|
||||||
field := dest.FieldByName(snakeToUpperCamel(value))
|
field := dest.FieldByName(snakeToUpperCamel(value))
|
||||||
if field.IsValid() {
|
if field.IsValid() {
|
||||||
values = append(values, field.Addr().Interface())
|
if field.CanAddr() {
|
||||||
|
values = append(values, field.Addr().Interface())
|
||||||
|
} else {
|
||||||
|
s.err(errors.New(fmt.Sprintf("Can't take address of %v, should be ptr", dest)))
|
||||||
|
return
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
var null interface{}
|
var null interface{}
|
||||||
values = append(values, &null)
|
values = append(values, &null)
|
||||||
|
|
19
gorm_test.go
19
gorm_test.go
|
@ -1041,11 +1041,28 @@ func TestSubStruct(t *testing.T) {
|
||||||
t.Errorf("Category should be saved")
|
t.Errorf("Category should be saved")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var p Post
|
||||||
|
db.First(&p, post.Id)
|
||||||
|
if post.CategoryId == 0 || p.CategoryId == 0 {
|
||||||
|
t.Errorf("Category Id should exist")
|
||||||
|
}
|
||||||
|
|
||||||
if db.First(&Comment{}, "content = ?", "Comment 1").Error != nil {
|
if db.First(&Comment{}, "content = ?", "Comment 1").Error != nil {
|
||||||
t.Errorf("Comment 1 should be saved")
|
t.Errorf("Comment 1 should be saved")
|
||||||
}
|
}
|
||||||
|
if post.Comments[0].PostId == 0 {
|
||||||
|
t.Errorf("Comment Should have post id")
|
||||||
|
}
|
||||||
|
|
||||||
if db.First(&Comment{}, "content = ?", "Comment 2").Error != nil {
|
var comment Comment
|
||||||
|
if db.First(&comment, "content = ?", "Comment 2").Error != nil {
|
||||||
t.Errorf("Comment 2 should be saved")
|
t.Errorf("Comment 2 should be saved")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if comment.PostId == 0 {
|
||||||
|
t.Errorf("Comment 2 Should have post id")
|
||||||
|
}
|
||||||
|
|
||||||
|
comment3 := Comment{Content: "Comment 3", Post: Post{Title: "Title 3", Body: "Body 3"}}
|
||||||
|
db.Save(&comment3)
|
||||||
}
|
}
|
||||||
|
|
6
model.go
6
model.go
|
@ -28,6 +28,7 @@ type Field struct {
|
||||||
|
|
||||||
beforeAssociation bool
|
beforeAssociation bool
|
||||||
afterAssociation bool
|
afterAssociation bool
|
||||||
|
foreignKey string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) primaryKeyZero() bool {
|
func (m *Model) primaryKeyZero() bool {
|
||||||
|
@ -139,12 +140,17 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
|
|
||||||
switch field_value.Kind() {
|
switch field_value.Kind() {
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
|
foreign_key := typ.Name() + "Id"
|
||||||
|
if reflect.New(field_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() {
|
||||||
|
field.foreignKey = foreign_key
|
||||||
|
}
|
||||||
field.afterAssociation = true
|
field.afterAssociation = true
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if is_time {
|
if is_time {
|
||||||
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
||||||
} else {
|
} else {
|
||||||
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
|
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
|
||||||
|
field.foreignKey = p.Name + "Id"
|
||||||
field.beforeAssociation = true
|
field.beforeAssociation = true
|
||||||
} else {
|
} else {
|
||||||
field.afterAssociation = true
|
field.afterAssociation = true
|
||||||
|
|
Loading…
Reference in New Issue