Save substructs successfully

This commit is contained in:
Jinzhu 2013-11-02 17:29:56 +08:00
parent 8c36a5d193
commit b9f4a59772
3 changed files with 84 additions and 8 deletions

32
do.go
View File

@ -61,9 +61,10 @@ func (s *Do) hasError() bool {
return len(s.Errors) > 0 return len(s.Errors) > 0
} }
func (s *Do) setModel(value interface{}) { func (s *Do) setModel(value interface{}) *Do {
s.model = &Model{data: value, driver: s.driver} s.model = &Model{data: value, driver: s.driver}
s.value = value s.value = value
return s
} }
func (s *Do) addToVars(value interface{}) string { func (s *Do) addToVars(value interface{}) string {
@ -114,9 +115,26 @@ func (s *Do) prepareCreateSql() {
return return
} }
func (s *Do) saveAssociation(typ string) { func (s *Do) saveBeforeAssociations() {
if typ == "before" { for _, field := range s.model.beforeAssociations() {
} else if typ == "after" { do := &Do{chain: s.chain, db: s.db, driver: s.driver}
do.setModel(field.Value).save()
}
}
func (s *Do) saveAfterAssociations() {
for _, field := range s.model.afterAssociations() {
reflect_value := reflect.ValueOf(field.Value)
switch reflect.TypeOf(field.Value).Kind() {
case reflect.Slice:
for i := 0; i < reflect_value.Len(); i++ {
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
do.setModel(reflect_value.Index(i).Addr().Interface()).save()
}
default:
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
do.setModel(field.Value).save()
}
} }
} }
@ -124,6 +142,7 @@ func (s *Do) create() {
s.err(s.model.callMethod("BeforeCreate")) s.err(s.model.callMethod("BeforeCreate"))
s.err(s.model.callMethod("BeforeSave")) s.err(s.model.callMethod("BeforeSave"))
s.saveBeforeAssociations()
s.prepareCreateSql() s.prepareCreateSql()
if !s.hasError() { if !s.hasError() {
@ -139,8 +158,9 @@ func (s *Do) create() {
} }
if !s.hasError() { if !s.hasError() {
result := reflect.ValueOf(s.value).Elem() result := reflect.Indirect(reflect.ValueOf(s.value))
setFieldValue(result.FieldByName(s.model.primaryKey()), id) setFieldValue(result.FieldByName(s.model.primaryKey()), id)
s.saveAfterAssociations()
s.err(s.model.callMethod("AfterCreate")) s.err(s.model.callMethod("AfterCreate"))
s.err(s.model.callMethod("AfterSave")) s.err(s.model.callMethod("AfterSave"))
@ -212,10 +232,12 @@ func (s *Do) update() {
s.err(s.model.callMethod("BeforeUpdate")) s.err(s.model.callMethod("BeforeUpdate"))
s.err(s.model.callMethod("BeforeSave")) s.err(s.model.callMethod("BeforeSave"))
s.saveBeforeAssociations()
s.prepareUpdateSql(update_attrs) s.prepareUpdateSql(update_attrs)
if !s.hasError() { if !s.hasError() {
s.exec() s.exec()
s.saveAfterAssociations()
if !s.hasError() { if !s.hasError() {
s.err(s.model.callMethod("AfterUpdate")) s.err(s.model.callMethod("AfterUpdate"))

View File

@ -1016,6 +1016,10 @@ type Comment struct {
} }
func TestSubStruct(t *testing.T) { func TestSubStruct(t *testing.T) {
db.DropTable(Category{})
db.DropTable(Post{})
db.DropTable(Comment{})
db.CreateTable(Category{}) db.CreateTable(Category{})
db.CreateTable(Post{}) db.CreateTable(Post{})
db.CreateTable(Comment{}) db.CreateTable(Comment{})
@ -1034,4 +1038,12 @@ func TestSubStruct(t *testing.T) {
if db.First(&Category{}, "name = ?", "Category 1").Error != nil { if db.First(&Category{}, "name = ?", "Category 1").Error != nil {
t.Errorf("Category should be saved") t.Errorf("Category should be saved")
} }
if db.First(&Comment{}, "content = ?", "Comment 1").Error != nil {
t.Errorf("Comment 1 should be saved")
}
if db.First(&Comment{}, "content = ?", "Comment 2").Error != nil {
t.Errorf("Comment 2 should be saved")
}
} }

View File

@ -25,6 +25,9 @@ type Field struct {
AutoUpdateTime bool AutoUpdateTime bool
IsPrimaryKey bool IsPrimaryKey bool
IsBlank bool IsBlank bool
beforeAssociation bool
afterAssociation bool
} }
func (m *Model) primaryKeyZero() bool { func (m *Model) primaryKeyZero() bool {
@ -66,10 +69,14 @@ func (m *Model) primaryKeyDb() string {
func (m *Model) fields(operation string) (fields []Field) { func (m *Model) fields(operation string) (fields []Field) {
if len(m._cache_fields[operation]) > 0 { if len(m._cache_fields[operation]) > 0 {
return return m._cache_fields[operation]
} }
indirect_value := reflect.Indirect(reflect.ValueOf(m.data)) indirect_value := reflect.Indirect(reflect.ValueOf(m.data))
if !indirect_value.IsValid() {
return
}
typ := indirect_value.Type() typ := indirect_value.Type()
for i := 0; i < typ.NumField(); i++ { for i := 0; i < typ.NumField(); i++ {
@ -89,9 +96,19 @@ func (m *Model) fields(operation string) (fields []Field) {
field.IsBlank = value.Int() == 0 field.IsBlank = value.Int() == 0
case reflect.String: case reflect.String:
field.IsBlank = value.String() == "" field.IsBlank = value.String() == ""
default: case reflect.Slice:
if value.Len() == 0 {
field.IsBlank = true
}
case reflect.Struct:
if is_time { if is_time {
field.IsBlank = time_value.IsZero() field.IsBlank = time_value.IsZero()
} else {
m := &Model{data: value.Interface(), driver: m.driver}
fields := m.columnsHasValue("other")
if len(fields) == 0 {
field.IsBlank = true
}
} }
} }
@ -115,9 +132,16 @@ func (m *Model) fields(operation string) (fields []Field) {
} else { } else {
switch reflect.TypeOf(field.Value).Kind() { switch reflect.TypeOf(field.Value).Kind() {
case reflect.Slice: case reflect.Slice:
field.afterAssociation = true
case reflect.Struct: case reflect.Struct:
if is_time { if is_time {
field.SqlType = getSqlType(m.driver, field.Value, 0) field.SqlType = getSqlType(m.driver, field.Value, 0)
} else {
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
field.beforeAssociation = true
} else {
field.afterAssociation = true
}
} }
default: default:
field.SqlType = getSqlType(m.driver, field.Value, 0) field.SqlType = getSqlType(m.driver, field.Value, 0)
@ -258,8 +282,26 @@ func (m *Model) setValueByColumn(name string, value interface{}, out interface{}
setFieldValue(data.FieldByName(snakeToUpperCamel(name)), value) setFieldValue(data.FieldByName(snakeToUpperCamel(name)), value)
} }
func (m *Model) beforeAssociations() (fields []Field) {
for _, field := range m.fields("null") {
if field.beforeAssociation && !field.IsBlank {
fields = append(fields, field)
}
}
return
}
func (m *Model) afterAssociations() (fields []Field) {
for _, field := range m.fields("null") {
if field.afterAssociation && !field.IsBlank {
fields = append(fields, field)
}
}
return
}
func setFieldValue(field reflect.Value, value interface{}) { func setFieldValue(field reflect.Value, value interface{}) {
if field.IsValid() { if field.IsValid() && field.CanAddr() {
switch field.Kind() { switch field.Kind() {
case reflect.Int, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int32, reflect.Int64:
if str, ok := value.(string); ok { if str, ok := value.(string); ok {