forked from mirror/gorm
Save substructs successfully
This commit is contained in:
parent
8c36a5d193
commit
b9f4a59772
32
do.go
32
do.go
|
@ -61,9 +61,10 @@ func (s *Do) hasError() bool {
|
||||||
return len(s.Errors) > 0
|
return len(s.Errors) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) setModel(value interface{}) {
|
func (s *Do) setModel(value interface{}) *Do {
|
||||||
s.model = &Model{data: value, driver: s.driver}
|
s.model = &Model{data: value, driver: s.driver}
|
||||||
s.value = value
|
s.value = value
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) addToVars(value interface{}) string {
|
func (s *Do) addToVars(value interface{}) string {
|
||||||
|
@ -114,9 +115,26 @@ func (s *Do) prepareCreateSql() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) saveAssociation(typ string) {
|
func (s *Do) saveBeforeAssociations() {
|
||||||
if typ == "before" {
|
for _, field := range s.model.beforeAssociations() {
|
||||||
} else if typ == "after" {
|
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
||||||
|
do.setModel(field.Value).save()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Do) saveAfterAssociations() {
|
||||||
|
for _, field := range s.model.afterAssociations() {
|
||||||
|
reflect_value := reflect.ValueOf(field.Value)
|
||||||
|
switch reflect.TypeOf(field.Value).Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
for i := 0; i < reflect_value.Len(); i++ {
|
||||||
|
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
||||||
|
do.setModel(reflect_value.Index(i).Addr().Interface()).save()
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
||||||
|
do.setModel(field.Value).save()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -124,6 +142,7 @@ func (s *Do) create() {
|
||||||
s.err(s.model.callMethod("BeforeCreate"))
|
s.err(s.model.callMethod("BeforeCreate"))
|
||||||
s.err(s.model.callMethod("BeforeSave"))
|
s.err(s.model.callMethod("BeforeSave"))
|
||||||
|
|
||||||
|
s.saveBeforeAssociations()
|
||||||
s.prepareCreateSql()
|
s.prepareCreateSql()
|
||||||
|
|
||||||
if !s.hasError() {
|
if !s.hasError() {
|
||||||
|
@ -139,8 +158,9 @@ func (s *Do) create() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.hasError() {
|
if !s.hasError() {
|
||||||
result := reflect.ValueOf(s.value).Elem()
|
result := reflect.Indirect(reflect.ValueOf(s.value))
|
||||||
setFieldValue(result.FieldByName(s.model.primaryKey()), id)
|
setFieldValue(result.FieldByName(s.model.primaryKey()), id)
|
||||||
|
s.saveAfterAssociations()
|
||||||
|
|
||||||
s.err(s.model.callMethod("AfterCreate"))
|
s.err(s.model.callMethod("AfterCreate"))
|
||||||
s.err(s.model.callMethod("AfterSave"))
|
s.err(s.model.callMethod("AfterSave"))
|
||||||
|
@ -212,10 +232,12 @@ func (s *Do) update() {
|
||||||
s.err(s.model.callMethod("BeforeUpdate"))
|
s.err(s.model.callMethod("BeforeUpdate"))
|
||||||
s.err(s.model.callMethod("BeforeSave"))
|
s.err(s.model.callMethod("BeforeSave"))
|
||||||
|
|
||||||
|
s.saveBeforeAssociations()
|
||||||
s.prepareUpdateSql(update_attrs)
|
s.prepareUpdateSql(update_attrs)
|
||||||
|
|
||||||
if !s.hasError() {
|
if !s.hasError() {
|
||||||
s.exec()
|
s.exec()
|
||||||
|
s.saveAfterAssociations()
|
||||||
|
|
||||||
if !s.hasError() {
|
if !s.hasError() {
|
||||||
s.err(s.model.callMethod("AfterUpdate"))
|
s.err(s.model.callMethod("AfterUpdate"))
|
||||||
|
|
12
gorm_test.go
12
gorm_test.go
|
@ -1016,6 +1016,10 @@ type Comment struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSubStruct(t *testing.T) {
|
func TestSubStruct(t *testing.T) {
|
||||||
|
db.DropTable(Category{})
|
||||||
|
db.DropTable(Post{})
|
||||||
|
db.DropTable(Comment{})
|
||||||
|
|
||||||
db.CreateTable(Category{})
|
db.CreateTable(Category{})
|
||||||
db.CreateTable(Post{})
|
db.CreateTable(Post{})
|
||||||
db.CreateTable(Comment{})
|
db.CreateTable(Comment{})
|
||||||
|
@ -1034,4 +1038,12 @@ func TestSubStruct(t *testing.T) {
|
||||||
if db.First(&Category{}, "name = ?", "Category 1").Error != nil {
|
if db.First(&Category{}, "name = ?", "Category 1").Error != nil {
|
||||||
t.Errorf("Category should be saved")
|
t.Errorf("Category should be saved")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if db.First(&Comment{}, "content = ?", "Comment 1").Error != nil {
|
||||||
|
t.Errorf("Comment 1 should be saved")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.First(&Comment{}, "content = ?", "Comment 2").Error != nil {
|
||||||
|
t.Errorf("Comment 2 should be saved")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
48
model.go
48
model.go
|
@ -25,6 +25,9 @@ type Field struct {
|
||||||
AutoUpdateTime bool
|
AutoUpdateTime bool
|
||||||
IsPrimaryKey bool
|
IsPrimaryKey bool
|
||||||
IsBlank bool
|
IsBlank bool
|
||||||
|
|
||||||
|
beforeAssociation bool
|
||||||
|
afterAssociation bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) primaryKeyZero() bool {
|
func (m *Model) primaryKeyZero() bool {
|
||||||
|
@ -66,10 +69,14 @@ func (m *Model) primaryKeyDb() string {
|
||||||
|
|
||||||
func (m *Model) fields(operation string) (fields []Field) {
|
func (m *Model) fields(operation string) (fields []Field) {
|
||||||
if len(m._cache_fields[operation]) > 0 {
|
if len(m._cache_fields[operation]) > 0 {
|
||||||
return
|
return m._cache_fields[operation]
|
||||||
}
|
}
|
||||||
|
|
||||||
indirect_value := reflect.Indirect(reflect.ValueOf(m.data))
|
indirect_value := reflect.Indirect(reflect.ValueOf(m.data))
|
||||||
|
if !indirect_value.IsValid() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
typ := indirect_value.Type()
|
typ := indirect_value.Type()
|
||||||
|
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
@ -89,9 +96,19 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
field.IsBlank = value.Int() == 0
|
field.IsBlank = value.Int() == 0
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
field.IsBlank = value.String() == ""
|
field.IsBlank = value.String() == ""
|
||||||
default:
|
case reflect.Slice:
|
||||||
|
if value.Len() == 0 {
|
||||||
|
field.IsBlank = true
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
if is_time {
|
if is_time {
|
||||||
field.IsBlank = time_value.IsZero()
|
field.IsBlank = time_value.IsZero()
|
||||||
|
} else {
|
||||||
|
m := &Model{data: value.Interface(), driver: m.driver}
|
||||||
|
fields := m.columnsHasValue("other")
|
||||||
|
if len(fields) == 0 {
|
||||||
|
field.IsBlank = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -115,9 +132,16 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
} else {
|
} else {
|
||||||
switch reflect.TypeOf(field.Value).Kind() {
|
switch reflect.TypeOf(field.Value).Kind() {
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
|
field.afterAssociation = true
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if is_time {
|
if is_time {
|
||||||
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
||||||
|
} else {
|
||||||
|
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
|
||||||
|
field.beforeAssociation = true
|
||||||
|
} else {
|
||||||
|
field.afterAssociation = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
field.SqlType = getSqlType(m.driver, field.Value, 0)
|
||||||
|
@ -258,8 +282,26 @@ func (m *Model) setValueByColumn(name string, value interface{}, out interface{}
|
||||||
setFieldValue(data.FieldByName(snakeToUpperCamel(name)), value)
|
setFieldValue(data.FieldByName(snakeToUpperCamel(name)), value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) beforeAssociations() (fields []Field) {
|
||||||
|
for _, field := range m.fields("null") {
|
||||||
|
if field.beforeAssociation && !field.IsBlank {
|
||||||
|
fields = append(fields, field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) afterAssociations() (fields []Field) {
|
||||||
|
for _, field := range m.fields("null") {
|
||||||
|
if field.afterAssociation && !field.IsBlank {
|
||||||
|
fields = append(fields, field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func setFieldValue(field reflect.Value, value interface{}) {
|
func setFieldValue(field reflect.Value, value interface{}) {
|
||||||
if field.IsValid() {
|
if field.IsValid() && field.CanAddr() {
|
||||||
switch field.Kind() {
|
switch field.Kind() {
|
||||||
case reflect.Int, reflect.Int32, reflect.Int64:
|
case reflect.Int, reflect.Int32, reflect.Int64:
|
||||||
if str, ok := value.(string); ok {
|
if str, ok := value.(string); ok {
|
||||||
|
|
Loading…
Reference in New Issue