forked from mirror/gorm
Only update non blank fields that has been changed
This commit is contained in:
parent
52ae6df6fd
commit
6bd0862811
|
@ -177,7 +177,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
|
||||||
modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
|
modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
|
||||||
if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
|
if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
|
||||||
if results.RowsAffected > 0 {
|
if results.RowsAffected > 0 {
|
||||||
scope.updatedAttrsWithValues(foreignKeyMap, false)
|
scope.updatedAttrsWithValues(foreignKeyMap)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
association.setErr(results.Error)
|
association.setErr(results.Error)
|
||||||
|
|
|
@ -22,17 +22,10 @@ func init() {
|
||||||
func assignUpdatingAttributesCallback(scope *Scope) {
|
func assignUpdatingAttributesCallback(scope *Scope) {
|
||||||
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
|
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
|
||||||
if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
|
if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
|
||||||
protected, ok := scope.Get("gorm:ignore_protected_attrs")
|
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(maps); hasUpdate {
|
||||||
_, updateColumn := scope.Get("gorm:update_column")
|
scope.InstanceSet("gorm:update_attrs", updateMaps)
|
||||||
updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool))
|
} else {
|
||||||
|
|
||||||
if updateColumn {
|
|
||||||
scope.InstanceSet("gorm:update_attrs", maps)
|
|
||||||
} else if len(updateAttrs) > 0 {
|
|
||||||
scope.InstanceSet("gorm:update_attrs", updateAttrs)
|
|
||||||
} else if !hasUpdate {
|
|
||||||
scope.SkipLeft()
|
scope.SkipLeft()
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -64,13 +57,7 @@ func updateCallback(scope *Scope) {
|
||||||
|
|
||||||
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
||||||
for column, value := range updateAttrs.(map[string]interface{}) {
|
for column, value := range updateAttrs.(map[string]interface{}) {
|
||||||
if field, ok := scope.FieldByName(column); ok {
|
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
|
||||||
if scope.changeableField(field) {
|
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(value)))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fields := scope.Fields()
|
fields := scope.Fields()
|
||||||
|
|
2
main.go
2
main.go
|
@ -258,7 +258,7 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
||||||
}
|
}
|
||||||
c.NewScope(out).inlineCondition(where...).initialize()
|
c.NewScope(out).inlineCondition(where...).initialize()
|
||||||
} else {
|
} else {
|
||||||
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs), false)
|
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs))
|
||||||
}
|
}
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
11
scope.go
11
scope.go
|
@ -154,20 +154,29 @@ func (scope *Scope) HasColumn(column string) bool {
|
||||||
|
|
||||||
// SetColumn to set the column's value
|
// SetColumn to set the column's value
|
||||||
func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
|
func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
|
||||||
|
var updateAttrs = map[string]interface{}{}
|
||||||
|
if attrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
||||||
|
updateAttrs = attrs.(map[string]interface{})
|
||||||
|
defer scope.InstanceSet("gorm:update_attrs", updateAttrs)
|
||||||
|
}
|
||||||
|
|
||||||
if field, ok := column.(*Field); ok {
|
if field, ok := column.(*Field); ok {
|
||||||
|
updateAttrs[field.DBName] = value
|
||||||
return field.Set(value)
|
return field.Set(value)
|
||||||
} else if name, ok := column.(string); ok {
|
} else if name, ok := column.(string); ok {
|
||||||
|
|
||||||
if field, ok := scope.Fields()[name]; ok {
|
if field, ok := scope.Fields()[name]; ok {
|
||||||
|
updateAttrs[field.DBName] = value
|
||||||
return field.Set(value)
|
return field.Set(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
dbName := ToDBName(name)
|
dbName := ToDBName(name)
|
||||||
if field, ok := scope.Fields()[dbName]; ok {
|
if field, ok := scope.Fields()[dbName]; ok {
|
||||||
|
updateAttrs[field.DBName] = value
|
||||||
return field.Set(value)
|
return field.Set(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
if field, ok := scope.FieldByName(name); ok {
|
if field, ok := scope.FieldByName(name); ok {
|
||||||
|
updateAttrs[field.DBName] = value
|
||||||
return field.Set(value)
|
return field.Set(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -319,38 +319,30 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) {
|
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) {
|
||||||
if !scope.IndirectValue().CanAddr() {
|
if scope.IndirectValue().Kind() != reflect.Struct {
|
||||||
return values, true
|
return values, true
|
||||||
}
|
}
|
||||||
|
|
||||||
var hasExpr bool
|
results = map[string]interface{}{}
|
||||||
for key, value := range values {
|
for key, value := range values {
|
||||||
if field, ok := scope.FieldByName(key); ok && field.Field.IsValid() {
|
if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
|
||||||
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
|
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
|
||||||
if _, ok := value.(*expr); ok {
|
if field.IsNormal {
|
||||||
hasExpr = true
|
if _, ok := value.(*expr); ok {
|
||||||
} else if !equalAsString(field.Field.Interface(), value) {
|
hasUpdate = true
|
||||||
hasUpdate = true
|
results[field.DBName] = value
|
||||||
|
} else if !equalAsString(field.Field.Interface(), value) {
|
||||||
|
hasUpdate = true
|
||||||
|
field.Set(value)
|
||||||
|
results[field.DBName] = field.Field.Interface()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
field.Set(value)
|
field.Set(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasExpr {
|
|
||||||
var updateMap = map[string]interface{}{}
|
|
||||||
for key, field := range scope.Fields() {
|
|
||||||
if field.IsNormal {
|
|
||||||
if v, ok := values[key]; ok {
|
|
||||||
updateMap[key] = v
|
|
||||||
} else {
|
|
||||||
updateMap[key] = field.Field.Interface()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return updateMap, true
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -370,10 +362,10 @@ func (scope *Scope) rows() (*sql.Rows, error) {
|
||||||
|
|
||||||
func (scope *Scope) initialize() *Scope {
|
func (scope *Scope) initialize() *Scope {
|
||||||
for _, clause := range scope.Search.whereConditions {
|
for _, clause := range scope.Search.whereConditions {
|
||||||
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false)
|
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]))
|
||||||
}
|
}
|
||||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs), false)
|
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs))
|
||||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false)
|
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs))
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -71,13 +71,14 @@ func TestUpdate(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.First(&product4, product4.Id)
|
DB.First(&product4, product4.Id)
|
||||||
|
updatedAt4 := product4.UpdatedAt
|
||||||
DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50))
|
DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50))
|
||||||
var product5 Product
|
var product5 Product
|
||||||
DB.First(&product5, product4.Id)
|
DB.First(&product5, product4.Id)
|
||||||
if product5.Price != product4.Price+100-50 {
|
if product5.Price != product4.Price+100-50 {
|
||||||
t.Errorf("Update with expression")
|
t.Errorf("Update with expression")
|
||||||
}
|
}
|
||||||
if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
|
if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) {
|
||||||
t.Errorf("Update with expression should update UpdatedAt")
|
t.Errorf("Update with expression should update UpdatedAt")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -170,13 +171,15 @@ func TestUpdates(t *testing.T) {
|
||||||
t.Errorf("product2's code should be updated")
|
t.Errorf("product2's code should be updated")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
updatedAt4 := product4.UpdatedAt
|
||||||
DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)})
|
DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)})
|
||||||
var product5 Product
|
var product5 Product
|
||||||
DB.First(&product5, product4.Id)
|
DB.First(&product5, product4.Id)
|
||||||
if product5.Price != product4.Price+100 {
|
if product5.Price != product4.Price+100 {
|
||||||
t.Errorf("Updates with expression")
|
t.Errorf("Updates with expression")
|
||||||
}
|
}
|
||||||
if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
|
// product4's UpdatedAt will be reset when updating
|
||||||
|
if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) {
|
||||||
t.Errorf("Updates with expression should update UpdatedAt")
|
t.Errorf("Updates with expression should update UpdatedAt")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -421,8 +424,6 @@ func TestUpdateColumnsSkipsAssociations(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdatesWithBlankValues(t *testing.T) {
|
func TestUpdatesWithBlankValues(t *testing.T) {
|
||||||
t.Skip("not implemented")
|
|
||||||
|
|
||||||
product := Product{Code: "product1", Price: 10}
|
product := Product{Code: "product1", Price: 10}
|
||||||
DB.Save(&product)
|
DB.Save(&product)
|
||||||
|
|
||||||
|
|
4
utils.go
4
utils.go
|
@ -192,9 +192,7 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} {
|
||||||
|
|
||||||
switch value := values.(type) {
|
switch value := values.(type) {
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
for k, v := range value {
|
return value
|
||||||
attrs[k] = v
|
|
||||||
}
|
|
||||||
case []interface{}:
|
case []interface{}:
|
||||||
for _, v := range value {
|
for _, v := range value {
|
||||||
for key, value := range convertInterfaceToMap(v) {
|
for key, value := range convertInterfaceToMap(v) {
|
||||||
|
|
Loading…
Reference in New Issue