diff --git a/do.go b/do.go index d78880f9..2aa7acda 100644 --- a/do.go +++ b/do.go @@ -77,7 +77,9 @@ func (s *Do) exec(sql ...string) { var err error if len(sql) == 0 { - s.sqlResult, err = s.db.Exec(s.sql, s.sqlVars...) + if len(s.sql) > 0 { + s.sqlResult, err = s.db.Exec(s.sql, s.sqlVars...) + } } else { s.sqlResult, err = s.db.Exec(sql[0]) } @@ -108,7 +110,6 @@ func (s *Do) prepareCreateSql() { strings.Join(sqls, ","), s.model.returningStr(), ) - debug(s.sql) return } @@ -145,13 +146,20 @@ func (s *Do) create() { return } -func (s *Do) prepareUpdateSql() { - update_attrs := s.updateAttrs - if len(update_attrs) == 0 { - update_attrs = s.model.columnsAndValues("update") +func (s *Do) prepareUpdateAttrs() (results map[string]interface{}, update bool) { + if len(s.updateAttrs) > 0 { + results, update = s.model.updatedColumnsAndValues(s.updateAttrs) + } + return +} + +func (s *Do) prepareUpdateSql(results map[string]interface{}) { + var sqls []string + for key, value := range results { + sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value))) } - var sqls []string + update_attrs := s.model.columnsAndValues("update") for key, value := range update_attrs { sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value))) } @@ -166,10 +174,20 @@ func (s *Do) prepareUpdateSql() { } func (s *Do) update() { + update_attrs := s.updateAttrs + if len(update_attrs) > 0 { + var need_update bool + update_attrs, need_update = s.prepareUpdateAttrs() + if !need_update { + return + } + } + s.err(s.model.callMethod("BeforeUpdate")) s.err(s.model.callMethod("BeforeSave")) - s.prepareUpdateSql() + s.prepareUpdateSql(update_attrs) + if !s.hasError() { s.exec() diff --git a/gorm_test.go b/gorm_test.go index 46ccd3c2..fe225343 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -717,6 +717,20 @@ func TestUpdate(t *testing.T) { product1 := Product{Code: "123"} product2 := Product{Code: "234"} db.Save(&product1).Save(&product2).Update("code", "456") + if product2.Code != "456" { + t.Errorf("Object should be updated also with update attributes") + } + + db.First(&product1, product1.Id) + db.First(&product2, product2.Id) + updated_at1 := product1.UpdatedAt + updated_at2 := product2.UpdatedAt + + var product3 Product + db.First(&product3, product2.Id).Update("code", "456") + if updated_at2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("updated_at should not be updated if nothing really new") + } if db.First(&Product{}, "code = '123'").Error != nil { t.Errorf("Product 123's code should not be changed!") @@ -731,6 +745,13 @@ func TestUpdate(t *testing.T) { } db.Table("products").Where("code in (?)", []string{"123"}).Update("code", "789") + + var product4 Product + db.First(&product4, product1.Id) + if updated_at1.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("updated_at should be updated if something updated") + } + if db.First(&Product{}, "code = '123'").Error == nil { t.Errorf("Product 123's code should be changed to 789") } @@ -748,6 +769,20 @@ func TestUpdates(t *testing.T) { product1 := Product{Code: "abc", Price: 10} product2 := Product{Code: "cde", Price: 20} db.Save(&product1).Save(&product2).Updates(map[string]interface{}{"code": "edf", "price": 100}) + if product2.Code != "edf" || product2.Price != 100 { + t.Errorf("Object should be updated also with update attributes") + } + db.First(&product1, product1.Id) + db.First(&product2, product2.Id) + updated_at1 := product1.UpdatedAt + updated_at2 := product2.UpdatedAt + + var product3 Product + db.First(&product3, product2.Id).Updates(map[string]interface{}{"code": "edf", "price": 100}) + + if updated_at2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("updated_at should not be updated if nothing really new") + } if db.First(&Product{}, "code = 'abc' and price = 10").Error != nil { t.Errorf("Product abc should not be updated!") @@ -765,6 +800,11 @@ func TestUpdates(t *testing.T) { if db.First(&Product{}, "code = 'abc'").Error == nil { t.Errorf("Product abc's code should be changed to fgh") } + var product4 Product + db.First(&product4, product1.Id) + if updated_at1.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("updated_at should be updated if something updated") + } if db.First(&Product{}, "code = 'edf' and price = ?", 100).Error != nil { t.Errorf("Product cde's code should not be changed to fgh") diff --git a/model.go b/model.go index facb9671..037ac8d8 100644 --- a/model.go +++ b/model.go @@ -125,7 +125,45 @@ func (m *Model) columnsHasValue(operation string) (fields []Field) { return } +func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (map[string]interface{}, bool) { + if m.data == nil { + return values, true + } + + data := reflect.Indirect(reflect.ValueOf(m.data)) + results := map[string]interface{}{} + + for key, value := range values { + field := data.FieldByName(snakeToUpperCamel(key)) + if field.IsValid() { + if field.Interface() != value { + switch field.Kind() { + case reflect.Int, reflect.Int32, reflect.Int64: + field.SetInt(reflect.ValueOf(value).Int()) + if field.Int() != reflect.ValueOf(value).Int() { + results[key] = value + } + default: + results[key] = value + field.Set(reflect.ValueOf(value)) + } + } + } + } + + field := data.FieldByName("UpdatedAt") + if field.IsValid() && values["updated_at"] != nil && len(results) > 0 { + data.FieldByName("UpdatedAt").Set(reflect.ValueOf(time.Now())) + } + result := len(results) > 0 + return map[string]interface{}{}, result +} + func (m *Model) columnsAndValues(operation string) map[string]interface{} { + if m.data == nil { + return map[string]interface{}{} + } + results := map[string]interface{}{} for _, field := range m.fields(operation) { if !field.IsPrimaryKey {