diff --git a/callback_create.go b/callback_create.go index 5ae83d63..6138e9d0 100644 --- a/callback_create.go +++ b/callback_create.go @@ -11,7 +11,7 @@ func BeforeCreate(scope *Scope) { scope.CallMethod("BeforeCreate") } -func UpdateCreateTimeStamp(scope *Scope) { +func UpdateTimeStampWhenCreate(scope *Scope) { if !scope.HasError() { scope.SetColumn("CreatedAt", time.Now()) scope.SetColumn("UpdatedAt", time.Now()) @@ -66,7 +66,7 @@ func init() { DefaultCallback.Create().Register("begin_transaction", BeginTransaction) DefaultCallback.Create().Register("before_create", BeforeCreate) DefaultCallback.Create().Register("save_before_associations", SaveBeforeAssociations) - DefaultCallback.Create().Register("update_create_time_stamp", UpdateCreateTimeStamp) + DefaultCallback.Create().Register("update_time_stamp_when_create", UpdateTimeStampWhenCreate) DefaultCallback.Create().Register("create", Create) DefaultCallback.Create().Register("save_after_associations", SaveAfterAssociations) DefaultCallback.Create().Register("after_create", AfterCreate) diff --git a/callback_update.go b/callback_update.go index e5eb9382..5b1ff338 100644 --- a/callback_update.go +++ b/callback_update.go @@ -6,6 +6,21 @@ import ( "time" ) +func AssignUpdateAttributes(scope *Scope) { + if attrs, ok := scope.Get("gorm:update_interface"); ok { + if maps := convertInterfaceToMap(attrs); len(maps) > 0 { + protected, ok := scope.Get("gorm:ignore_protected_attrs") + updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool)) + if len(updateAttrs) > 0 { + scope.Set("gorm:update_attrs", updateAttrs) + } else if !hasUpdate { + scope.SkipLeft() + return + } + } + } +} + func BeforeUpdate(scope *Scope) { scope.CallMethod("BeforeSave") scope.CallMethod("BeforeUpdate") @@ -18,11 +33,21 @@ func UpdateTimeStampWhenUpdate(scope *Scope) { } func Update(scope *Scope) { + defer scope.Trace(time.Now()) + if !scope.HasError() { var sqls []string - for _, field := range scope.Fields() { - if field.DBName != scope.PrimaryKey() && len(field.SqlTag) > 0 && !field.IsIgnored { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.quote(field.DBName), scope.AddToVars(field.Value))) + + updateAttrs, ok := scope.Get("gorm:update_attrs") + if ok { + for key, value := range updateAttrs.(map[string]interface{}) { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.quote(key), scope.AddToVars(value))) + } + } else { + for _, field := range scope.Fields() { + if field.DBName != scope.PrimaryKey() && len(field.SqlTag) > 0 && !field.IsIgnored { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.quote(field.DBName), scope.AddToVars(field.Value))) + } } } @@ -42,6 +67,7 @@ func AfterUpdate(scope *Scope) { } func init() { + DefaultCallback.Update().Register("assign_update_attributes", AssignUpdateAttributes) DefaultCallback.Update().Register("begin_transaction", BeginTransaction) DefaultCallback.Update().Register("before_update", BeforeUpdate) DefaultCallback.Update().Register("save_before_associations", SaveBeforeAssociations) diff --git a/gorm_test.go b/gorm_test.go index 78bf1d3b..e3954577 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -929,7 +929,8 @@ func TestUpdate(t *testing.T) { func TestUpdates(t *testing.T) { product1 := Product{Code: "abc", Price: 10} product2 := Product{Code: "cde", Price: 20} - db.Save(&product1).Save(&product2).Updates(map[string]interface{}{"code": "edf", "price": 100}) + db.Save(&product1).Save(&product2) + db.Model(&product2).Updates(map[string]interface{}{"code": "edf", "price": 100}) if product2.Code != "edf" || product2.Price != 100 { t.Errorf("Record should be updated also with update attributes") } diff --git a/main.go b/main.go index c3c44601..e0a44caf 100644 --- a/main.go +++ b/main.go @@ -165,8 +165,10 @@ func (s *DB) Update(attrs ...interface{}) *DB { } func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { - return s.clone().do(s.Value).begin().updateAttrs(values, ignoreProtectedAttrs...).update().commit_or_rollback().db - // return s.clone().NewScope(s.Value).callCallbacks(s.parent.callback.updates).db + return s.clone().NewScope(s.Value). + Set("gorm:update_interface", values). + Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). + callCallbacks(s.parent.callback.updates).db } func (s *DB) UpdateColumn(attrs ...interface{}) *DB { diff --git a/scope.go b/scope.go index 27efde54..5401556c 100644 --- a/scope.go +++ b/scope.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/jinzhu/gorm/dialect" "go/ast" + "strconv" "strings" "time" @@ -13,12 +14,13 @@ import ( ) type Scope struct { - Value interface{} - Search *search - Sql string - SqlVars []interface{} - db *DB - _values map[string]interface{} + Value interface{} + Search *search + Sql string + SqlVars []interface{} + db *DB + _values map[string]interface{} + skipLeft bool } func (db *DB) NewScope(value interface{}) *Scope { @@ -26,9 +28,16 @@ func (db *DB) NewScope(value interface{}) *Scope { return &Scope{db: db, Search: db.search, Value: value, _values: map[string]interface{}{}} } +func (scope *Scope) SkipLeft() { + scope.skipLeft = true +} + func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { for _, f := range funcs { (*f)(scope) + if scope.skipLeft { + break + } } return scope } @@ -90,12 +99,54 @@ func (scope *Scope) HasColumn(name string) bool { return false } +func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) { + data := reflect.Indirect(reflect.ValueOf(scope.Value)) + if !data.CanAddr() { + return values, true + } + + for key, value := range values { + if field := data.FieldByName(snakeToUpperCamel(key)); field.IsValid() { + if field.Interface() != value { + + switch field.Kind() { + case reflect.Int, reflect.Int32, reflect.Int64: + if s, ok := value.(string); ok { + i, err := strconv.Atoi(s) + if scope.Err(err) == nil { + value = i + } + } + + scope.db.log(field.Int() != reflect.ValueOf(value).Int()) + if field.Int() != reflect.ValueOf(value).Int() { + hasUpdate = true + setFieldValue(field, value) + } + default: + hasUpdate = true + setFieldValue(field, value) + } + } + } + } + return +} + func (scope *Scope) SetColumn(column string, value interface{}) { + if scope.Value == nil { + return + } + data := reflect.Indirect(reflect.ValueOf(scope.Value)) setFieldValue(data.FieldByName(snakeToUpperCamel(column)), value) } func (scope *Scope) CallMethod(name string) { + if scope.Value == nil { + return + } + if fm := reflect.ValueOf(scope.Value).MethodByName(name); fm.IsValid() { fi := fm.Interface() if f, ok := fi.(func()); ok { @@ -193,65 +244,67 @@ func (scope *Scope) SqlTagForField(field *Field) (tag string) { } func (scope *Scope) Fields() []*Field { - indirect_value := reflect.Indirect(reflect.ValueOf(scope.Value)) + indirectValue := reflect.Indirect(reflect.ValueOf(scope.Value)) fields := []*Field{} - if !indirect_value.IsValid() { + if !indirectValue.IsValid() { return fields } - scope_typ := indirect_value.Type() - for i := 0; i < scope_typ.NumField(); i++ { - field_struct := scope_typ.Field(i) - if field_struct.Anonymous || !ast.IsExported(field_struct.Name) { + scopeTyp := indirectValue.Type() + for i := 0; i < scopeTyp.NumField(); i++ { + fieldStruct := scopeTyp.Field(i) + if fieldStruct.Anonymous || !ast.IsExported(fieldStruct.Name) { continue } var field Field - field.Name = field_struct.Name - field.DBName = toSnake(field_struct.Name) + field.Name = fieldStruct.Name + field.DBName = toSnake(fieldStruct.Name) - value := indirect_value.FieldByName(field_struct.Name) + value := indirectValue.FieldByName(fieldStruct.Name) field.Value = value.Interface() field.IsBlank = isBlank(value) - tag, addational_tag, size := parseSqlTag(field_struct.Tag.Get(scope.db.parent.tagIdentifier)) - field.Tag = tag - field.AddationalTag = addational_tag - field.Size = size - field.SqlTag = scope.SqlTagForField(&field) + if scope.db != nil { + tag, addationalTag, size := parseSqlTag(fieldStruct.Tag.Get(scope.db.parent.tagIdentifier)) + field.Tag = tag + field.AddationalTag = addationalTag + field.Size = size + field.SqlTag = scope.SqlTagForField(&field) - if tag == "-" { - field.IsIgnored = true - } - - // parse association - elem := reflect.Indirect(value) - typ := elem.Type() - - switch elem.Kind() { - case reflect.Slice: - typ = typ.Elem() - - if _, ok := field.Value.([]byte); !ok { - foreignKey := scope_typ.Name() + "Id" - if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { - field.ForeignKey = foreignKey - } - field.AfterAssociation = true + if tag == "-" { + field.IsIgnored = true } - case reflect.Struct: - if !field.IsTime() && !field.IsScanner() { - if scope.HasColumn(field.Name + "Id") { - field.ForeignKey = field.Name + "Id" - field.BeforeAssociation = true - } else { - foreignKey := scope_typ.Name() + "Id" + + // parse association + elem := reflect.Indirect(value) + typ := elem.Type() + + switch elem.Kind() { + case reflect.Slice: + typ = typ.Elem() + + if _, ok := field.Value.([]byte); !ok { + foreignKey := scopeTyp.Name() + "Id" if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { field.ForeignKey = foreignKey } field.AfterAssociation = true } + case reflect.Struct: + if !field.IsTime() && !field.IsScanner() { + if scope.HasColumn(field.Name + "Id") { + field.ForeignKey = field.Name + "Id" + field.BeforeAssociation = true + } else { + foreignKey := scopeTyp.Name() + "Id" + if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { + field.ForeignKey = foreignKey + } + field.AfterAssociation = true + } + } } } fields = append(fields, &field) @@ -276,8 +329,9 @@ func (scope *Scope) Get(name string) (value interface{}, ok bool) { return } -func (scope *Scope) Set(name string, value interface{}) { +func (scope *Scope) Set(name string, value interface{}) *Scope { scope._values[name] = value + return scope } func (scope *Scope) Trace(t time.Time) { diff --git a/utils.go b/utils.go index 6df5a86a..6acad861 100644 --- a/utils.go +++ b/utils.go @@ -124,3 +124,37 @@ func setFieldValue(field reflect.Value, value interface{}) bool { func isBlank(value reflect.Value) bool { return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) } + +func convertInterfaceToMap(values interface{}) map[string]interface{} { + attrs := map[string]interface{}{} + + switch value := values.(type) { + case map[string]interface{}: + for k, v := range value { + attrs[toSnake(k)] = v + } + case []interface{}: + for _, v := range value { + for key, value := range convertInterfaceToMap(v) { + attrs[key] = value + } + } + case interface{}: + reflectValue := reflect.ValueOf(values) + + switch reflectValue.Kind() { + case reflect.Map: + for _, key := range reflectValue.MapKeys() { + attrs[toSnake(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() + } + default: + scope := Scope{Value: values} + for _, field := range scope.Fields() { + if !field.IsBlank { + attrs[field.DBName] = field.Value + } + } + } + } + return attrs +}