Make Update, Updates smart, Only run sql when there are really some new things and reflect changes back to the struct

This commit is contained in:
Jinzhu 2013-10-30 23:19:00 +08:00
parent 721090d39a
commit a63b3247f6
3 changed files with 104 additions and 8 deletions

34
do.go
View File

@ -77,7 +77,9 @@ func (s *Do) exec(sql ...string) {
var err error var err error
if len(sql) == 0 { if len(sql) == 0 {
s.sqlResult, err = s.db.Exec(s.sql, s.sqlVars...) if len(s.sql) > 0 {
s.sqlResult, err = s.db.Exec(s.sql, s.sqlVars...)
}
} else { } else {
s.sqlResult, err = s.db.Exec(sql[0]) s.sqlResult, err = s.db.Exec(sql[0])
} }
@ -108,7 +110,6 @@ func (s *Do) prepareCreateSql() {
strings.Join(sqls, ","), strings.Join(sqls, ","),
s.model.returningStr(), s.model.returningStr(),
) )
debug(s.sql)
return return
} }
@ -145,13 +146,20 @@ func (s *Do) create() {
return return
} }
func (s *Do) prepareUpdateSql() { func (s *Do) prepareUpdateAttrs() (results map[string]interface{}, update bool) {
update_attrs := s.updateAttrs if len(s.updateAttrs) > 0 {
if len(update_attrs) == 0 { results, update = s.model.updatedColumnsAndValues(s.updateAttrs)
update_attrs = s.model.columnsAndValues("update") }
return
}
func (s *Do) prepareUpdateSql(results map[string]interface{}) {
var sqls []string
for key, value := range results {
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
} }
var sqls []string update_attrs := s.model.columnsAndValues("update")
for key, value := range update_attrs { for key, value := range update_attrs {
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value))) sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
} }
@ -166,10 +174,20 @@ func (s *Do) prepareUpdateSql() {
} }
func (s *Do) update() { func (s *Do) update() {
update_attrs := s.updateAttrs
if len(update_attrs) > 0 {
var need_update bool
update_attrs, need_update = s.prepareUpdateAttrs()
if !need_update {
return
}
}
s.err(s.model.callMethod("BeforeUpdate")) s.err(s.model.callMethod("BeforeUpdate"))
s.err(s.model.callMethod("BeforeSave")) s.err(s.model.callMethod("BeforeSave"))
s.prepareUpdateSql() s.prepareUpdateSql(update_attrs)
if !s.hasError() { if !s.hasError() {
s.exec() s.exec()

View File

@ -717,6 +717,20 @@ func TestUpdate(t *testing.T) {
product1 := Product{Code: "123"} product1 := Product{Code: "123"}
product2 := Product{Code: "234"} product2 := Product{Code: "234"}
db.Save(&product1).Save(&product2).Update("code", "456") db.Save(&product1).Save(&product2).Update("code", "456")
if product2.Code != "456" {
t.Errorf("Object should be updated also with update attributes")
}
db.First(&product1, product1.Id)
db.First(&product2, product2.Id)
updated_at1 := product1.UpdatedAt
updated_at2 := product2.UpdatedAt
var product3 Product
db.First(&product3, product2.Id).Update("code", "456")
if updated_at2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("updated_at should not be updated if nothing really new")
}
if db.First(&Product{}, "code = '123'").Error != nil { if db.First(&Product{}, "code = '123'").Error != nil {
t.Errorf("Product 123's code should not be changed!") t.Errorf("Product 123's code should not be changed!")
@ -731,6 +745,13 @@ func TestUpdate(t *testing.T) {
} }
db.Table("products").Where("code in (?)", []string{"123"}).Update("code", "789") db.Table("products").Where("code in (?)", []string{"123"}).Update("code", "789")
var product4 Product
db.First(&product4, product1.Id)
if updated_at1.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("updated_at should be updated if something updated")
}
if db.First(&Product{}, "code = '123'").Error == nil { if db.First(&Product{}, "code = '123'").Error == nil {
t.Errorf("Product 123's code should be changed to 789") t.Errorf("Product 123's code should be changed to 789")
} }
@ -748,6 +769,20 @@ func TestUpdates(t *testing.T) {
product1 := Product{Code: "abc", Price: 10} product1 := Product{Code: "abc", Price: 10}
product2 := Product{Code: "cde", Price: 20} product2 := Product{Code: "cde", Price: 20}
db.Save(&product1).Save(&product2).Updates(map[string]interface{}{"code": "edf", "price": 100}) db.Save(&product1).Save(&product2).Updates(map[string]interface{}{"code": "edf", "price": 100})
if product2.Code != "edf" || product2.Price != 100 {
t.Errorf("Object should be updated also with update attributes")
}
db.First(&product1, product1.Id)
db.First(&product2, product2.Id)
updated_at1 := product1.UpdatedAt
updated_at2 := product2.UpdatedAt
var product3 Product
db.First(&product3, product2.Id).Updates(map[string]interface{}{"code": "edf", "price": 100})
if updated_at2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("updated_at should not be updated if nothing really new")
}
if db.First(&Product{}, "code = 'abc' and price = 10").Error != nil { if db.First(&Product{}, "code = 'abc' and price = 10").Error != nil {
t.Errorf("Product abc should not be updated!") t.Errorf("Product abc should not be updated!")
@ -765,6 +800,11 @@ func TestUpdates(t *testing.T) {
if db.First(&Product{}, "code = 'abc'").Error == nil { if db.First(&Product{}, "code = 'abc'").Error == nil {
t.Errorf("Product abc's code should be changed to fgh") t.Errorf("Product abc's code should be changed to fgh")
} }
var product4 Product
db.First(&product4, product1.Id)
if updated_at1.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("updated_at should be updated if something updated")
}
if db.First(&Product{}, "code = 'edf' and price = ?", 100).Error != nil { if db.First(&Product{}, "code = 'edf' and price = ?", 100).Error != nil {
t.Errorf("Product cde's code should not be changed to fgh") t.Errorf("Product cde's code should not be changed to fgh")

View File

@ -125,7 +125,45 @@ func (m *Model) columnsHasValue(operation string) (fields []Field) {
return return
} }
func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (map[string]interface{}, bool) {
if m.data == nil {
return values, true
}
data := reflect.Indirect(reflect.ValueOf(m.data))
results := map[string]interface{}{}
for key, value := range values {
field := data.FieldByName(snakeToUpperCamel(key))
if field.IsValid() {
if field.Interface() != value {
switch field.Kind() {
case reflect.Int, reflect.Int32, reflect.Int64:
field.SetInt(reflect.ValueOf(value).Int())
if field.Int() != reflect.ValueOf(value).Int() {
results[key] = value
}
default:
results[key] = value
field.Set(reflect.ValueOf(value))
}
}
}
}
field := data.FieldByName("UpdatedAt")
if field.IsValid() && values["updated_at"] != nil && len(results) > 0 {
data.FieldByName("UpdatedAt").Set(reflect.ValueOf(time.Now()))
}
result := len(results) > 0
return map[string]interface{}{}, result
}
func (m *Model) columnsAndValues(operation string) map[string]interface{} { func (m *Model) columnsAndValues(operation string) map[string]interface{} {
if m.data == nil {
return map[string]interface{}{}
}
results := map[string]interface{}{} results := map[string]interface{}{}
for _, field := range m.fields(operation) { for _, field := range m.fields(operation) {
if !field.IsPrimaryKey { if !field.IsPrimaryKey {