Refactor Scope updatedAttrsWithValues

This commit is contained in:
Jinzhu 2016-03-09 16:18:01 +08:00
parent a0aa21aec5
commit 8de97c2883
5 changed files with 50 additions and 73 deletions

View File

@ -21,15 +21,13 @@ func init() {
// assignUpdatingAttributesCallback assign updating attributes to model // assignUpdatingAttributesCallback assign updating attributes to model
func assignUpdatingAttributesCallback(scope *Scope) { func assignUpdatingAttributesCallback(scope *Scope) {
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
if maps := convertInterfaceToMap(attrs); len(maps) > 0 { if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(maps); hasUpdate {
scope.InstanceSet("gorm:update_attrs", updateMaps) scope.InstanceSet("gorm:update_attrs", updateMaps)
} else { } else {
scope.SkipLeft() scope.SkipLeft()
} }
} }
} }
}
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating // beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
func beforeUpdateCallback(scope *Scope) { func beforeUpdateCallback(scope *Scope) {

View File

@ -310,7 +310,7 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
} }
c.NewScope(out).inlineCondition(where...).initialize() c.NewScope(out).inlineCondition(where...).initialize()
} else { } else {
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs)) c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs)
} }
return c return c
} }

View File

@ -793,28 +793,56 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
return scope return scope
} }
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) { func convertInterfaceToMap(values interface{}) map[string]interface{} {
var attrs = map[string]interface{}{}
switch value := values.(type) {
case map[string]interface{}:
return value
case []interface{}:
for _, v := range value {
for key, value := range convertInterfaceToMap(v) {
attrs[key] = value
}
}
case interface{}:
reflectValue := reflect.ValueOf(values)
switch reflectValue.Kind() {
case reflect.Map:
for _, key := range reflectValue.MapKeys() {
attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
}
default:
for _, field := range (&Scope{Value: values}).Fields() {
if !field.IsBlank {
attrs[field.DBName] = field.Field.Interface()
}
}
}
}
return attrs
}
func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) {
if scope.IndirectValue().Kind() != reflect.Struct { if scope.IndirectValue().Kind() != reflect.Struct {
return values, true return convertInterfaceToMap(value), true
} }
results = map[string]interface{}{} results = map[string]interface{}{}
for key, value := range values {
for key, value := range convertInterfaceToMap(value) {
if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
if _, ok := value.(*expr); ok { if _, ok := value.(*expr); ok {
hasUpdate = true hasUpdate = true
results[field.DBName] = value results[field.DBName] = value
} else if !equalAsString(field.Field.Interface(), value) { } else {
field.Set(value) field.Set(value)
if field.IsNormal { if field.IsNormal {
hasUpdate = true hasUpdate = true
results[field.DBName] = field.Field.Interface() results[field.DBName] = field.Field.Interface()
} }
} }
} else {
field.Set(value)
}
} }
} }
return return
@ -836,10 +864,10 @@ func (scope *Scope) rows() (*sql.Rows, error) {
func (scope *Scope) initialize() *Scope { func (scope *Scope) initialize() *Scope {
for _, clause := range scope.Search.whereConditions { for _, clause := range scope.Search.whereConditions {
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"])) scope.updatedAttrsWithValues(clause["query"])
} }
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs)) scope.updatedAttrsWithValues(scope.Search.initAttrs)
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs)) scope.updatedAttrsWithValues(scope.Search.assignAttrs)
return scope return scope
} }

View File

@ -20,13 +20,6 @@ func TestUpdate(t *testing.T) {
DB.First(&product1, product1.Id) DB.First(&product1, product1.Id)
DB.First(&product2, product2.Id) DB.First(&product2, product2.Id)
updatedAt1 := product1.UpdatedAt updatedAt1 := product1.UpdatedAt
updatedAt2 := product2.UpdatedAt
var product3 Product
DB.First(&product3, product2.Id).Update("code", "product2newcode")
if updatedAt2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("updatedAt should not be updated if nothing changed")
}
if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() { if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() {
t.Errorf("Product1 should not be updated") t.Errorf("Product1 should not be updated")
@ -135,19 +128,8 @@ func TestUpdates(t *testing.T) {
DB.First(&product1, product1.Id) DB.First(&product1, product1.Id)
DB.First(&product2, product2.Id) DB.First(&product2, product2.Id)
updatedAt1 := product1.UpdatedAt
updatedAt2 := product2.UpdatedAt updatedAt2 := product2.UpdatedAt
var product3 Product
DB.First(&product3, product1.Id).Updates(Product{Code: "product1newcode", Price: 100})
if product3.Code != "product1newcode" || product3.Price != 100 {
t.Errorf("Record should be updated with struct")
}
if updatedAt1.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("updatedAt should not be updated if nothing changed")
}
if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() { if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() {
t.Errorf("Product2 should not be updated") t.Errorf("Product2 should not be updated")
} }

View File

@ -199,37 +199,6 @@ func toSearchableMap(attrs ...interface{}) (result interface{}) {
return return
} }
func convertInterfaceToMap(values interface{}) map[string]interface{} {
attrs := map[string]interface{}{}
switch value := values.(type) {
case map[string]interface{}:
return value
case []interface{}:
for _, v := range value {
for key, value := range convertInterfaceToMap(v) {
attrs[key] = value
}
}
case interface{}:
reflectValue := reflect.ValueOf(values)
switch reflectValue.Kind() {
case reflect.Map:
for _, key := range reflectValue.MapKeys() {
attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
}
default:
for _, field := range (&Scope{Value: values}).Fields() {
if !field.IsBlank {
attrs[field.DBName] = field.Field.Interface()
}
}
}
}
return attrs
}
func equalAsString(a interface{}, b interface{}) bool { func equalAsString(a interface{}, b interface{}) bool {
return toString(a) == toString(b) return toString(a) == toString(b)
} }