mirror of https://github.com/go-gorm/gorm.git
Make Update, Updates smart, Only run sql when there are really some new things and reflect changes back to the struct
This commit is contained in:
parent
721090d39a
commit
a63b3247f6
34
do.go
34
do.go
|
@ -77,7 +77,9 @@ func (s *Do) exec(sql ...string) {
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if len(sql) == 0 {
|
if len(sql) == 0 {
|
||||||
s.sqlResult, err = s.db.Exec(s.sql, s.sqlVars...)
|
if len(s.sql) > 0 {
|
||||||
|
s.sqlResult, err = s.db.Exec(s.sql, s.sqlVars...)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
s.sqlResult, err = s.db.Exec(sql[0])
|
s.sqlResult, err = s.db.Exec(sql[0])
|
||||||
}
|
}
|
||||||
|
@ -108,7 +110,6 @@ func (s *Do) prepareCreateSql() {
|
||||||
strings.Join(sqls, ","),
|
strings.Join(sqls, ","),
|
||||||
s.model.returningStr(),
|
s.model.returningStr(),
|
||||||
)
|
)
|
||||||
debug(s.sql)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -145,13 +146,20 @@ func (s *Do) create() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) prepareUpdateSql() {
|
func (s *Do) prepareUpdateAttrs() (results map[string]interface{}, update bool) {
|
||||||
update_attrs := s.updateAttrs
|
if len(s.updateAttrs) > 0 {
|
||||||
if len(update_attrs) == 0 {
|
results, update = s.model.updatedColumnsAndValues(s.updateAttrs)
|
||||||
update_attrs = s.model.columnsAndValues("update")
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Do) prepareUpdateSql(results map[string]interface{}) {
|
||||||
|
var sqls []string
|
||||||
|
for key, value := range results {
|
||||||
|
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
|
||||||
}
|
}
|
||||||
|
|
||||||
var sqls []string
|
update_attrs := s.model.columnsAndValues("update")
|
||||||
for key, value := range update_attrs {
|
for key, value := range update_attrs {
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
|
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
|
||||||
}
|
}
|
||||||
|
@ -166,10 +174,20 @@ func (s *Do) prepareUpdateSql() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) update() {
|
func (s *Do) update() {
|
||||||
|
update_attrs := s.updateAttrs
|
||||||
|
if len(update_attrs) > 0 {
|
||||||
|
var need_update bool
|
||||||
|
update_attrs, need_update = s.prepareUpdateAttrs()
|
||||||
|
if !need_update {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
s.err(s.model.callMethod("BeforeUpdate"))
|
s.err(s.model.callMethod("BeforeUpdate"))
|
||||||
s.err(s.model.callMethod("BeforeSave"))
|
s.err(s.model.callMethod("BeforeSave"))
|
||||||
|
|
||||||
s.prepareUpdateSql()
|
s.prepareUpdateSql(update_attrs)
|
||||||
|
|
||||||
if !s.hasError() {
|
if !s.hasError() {
|
||||||
s.exec()
|
s.exec()
|
||||||
|
|
||||||
|
|
40
gorm_test.go
40
gorm_test.go
|
@ -717,6 +717,20 @@ func TestUpdate(t *testing.T) {
|
||||||
product1 := Product{Code: "123"}
|
product1 := Product{Code: "123"}
|
||||||
product2 := Product{Code: "234"}
|
product2 := Product{Code: "234"}
|
||||||
db.Save(&product1).Save(&product2).Update("code", "456")
|
db.Save(&product1).Save(&product2).Update("code", "456")
|
||||||
|
if product2.Code != "456" {
|
||||||
|
t.Errorf("Object should be updated also with update attributes")
|
||||||
|
}
|
||||||
|
|
||||||
|
db.First(&product1, product1.Id)
|
||||||
|
db.First(&product2, product2.Id)
|
||||||
|
updated_at1 := product1.UpdatedAt
|
||||||
|
updated_at2 := product2.UpdatedAt
|
||||||
|
|
||||||
|
var product3 Product
|
||||||
|
db.First(&product3, product2.Id).Update("code", "456")
|
||||||
|
if updated_at2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
|
||||||
|
t.Errorf("updated_at should not be updated if nothing really new")
|
||||||
|
}
|
||||||
|
|
||||||
if db.First(&Product{}, "code = '123'").Error != nil {
|
if db.First(&Product{}, "code = '123'").Error != nil {
|
||||||
t.Errorf("Product 123's code should not be changed!")
|
t.Errorf("Product 123's code should not be changed!")
|
||||||
|
@ -731,6 +745,13 @@ func TestUpdate(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Table("products").Where("code in (?)", []string{"123"}).Update("code", "789")
|
db.Table("products").Where("code in (?)", []string{"123"}).Update("code", "789")
|
||||||
|
|
||||||
|
var product4 Product
|
||||||
|
db.First(&product4, product1.Id)
|
||||||
|
if updated_at1.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||||
|
t.Errorf("updated_at should be updated if something updated")
|
||||||
|
}
|
||||||
|
|
||||||
if db.First(&Product{}, "code = '123'").Error == nil {
|
if db.First(&Product{}, "code = '123'").Error == nil {
|
||||||
t.Errorf("Product 123's code should be changed to 789")
|
t.Errorf("Product 123's code should be changed to 789")
|
||||||
}
|
}
|
||||||
|
@ -748,6 +769,20 @@ func TestUpdates(t *testing.T) {
|
||||||
product1 := Product{Code: "abc", Price: 10}
|
product1 := Product{Code: "abc", Price: 10}
|
||||||
product2 := Product{Code: "cde", Price: 20}
|
product2 := Product{Code: "cde", Price: 20}
|
||||||
db.Save(&product1).Save(&product2).Updates(map[string]interface{}{"code": "edf", "price": 100})
|
db.Save(&product1).Save(&product2).Updates(map[string]interface{}{"code": "edf", "price": 100})
|
||||||
|
if product2.Code != "edf" || product2.Price != 100 {
|
||||||
|
t.Errorf("Object should be updated also with update attributes")
|
||||||
|
}
|
||||||
|
db.First(&product1, product1.Id)
|
||||||
|
db.First(&product2, product2.Id)
|
||||||
|
updated_at1 := product1.UpdatedAt
|
||||||
|
updated_at2 := product2.UpdatedAt
|
||||||
|
|
||||||
|
var product3 Product
|
||||||
|
db.First(&product3, product2.Id).Updates(map[string]interface{}{"code": "edf", "price": 100})
|
||||||
|
|
||||||
|
if updated_at2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
|
||||||
|
t.Errorf("updated_at should not be updated if nothing really new")
|
||||||
|
}
|
||||||
|
|
||||||
if db.First(&Product{}, "code = 'abc' and price = 10").Error != nil {
|
if db.First(&Product{}, "code = 'abc' and price = 10").Error != nil {
|
||||||
t.Errorf("Product abc should not be updated!")
|
t.Errorf("Product abc should not be updated!")
|
||||||
|
@ -765,6 +800,11 @@ func TestUpdates(t *testing.T) {
|
||||||
if db.First(&Product{}, "code = 'abc'").Error == nil {
|
if db.First(&Product{}, "code = 'abc'").Error == nil {
|
||||||
t.Errorf("Product abc's code should be changed to fgh")
|
t.Errorf("Product abc's code should be changed to fgh")
|
||||||
}
|
}
|
||||||
|
var product4 Product
|
||||||
|
db.First(&product4, product1.Id)
|
||||||
|
if updated_at1.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||||
|
t.Errorf("updated_at should be updated if something updated")
|
||||||
|
}
|
||||||
|
|
||||||
if db.First(&Product{}, "code = 'edf' and price = ?", 100).Error != nil {
|
if db.First(&Product{}, "code = 'edf' and price = ?", 100).Error != nil {
|
||||||
t.Errorf("Product cde's code should not be changed to fgh")
|
t.Errorf("Product cde's code should not be changed to fgh")
|
||||||
|
|
38
model.go
38
model.go
|
@ -125,7 +125,45 @@ func (m *Model) columnsHasValue(operation string) (fields []Field) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (map[string]interface{}, bool) {
|
||||||
|
if m.data == nil {
|
||||||
|
return values, true
|
||||||
|
}
|
||||||
|
|
||||||
|
data := reflect.Indirect(reflect.ValueOf(m.data))
|
||||||
|
results := map[string]interface{}{}
|
||||||
|
|
||||||
|
for key, value := range values {
|
||||||
|
field := data.FieldByName(snakeToUpperCamel(key))
|
||||||
|
if field.IsValid() {
|
||||||
|
if field.Interface() != value {
|
||||||
|
switch field.Kind() {
|
||||||
|
case reflect.Int, reflect.Int32, reflect.Int64:
|
||||||
|
field.SetInt(reflect.ValueOf(value).Int())
|
||||||
|
if field.Int() != reflect.ValueOf(value).Int() {
|
||||||
|
results[key] = value
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
results[key] = value
|
||||||
|
field.Set(reflect.ValueOf(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
field := data.FieldByName("UpdatedAt")
|
||||||
|
if field.IsValid() && values["updated_at"] != nil && len(results) > 0 {
|
||||||
|
data.FieldByName("UpdatedAt").Set(reflect.ValueOf(time.Now()))
|
||||||
|
}
|
||||||
|
result := len(results) > 0
|
||||||
|
return map[string]interface{}{}, result
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) columnsAndValues(operation string) map[string]interface{} {
|
func (m *Model) columnsAndValues(operation string) map[string]interface{} {
|
||||||
|
if m.data == nil {
|
||||||
|
return map[string]interface{}{}
|
||||||
|
}
|
||||||
|
|
||||||
results := map[string]interface{}{}
|
results := map[string]interface{}{}
|
||||||
for _, field := range m.fields(operation) {
|
for _, field := range m.fields(operation) {
|
||||||
if !field.IsPrimaryKey {
|
if !field.IsPrimaryKey {
|
||||||
|
|
Loading…
Reference in New Issue