forked from mirror/gorm
Save substructs successfully
This commit is contained in:
parent
8c36a5d193
commit
b9f4a59772
32
do.go
32
do.go
|
@ -61,9 +61,10 @@ func (s *Do) hasError() bool {
|
|||
return len(s.Errors) > 0
|
||||
}
|
||||
|
||||
func (s *Do) setModel(value interface{}) {
|
||||
func (s *Do) setModel(value interface{}) *Do {
|
||||
s.model = &Model{data: value, driver: s.driver}
|
||||
s.value = value
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Do) addToVars(value interface{}) string {
|
||||
|
@ -114,9 +115,26 @@ func (s *Do) prepareCreateSql() {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *Do) saveAssociation(typ string) {
|
||||
if typ == "before" {
|
||||
} else if typ == "after" {
|
||||
func (s *Do) saveBeforeAssociations() {
|
||||
for _, field := range s.model.beforeAssociations() {
|
||||
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
||||
do.setModel(field.Value).save()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Do) saveAfterAssociations() {
|
||||
for _, field := range s.model.afterAssociations() {
|
||||
reflect_value := reflect.ValueOf(field.Value)
|
||||
switch reflect.TypeOf(field.Value).Kind() {
|
||||
case reflect.Slice:
|
||||
for i := 0; i < reflect_value.Len(); i++ {
|
||||
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
||||
do.setModel(reflect_value.Index(i).Addr().Interface()).save()
|
||||
}
|
||||
default:
|
||||
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
||||
do.setModel(field.Value).save()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -124,6 +142,7 @@ func (s *Do) create() {
|
|||
s.err(s.model.callMethod("BeforeCreate"))
|
||||
s.err(s.model.callMethod("BeforeSave"))
|
||||
|
||||
s.saveBeforeAssociations()
|
||||
s.prepareCreateSql()
|
||||
|
||||
if !s.hasError() {
|
||||
|
@ -139,8 +158,9 @@ func (s *Do) create() {
|
|||
}
|
||||
|
||||
if !s.hasError() {
|
||||
result := reflect.ValueOf(s.value).Elem()
|
||||
result := reflect.Indirect(reflect.ValueOf(s.value))
|
||||
setFieldValue(result.FieldByName(s.model.primaryKey()), id)
|
||||
s.saveAfterAssociations()
|
||||
|
||||
s.err(s.model.callMethod("AfterCreate"))
|
||||
s.err(s.model.callMethod("AfterSave"))
|
||||
|
@ -212,10 +232,12 @@ func (s *Do) update() {
|
|||
s.err(s.model.callMethod("BeforeUpdate"))
|
||||
s.err(s.model.callMethod("BeforeSave"))
|
||||
|
||||
s.saveBeforeAssociations()
|
||||
s.prepareUpdateSql(update_attrs)
|
||||
|
||||
if !s.hasError() {
|
||||
s.exec()
|
||||
s.saveAfterAssociations()
|
||||
|
||||
if !s.hasError() {
|
||||
s.err(s.model.callMethod("AfterUpdate"))
|
||||
|
|
12
gorm_test.go
12
gorm_test.go
|
@ -1016,6 +1016,10 @@ type Comment struct {
|
|||
}
|
||||
|
||||
func TestSubStruct(t *testing.T) {
|
||||
db.DropTable(Category{})
|
||||
db.DropTable(Post{})
|
||||
db.DropTable(Comment{})
|
||||
|
||||
db.CreateTable(Category{})
|
||||
db.CreateTable(Post{})
|
||||
db.CreateTable(Comment{})
|
||||
|
@ -1034,4 +1038,12 @@ func TestSubStruct(t *testing.T) {
|
|||
if db.First(&Category{}, "name = ?", "Category 1").Error != nil {
|
||||
t.Errorf("Category should be saved")
|
||||
}
|
||||
|
||||
if db.First(&Comment{}, "content = ?", "Comment 1").Error != nil {
|
||||
t.Errorf("Comment 1 should be saved")
|
||||
}
|
||||
|
||||
if db.First(&Comment{}, "content = ?", "Comment 2").Error != nil {
|
||||
t.Errorf("Comment 2 should be saved")
|
||||
}
|
||||
}
|
||||
|
|
48
model.go
48
model.go
|
@ -25,6 +25,9 @@ type Field struct {
|
|||
AutoUpdateTime bool
|
||||
IsPrimaryKey bool
|
||||
IsBlank bool
|
||||
|
||||
beforeAssociation bool
|
||||
afterAssociation bool
|
||||
}
|
||||
|
||||
func (m *Model) primaryKeyZero() bool {
|
||||
|
@ -66,10 +69,14 @@ func (m *Model) primaryKeyDb() string {
|
|||
|
||||
func (m *Model) fields(operation string) (fields []Field) {
|
||||
if len(m._cache_fields[operation]) > 0 {
|
||||
return
|
||||
return m._cache_fields[operation]
|
||||
}
|
||||
|
||||
indirect_value := reflect.Indirect(reflect.ValueOf(m.data))
|
||||
if !indirect_value.IsValid() {
|
||||
return
|
||||
}
|
||||
|
||||
typ := indirect_value.Type()
|
||||
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
|
@ -89,9 +96,19 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
field.IsBlank = value.Int() == 0
|
||||
case reflect.String:
|
||||
field.IsBlank = value.String() == ""
|
||||
default:
|
||||
case reflect.Slice:
|
||||
if value.Len() == 0 {
|
||||
field.IsBlank = true
|
||||
}
|
||||
case reflect.Struct:
|
||||
if is_time {
|
||||
field.IsBlank = time_value.IsZero()
|
||||
} else {
|
||||
m := &Model{data: value.Interface(), driver: m.driver}
|
||||
fields := m.columnsHasValue("other")
|
||||
if len(fields) == 0 {
|
||||
field.IsBlank = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -115,9 +132,16 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
} else {
|
||||
switch reflect.TypeOf(field.Value).Kind() {
|
||||
case reflect.Slice:
|
||||
field.afterAssociation = true
|
||||
case reflect.Struct:
|
||||
if is_time {
|
||||
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
||||
} else {
|
||||
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
|
||||
field.beforeAssociation = true
|
||||
} else {
|
||||
field.afterAssociation = true
|
||||
}
|
||||
}
|
||||
default:
|
||||
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
||||
|
@ -258,8 +282,26 @@ func (m *Model) setValueByColumn(name string, value interface{}, out interface{}
|
|||
setFieldValue(data.FieldByName(snakeToUpperCamel(name)), value)
|
||||
}
|
||||
|
||||
func (m *Model) beforeAssociations() (fields []Field) {
|
||||
for _, field := range m.fields("null") {
|
||||
if field.beforeAssociation && !field.IsBlank {
|
||||
fields = append(fields, field)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Model) afterAssociations() (fields []Field) {
|
||||
for _, field := range m.fields("null") {
|
||||
if field.afterAssociation && !field.IsBlank {
|
||||
fields = append(fields, field)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func setFieldValue(field reflect.Value, value interface{}) {
|
||||
if field.IsValid() {
|
||||
if field.IsValid() && field.CanAddr() {
|
||||
switch field.Kind() {
|
||||
case reflect.Int, reflect.Int32, reflect.Int64:
|
||||
if str, ok := value.(string); ok {
|
||||
|
|
Loading…
Reference in New Issue