mirror of https://github.com/go-gorm/gorm.git
Refactor Scope updatedAttrsWithValues
This commit is contained in:
parent
a0aa21aec5
commit
8de97c2883
|
@ -21,14 +21,12 @@ func init() {
|
|||
// assignUpdatingAttributesCallback assign updating attributes to model
|
||||
func assignUpdatingAttributesCallback(scope *Scope) {
|
||||
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
|
||||
if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
|
||||
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(maps); hasUpdate {
|
||||
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
|
||||
scope.InstanceSet("gorm:update_attrs", updateMaps)
|
||||
} else {
|
||||
scope.SkipLeft()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
|
||||
|
|
2
main.go
2
main.go
|
@ -310,7 +310,7 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
|||
}
|
||||
c.NewScope(out).inlineCondition(where...).initialize()
|
||||
} else {
|
||||
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs))
|
||||
c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
|
50
scope.go
50
scope.go
|
@ -793,28 +793,56 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
|
|||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) {
|
||||
func convertInterfaceToMap(values interface{}) map[string]interface{} {
|
||||
var attrs = map[string]interface{}{}
|
||||
|
||||
switch value := values.(type) {
|
||||
case map[string]interface{}:
|
||||
return value
|
||||
case []interface{}:
|
||||
for _, v := range value {
|
||||
for key, value := range convertInterfaceToMap(v) {
|
||||
attrs[key] = value
|
||||
}
|
||||
}
|
||||
case interface{}:
|
||||
reflectValue := reflect.ValueOf(values)
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Map:
|
||||
for _, key := range reflectValue.MapKeys() {
|
||||
attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
|
||||
}
|
||||
default:
|
||||
for _, field := range (&Scope{Value: values}).Fields() {
|
||||
if !field.IsBlank {
|
||||
attrs[field.DBName] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return attrs
|
||||
}
|
||||
|
||||
func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) {
|
||||
if scope.IndirectValue().Kind() != reflect.Struct {
|
||||
return values, true
|
||||
return convertInterfaceToMap(value), true
|
||||
}
|
||||
|
||||
results = map[string]interface{}{}
|
||||
for key, value := range values {
|
||||
|
||||
for key, value := range convertInterfaceToMap(value) {
|
||||
if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
|
||||
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
|
||||
if _, ok := value.(*expr); ok {
|
||||
hasUpdate = true
|
||||
results[field.DBName] = value
|
||||
} else if !equalAsString(field.Field.Interface(), value) {
|
||||
} else {
|
||||
field.Set(value)
|
||||
if field.IsNormal {
|
||||
hasUpdate = true
|
||||
results[field.DBName] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
field.Set(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
|
@ -836,10 +864,10 @@ func (scope *Scope) rows() (*sql.Rows, error) {
|
|||
|
||||
func (scope *Scope) initialize() *Scope {
|
||||
for _, clause := range scope.Search.whereConditions {
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]))
|
||||
scope.updatedAttrsWithValues(clause["query"])
|
||||
}
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs))
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs))
|
||||
scope.updatedAttrsWithValues(scope.Search.initAttrs)
|
||||
scope.updatedAttrsWithValues(scope.Search.assignAttrs)
|
||||
return scope
|
||||
}
|
||||
|
||||
|
|
|
@ -20,13 +20,6 @@ func TestUpdate(t *testing.T) {
|
|||
DB.First(&product1, product1.Id)
|
||||
DB.First(&product2, product2.Id)
|
||||
updatedAt1 := product1.UpdatedAt
|
||||
updatedAt2 := product2.UpdatedAt
|
||||
|
||||
var product3 Product
|
||||
DB.First(&product3, product2.Id).Update("code", "product2newcode")
|
||||
if updatedAt2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
|
||||
t.Errorf("updatedAt should not be updated if nothing changed")
|
||||
}
|
||||
|
||||
if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() {
|
||||
t.Errorf("Product1 should not be updated")
|
||||
|
@ -135,19 +128,8 @@ func TestUpdates(t *testing.T) {
|
|||
|
||||
DB.First(&product1, product1.Id)
|
||||
DB.First(&product2, product2.Id)
|
||||
updatedAt1 := product1.UpdatedAt
|
||||
updatedAt2 := product2.UpdatedAt
|
||||
|
||||
var product3 Product
|
||||
DB.First(&product3, product1.Id).Updates(Product{Code: "product1newcode", Price: 100})
|
||||
if product3.Code != "product1newcode" || product3.Price != 100 {
|
||||
t.Errorf("Record should be updated with struct")
|
||||
}
|
||||
|
||||
if updatedAt1.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
|
||||
t.Errorf("updatedAt should not be updated if nothing changed")
|
||||
}
|
||||
|
||||
if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() {
|
||||
t.Errorf("Product2 should not be updated")
|
||||
}
|
||||
|
|
31
utils.go
31
utils.go
|
@ -199,37 +199,6 @@ func toSearchableMap(attrs ...interface{}) (result interface{}) {
|
|||
return
|
||||
}
|
||||
|
||||
func convertInterfaceToMap(values interface{}) map[string]interface{} {
|
||||
attrs := map[string]interface{}{}
|
||||
|
||||
switch value := values.(type) {
|
||||
case map[string]interface{}:
|
||||
return value
|
||||
case []interface{}:
|
||||
for _, v := range value {
|
||||
for key, value := range convertInterfaceToMap(v) {
|
||||
attrs[key] = value
|
||||
}
|
||||
}
|
||||
case interface{}:
|
||||
reflectValue := reflect.ValueOf(values)
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Map:
|
||||
for _, key := range reflectValue.MapKeys() {
|
||||
attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
|
||||
}
|
||||
default:
|
||||
for _, field := range (&Scope{Value: values}).Fields() {
|
||||
if !field.IsBlank {
|
||||
attrs[field.DBName] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return attrs
|
||||
}
|
||||
|
||||
func equalAsString(a interface{}, b interface{}) bool {
|
||||
return toString(a) == toString(b)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue