diff --git a/callback_update.go b/callback_update.go index 192d8a9e..aa27b5fb 100644 --- a/callback_update.go +++ b/callback_update.go @@ -21,12 +21,10 @@ func init() { // assignUpdatingAttributesCallback assign updating attributes to model func assignUpdatingAttributesCallback(scope *Scope) { if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { - if maps := convertInterfaceToMap(attrs); len(maps) > 0 { - if updateMaps, hasUpdate := scope.updatedAttrsWithValues(maps); hasUpdate { - scope.InstanceSet("gorm:update_attrs", updateMaps) - } else { - scope.SkipLeft() - } + if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate { + scope.InstanceSet("gorm:update_attrs", updateMaps) + } else { + scope.SkipLeft() } } } diff --git a/main.go b/main.go index a5201d48..d09cf416 100644 --- a/main.go +++ b/main.go @@ -310,7 +310,7 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { } c.NewScope(out).inlineCondition(where...).initialize() } else { - c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs)) + c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs) } return c } diff --git a/scope.go b/scope.go index a5eedbac..da5f7ff3 100644 --- a/scope.go +++ b/scope.go @@ -793,27 +793,55 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { return scope } -func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) { +func convertInterfaceToMap(values interface{}) map[string]interface{} { + var attrs = map[string]interface{}{} + + switch value := values.(type) { + case map[string]interface{}: + return value + case []interface{}: + for _, v := range value { + for key, value := range convertInterfaceToMap(v) { + attrs[key] = value + } + } + case interface{}: + reflectValue := reflect.ValueOf(values) + + switch reflectValue.Kind() { + case reflect.Map: + for _, key := range reflectValue.MapKeys() { + attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() + } + default: + for _, field := range (&Scope{Value: values}).Fields() { + if !field.IsBlank { + attrs[field.DBName] = field.Field.Interface() + } + } + } + } + return attrs +} + +func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) { if scope.IndirectValue().Kind() != reflect.Struct { - return values, true + return convertInterfaceToMap(value), true } results = map[string]interface{}{} - for key, value := range values { + + for key, value := range convertInterfaceToMap(value) { if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { - if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) { - if _, ok := value.(*expr); ok { - hasUpdate = true - results[field.DBName] = value - } else if !equalAsString(field.Field.Interface(), value) { - field.Set(value) - if field.IsNormal { - hasUpdate = true - results[field.DBName] = field.Field.Interface() - } - } + if _, ok := value.(*expr); ok { + hasUpdate = true + results[field.DBName] = value } else { field.Set(value) + if field.IsNormal { + hasUpdate = true + results[field.DBName] = field.Field.Interface() + } } } } @@ -836,10 +864,10 @@ func (scope *Scope) rows() (*sql.Rows, error) { func (scope *Scope) initialize() *Scope { for _, clause := range scope.Search.whereConditions { - scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"])) + scope.updatedAttrsWithValues(clause["query"]) } - scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs)) - scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs)) + scope.updatedAttrsWithValues(scope.Search.initAttrs) + scope.updatedAttrsWithValues(scope.Search.assignAttrs) return scope } diff --git a/update_test.go b/update_test.go index 218c5834..bdf01091 100644 --- a/update_test.go +++ b/update_test.go @@ -20,13 +20,6 @@ func TestUpdate(t *testing.T) { DB.First(&product1, product1.Id) DB.First(&product2, product2.Id) updatedAt1 := product1.UpdatedAt - updatedAt2 := product2.UpdatedAt - - var product3 Product - DB.First(&product3, product2.Id).Update("code", "product2newcode") - if updatedAt2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should not be updated if nothing changed") - } if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() { t.Errorf("Product1 should not be updated") @@ -135,19 +128,8 @@ func TestUpdates(t *testing.T) { DB.First(&product1, product1.Id) DB.First(&product2, product2.Id) - updatedAt1 := product1.UpdatedAt updatedAt2 := product2.UpdatedAt - var product3 Product - DB.First(&product3, product1.Id).Updates(Product{Code: "product1newcode", Price: 100}) - if product3.Code != "product1newcode" || product3.Price != 100 { - t.Errorf("Record should be updated with struct") - } - - if updatedAt1.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should not be updated if nothing changed") - } - if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() { t.Errorf("Product2 should not be updated") } diff --git a/utils.go b/utils.go index c525631c..8ac4fa7d 100644 --- a/utils.go +++ b/utils.go @@ -199,37 +199,6 @@ func toSearchableMap(attrs ...interface{}) (result interface{}) { return } -func convertInterfaceToMap(values interface{}) map[string]interface{} { - attrs := map[string]interface{}{} - - switch value := values.(type) { - case map[string]interface{}: - return value - case []interface{}: - for _, v := range value { - for key, value := range convertInterfaceToMap(v) { - attrs[key] = value - } - } - case interface{}: - reflectValue := reflect.ValueOf(values) - - switch reflectValue.Kind() { - case reflect.Map: - for _, key := range reflectValue.MapKeys() { - attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() - } - default: - for _, field := range (&Scope{Value: values}).Fields() { - if !field.IsBlank { - attrs[field.DBName] = field.Field.Interface() - } - } - } - } - return attrs -} - func equalAsString(a interface{}, b interface{}) bool { return toString(a) == toString(b) }