mirror of https://github.com/go-gorm/gorm.git
Only update non blank fields that has been changed
This commit is contained in:
parent
52ae6df6fd
commit
6bd0862811
|
@ -177,7 +177,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
|
|||
modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
|
||||
if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
|
||||
if results.RowsAffected > 0 {
|
||||
scope.updatedAttrsWithValues(foreignKeyMap, false)
|
||||
scope.updatedAttrsWithValues(foreignKeyMap)
|
||||
}
|
||||
} else {
|
||||
association.setErr(results.Error)
|
||||
|
|
|
@ -22,17 +22,10 @@ func init() {
|
|||
func assignUpdatingAttributesCallback(scope *Scope) {
|
||||
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
|
||||
if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
|
||||
protected, ok := scope.Get("gorm:ignore_protected_attrs")
|
||||
_, updateColumn := scope.Get("gorm:update_column")
|
||||
updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool))
|
||||
|
||||
if updateColumn {
|
||||
scope.InstanceSet("gorm:update_attrs", maps)
|
||||
} else if len(updateAttrs) > 0 {
|
||||
scope.InstanceSet("gorm:update_attrs", updateAttrs)
|
||||
} else if !hasUpdate {
|
||||
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(maps); hasUpdate {
|
||||
scope.InstanceSet("gorm:update_attrs", updateMaps)
|
||||
} else {
|
||||
scope.SkipLeft()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -64,13 +57,7 @@ func updateCallback(scope *Scope) {
|
|||
|
||||
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
||||
for column, value := range updateAttrs.(map[string]interface{}) {
|
||||
if field, ok := scope.FieldByName(column); ok {
|
||||
if scope.changeableField(field) {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(value)))
|
||||
}
|
||||
} else {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
|
||||
}
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
|
||||
}
|
||||
} else {
|
||||
fields := scope.Fields()
|
||||
|
|
2
main.go
2
main.go
|
@ -258,7 +258,7 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
|||
}
|
||||
c.NewScope(out).inlineCondition(where...).initialize()
|
||||
} else {
|
||||
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs), false)
|
||||
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs))
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
|
11
scope.go
11
scope.go
|
@ -154,20 +154,29 @@ func (scope *Scope) HasColumn(column string) bool {
|
|||
|
||||
// SetColumn to set the column's value
|
||||
func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
|
||||
var updateAttrs = map[string]interface{}{}
|
||||
if attrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
||||
updateAttrs = attrs.(map[string]interface{})
|
||||
defer scope.InstanceSet("gorm:update_attrs", updateAttrs)
|
||||
}
|
||||
|
||||
if field, ok := column.(*Field); ok {
|
||||
updateAttrs[field.DBName] = value
|
||||
return field.Set(value)
|
||||
} else if name, ok := column.(string); ok {
|
||||
|
||||
if field, ok := scope.Fields()[name]; ok {
|
||||
updateAttrs[field.DBName] = value
|
||||
return field.Set(value)
|
||||
}
|
||||
|
||||
dbName := ToDBName(name)
|
||||
if field, ok := scope.Fields()[dbName]; ok {
|
||||
updateAttrs[field.DBName] = value
|
||||
return field.Set(value)
|
||||
}
|
||||
|
||||
if field, ok := scope.FieldByName(name); ok {
|
||||
updateAttrs[field.DBName] = value
|
||||
return field.Set(value)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -319,38 +319,30 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
|
|||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) {
|
||||
if !scope.IndirectValue().CanAddr() {
|
||||
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) {
|
||||
if scope.IndirectValue().Kind() != reflect.Struct {
|
||||
return values, true
|
||||
}
|
||||
|
||||
var hasExpr bool
|
||||
results = map[string]interface{}{}
|
||||
for key, value := range values {
|
||||
if field, ok := scope.FieldByName(key); ok && field.Field.IsValid() {
|
||||
if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
|
||||
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
|
||||
if _, ok := value.(*expr); ok {
|
||||
hasExpr = true
|
||||
} else if !equalAsString(field.Field.Interface(), value) {
|
||||
hasUpdate = true
|
||||
if field.IsNormal {
|
||||
if _, ok := value.(*expr); ok {
|
||||
hasUpdate = true
|
||||
results[field.DBName] = value
|
||||
} else if !equalAsString(field.Field.Interface(), value) {
|
||||
hasUpdate = true
|
||||
field.Set(value)
|
||||
results[field.DBName] = field.Field.Interface()
|
||||
}
|
||||
} else {
|
||||
field.Set(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasExpr {
|
||||
var updateMap = map[string]interface{}{}
|
||||
for key, field := range scope.Fields() {
|
||||
if field.IsNormal {
|
||||
if v, ok := values[key]; ok {
|
||||
updateMap[key] = v
|
||||
} else {
|
||||
updateMap[key] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
return updateMap, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -370,10 +362,10 @@ func (scope *Scope) rows() (*sql.Rows, error) {
|
|||
|
||||
func (scope *Scope) initialize() *Scope {
|
||||
for _, clause := range scope.Search.whereConditions {
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false)
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]))
|
||||
}
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs), false)
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false)
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs))
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs))
|
||||
return scope
|
||||
}
|
||||
|
||||
|
|
|
@ -71,13 +71,14 @@ func TestUpdate(t *testing.T) {
|
|||
}
|
||||
|
||||
DB.First(&product4, product4.Id)
|
||||
updatedAt4 := product4.UpdatedAt
|
||||
DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50))
|
||||
var product5 Product
|
||||
DB.First(&product5, product4.Id)
|
||||
if product5.Price != product4.Price+100-50 {
|
||||
t.Errorf("Update with expression")
|
||||
}
|
||||
if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||
if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) {
|
||||
t.Errorf("Update with expression should update UpdatedAt")
|
||||
}
|
||||
}
|
||||
|
@ -170,13 +171,15 @@ func TestUpdates(t *testing.T) {
|
|||
t.Errorf("product2's code should be updated")
|
||||
}
|
||||
|
||||
updatedAt4 := product4.UpdatedAt
|
||||
DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)})
|
||||
var product5 Product
|
||||
DB.First(&product5, product4.Id)
|
||||
if product5.Price != product4.Price+100 {
|
||||
t.Errorf("Updates with expression")
|
||||
}
|
||||
if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||
// product4's UpdatedAt will be reset when updating
|
||||
if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) {
|
||||
t.Errorf("Updates with expression should update UpdatedAt")
|
||||
}
|
||||
}
|
||||
|
@ -421,8 +424,6 @@ func TestUpdateColumnsSkipsAssociations(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestUpdatesWithBlankValues(t *testing.T) {
|
||||
t.Skip("not implemented")
|
||||
|
||||
product := Product{Code: "product1", Price: 10}
|
||||
DB.Save(&product)
|
||||
|
||||
|
|
4
utils.go
4
utils.go
|
@ -192,9 +192,7 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} {
|
|||
|
||||
switch value := values.(type) {
|
||||
case map[string]interface{}:
|
||||
for k, v := range value {
|
||||
attrs[k] = v
|
||||
}
|
||||
return value
|
||||
case []interface{}:
|
||||
for _, v := range value {
|
||||
for key, value := range convertInterfaceToMap(v) {
|
||||
|
|
Loading…
Reference in New Issue