Only update non blank fields that has been changed

This commit is contained in:
Jinzhu 2016-02-18 22:24:35 +08:00
parent 52ae6df6fd
commit 6bd0862811
7 changed files with 39 additions and 52 deletions

View File

@ -177,7 +177,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface() modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil { if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
if results.RowsAffected > 0 { if results.RowsAffected > 0 {
scope.updatedAttrsWithValues(foreignKeyMap, false) scope.updatedAttrsWithValues(foreignKeyMap)
} }
} else { } else {
association.setErr(results.Error) association.setErr(results.Error)

View File

@ -22,17 +22,10 @@ func init() {
func assignUpdatingAttributesCallback(scope *Scope) { func assignUpdatingAttributesCallback(scope *Scope) {
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
if maps := convertInterfaceToMap(attrs); len(maps) > 0 { if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
protected, ok := scope.Get("gorm:ignore_protected_attrs") if updateMaps, hasUpdate := scope.updatedAttrsWithValues(maps); hasUpdate {
_, updateColumn := scope.Get("gorm:update_column") scope.InstanceSet("gorm:update_attrs", updateMaps)
updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool)) } else {
if updateColumn {
scope.InstanceSet("gorm:update_attrs", maps)
} else if len(updateAttrs) > 0 {
scope.InstanceSet("gorm:update_attrs", updateAttrs)
} else if !hasUpdate {
scope.SkipLeft() scope.SkipLeft()
return
} }
} }
} }
@ -64,14 +57,8 @@ func updateCallback(scope *Scope) {
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
for column, value := range updateAttrs.(map[string]interface{}) { for column, value := range updateAttrs.(map[string]interface{}) {
if field, ok := scope.FieldByName(column); ok {
if scope.changeableField(field) {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(value)))
}
} else {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
} }
}
} else { } else {
fields := scope.Fields() fields := scope.Fields()
for _, field := range fields { for _, field := range fields {

View File

@ -258,7 +258,7 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
} }
c.NewScope(out).inlineCondition(where...).initialize() c.NewScope(out).inlineCondition(where...).initialize()
} else { } else {
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs), false) c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs))
} }
return c return c
} }

View File

@ -154,20 +154,29 @@ func (scope *Scope) HasColumn(column string) bool {
// SetColumn to set the column's value // SetColumn to set the column's value
func (scope *Scope) SetColumn(column interface{}, value interface{}) error { func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
var updateAttrs = map[string]interface{}{}
if attrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
updateAttrs = attrs.(map[string]interface{})
defer scope.InstanceSet("gorm:update_attrs", updateAttrs)
}
if field, ok := column.(*Field); ok { if field, ok := column.(*Field); ok {
updateAttrs[field.DBName] = value
return field.Set(value) return field.Set(value)
} else if name, ok := column.(string); ok { } else if name, ok := column.(string); ok {
if field, ok := scope.Fields()[name]; ok { if field, ok := scope.Fields()[name]; ok {
updateAttrs[field.DBName] = value
return field.Set(value) return field.Set(value)
} }
dbName := ToDBName(name) dbName := ToDBName(name)
if field, ok := scope.Fields()[dbName]; ok { if field, ok := scope.Fields()[dbName]; ok {
updateAttrs[field.DBName] = value
return field.Set(value) return field.Set(value)
} }
if field, ok := scope.FieldByName(name); ok { if field, ok := scope.FieldByName(name); ok {
updateAttrs[field.DBName] = value
return field.Set(value) return field.Set(value)
} }
} }

View File

@ -319,38 +319,30 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
return scope return scope
} }
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) { func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) {
if !scope.IndirectValue().CanAddr() { if scope.IndirectValue().Kind() != reflect.Struct {
return values, true return values, true
} }
var hasExpr bool results = map[string]interface{}{}
for key, value := range values { for key, value := range values {
if field, ok := scope.FieldByName(key); ok && field.Field.IsValid() { if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) { if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
if field.IsNormal {
if _, ok := value.(*expr); ok { if _, ok := value.(*expr); ok {
hasExpr = true hasUpdate = true
results[field.DBName] = value
} else if !equalAsString(field.Field.Interface(), value) { } else if !equalAsString(field.Field.Interface(), value) {
hasUpdate = true hasUpdate = true
field.Set(value) field.Set(value)
results[field.DBName] = field.Field.Interface()
}
} else {
field.Set(value)
} }
} }
} }
} }
if hasExpr {
var updateMap = map[string]interface{}{}
for key, field := range scope.Fields() {
if field.IsNormal {
if v, ok := values[key]; ok {
updateMap[key] = v
} else {
updateMap[key] = field.Field.Interface()
}
}
}
return updateMap, true
}
return return
} }
@ -370,10 +362,10 @@ func (scope *Scope) rows() (*sql.Rows, error) {
func (scope *Scope) initialize() *Scope { func (scope *Scope) initialize() *Scope {
for _, clause := range scope.Search.whereConditions { for _, clause := range scope.Search.whereConditions {
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false) scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]))
} }
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs), false) scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs))
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false) scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs))
return scope return scope
} }

View File

@ -71,13 +71,14 @@ func TestUpdate(t *testing.T) {
} }
DB.First(&product4, product4.Id) DB.First(&product4, product4.Id)
updatedAt4 := product4.UpdatedAt
DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50)) DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50))
var product5 Product var product5 Product
DB.First(&product5, product4.Id) DB.First(&product5, product4.Id)
if product5.Price != product4.Price+100-50 { if product5.Price != product4.Price+100-50 {
t.Errorf("Update with expression") t.Errorf("Update with expression")
} }
if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) { if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) {
t.Errorf("Update with expression should update UpdatedAt") t.Errorf("Update with expression should update UpdatedAt")
} }
} }
@ -170,13 +171,15 @@ func TestUpdates(t *testing.T) {
t.Errorf("product2's code should be updated") t.Errorf("product2's code should be updated")
} }
updatedAt4 := product4.UpdatedAt
DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)}) DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)})
var product5 Product var product5 Product
DB.First(&product5, product4.Id) DB.First(&product5, product4.Id)
if product5.Price != product4.Price+100 { if product5.Price != product4.Price+100 {
t.Errorf("Updates with expression") t.Errorf("Updates with expression")
} }
if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) { // product4's UpdatedAt will be reset when updating
if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) {
t.Errorf("Updates with expression should update UpdatedAt") t.Errorf("Updates with expression should update UpdatedAt")
} }
} }
@ -421,8 +424,6 @@ func TestUpdateColumnsSkipsAssociations(t *testing.T) {
} }
func TestUpdatesWithBlankValues(t *testing.T) { func TestUpdatesWithBlankValues(t *testing.T) {
t.Skip("not implemented")
product := Product{Code: "product1", Price: 10} product := Product{Code: "product1", Price: 10}
DB.Save(&product) DB.Save(&product)

View File

@ -192,9 +192,7 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} {
switch value := values.(type) { switch value := values.(type) {
case map[string]interface{}: case map[string]interface{}:
for k, v := range value { return value
attrs[k] = v
}
case []interface{}: case []interface{}:
for _, v := range value { for _, v := range value {
for key, value := range convertInterfaceToMap(v) { for key, value := range convertInterfaceToMap(v) {