Save substructs successfully

This commit is contained in:
Jinzhu 2013-11-02 17:29:56 +08:00
parent 8c36a5d193
commit b9f4a59772
3 changed files with 84 additions and 8 deletions

32
do.go
View File

@ -61,9 +61,10 @@ func (s *Do) hasError() bool {
return len(s.Errors) > 0
}
func (s *Do) setModel(value interface{}) {
func (s *Do) setModel(value interface{}) *Do {
s.model = &Model{data: value, driver: s.driver}
s.value = value
return s
}
func (s *Do) addToVars(value interface{}) string {
@ -114,9 +115,26 @@ func (s *Do) prepareCreateSql() {
return
}
func (s *Do) saveAssociation(typ string) {
if typ == "before" {
} else if typ == "after" {
func (s *Do) saveBeforeAssociations() {
for _, field := range s.model.beforeAssociations() {
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
do.setModel(field.Value).save()
}
}
func (s *Do) saveAfterAssociations() {
for _, field := range s.model.afterAssociations() {
reflect_value := reflect.ValueOf(field.Value)
switch reflect.TypeOf(field.Value).Kind() {
case reflect.Slice:
for i := 0; i < reflect_value.Len(); i++ {
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
do.setModel(reflect_value.Index(i).Addr().Interface()).save()
}
default:
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
do.setModel(field.Value).save()
}
}
}
@ -124,6 +142,7 @@ func (s *Do) create() {
s.err(s.model.callMethod("BeforeCreate"))
s.err(s.model.callMethod("BeforeSave"))
s.saveBeforeAssociations()
s.prepareCreateSql()
if !s.hasError() {
@ -139,8 +158,9 @@ func (s *Do) create() {
}
if !s.hasError() {
result := reflect.ValueOf(s.value).Elem()
result := reflect.Indirect(reflect.ValueOf(s.value))
setFieldValue(result.FieldByName(s.model.primaryKey()), id)
s.saveAfterAssociations()
s.err(s.model.callMethod("AfterCreate"))
s.err(s.model.callMethod("AfterSave"))
@ -212,10 +232,12 @@ func (s *Do) update() {
s.err(s.model.callMethod("BeforeUpdate"))
s.err(s.model.callMethod("BeforeSave"))
s.saveBeforeAssociations()
s.prepareUpdateSql(update_attrs)
if !s.hasError() {
s.exec()
s.saveAfterAssociations()
if !s.hasError() {
s.err(s.model.callMethod("AfterUpdate"))

View File

@ -1016,6 +1016,10 @@ type Comment struct {
}
func TestSubStruct(t *testing.T) {
db.DropTable(Category{})
db.DropTable(Post{})
db.DropTable(Comment{})
db.CreateTable(Category{})
db.CreateTable(Post{})
db.CreateTable(Comment{})
@ -1034,4 +1038,12 @@ func TestSubStruct(t *testing.T) {
if db.First(&Category{}, "name = ?", "Category 1").Error != nil {
t.Errorf("Category should be saved")
}
if db.First(&Comment{}, "content = ?", "Comment 1").Error != nil {
t.Errorf("Comment 1 should be saved")
}
if db.First(&Comment{}, "content = ?", "Comment 2").Error != nil {
t.Errorf("Comment 2 should be saved")
}
}

View File

@ -25,6 +25,9 @@ type Field struct {
AutoUpdateTime bool
IsPrimaryKey bool
IsBlank bool
beforeAssociation bool
afterAssociation bool
}
func (m *Model) primaryKeyZero() bool {
@ -66,10 +69,14 @@ func (m *Model) primaryKeyDb() string {
func (m *Model) fields(operation string) (fields []Field) {
if len(m._cache_fields[operation]) > 0 {
return
return m._cache_fields[operation]
}
indirect_value := reflect.Indirect(reflect.ValueOf(m.data))
if !indirect_value.IsValid() {
return
}
typ := indirect_value.Type()
for i := 0; i < typ.NumField(); i++ {
@ -89,9 +96,19 @@ func (m *Model) fields(operation string) (fields []Field) {
field.IsBlank = value.Int() == 0
case reflect.String:
field.IsBlank = value.String() == ""
default:
case reflect.Slice:
if value.Len() == 0 {
field.IsBlank = true
}
case reflect.Struct:
if is_time {
field.IsBlank = time_value.IsZero()
} else {
m := &Model{data: value.Interface(), driver: m.driver}
fields := m.columnsHasValue("other")
if len(fields) == 0 {
field.IsBlank = true
}
}
}
@ -115,9 +132,16 @@ func (m *Model) fields(operation string) (fields []Field) {
} else {
switch reflect.TypeOf(field.Value).Kind() {
case reflect.Slice:
field.afterAssociation = true
case reflect.Struct:
if is_time {
field.SqlType = getSqlType(m.driver, field.Value, 0)
} else {
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
field.beforeAssociation = true
} else {
field.afterAssociation = true
}
}
default:
field.SqlType = getSqlType(m.driver, field.Value, 0)
@ -258,8 +282,26 @@ func (m *Model) setValueByColumn(name string, value interface{}, out interface{}
setFieldValue(data.FieldByName(snakeToUpperCamel(name)), value)
}
func (m *Model) beforeAssociations() (fields []Field) {
for _, field := range m.fields("null") {
if field.beforeAssociation && !field.IsBlank {
fields = append(fields, field)
}
}
return
}
func (m *Model) afterAssociations() (fields []Field) {
for _, field := range m.fields("null") {
if field.afterAssociation && !field.IsBlank {
fields = append(fields, field)
}
}
return
}
func setFieldValue(field reflect.Value, value interface{}) {
if field.IsValid() {
if field.IsValid() && field.CanAddr() {
switch field.Kind() {
case reflect.Int, reflect.Int32, reflect.Int64:
if str, ok := value.(string); ok {