mirror of https://github.com/go-gorm/gorm.git
Refactor Scope updatedAttrsWithValues
This commit is contained in:
parent
a0aa21aec5
commit
8de97c2883
|
@ -21,15 +21,13 @@ func init() {
|
||||||
// assignUpdatingAttributesCallback assign updating attributes to model
|
// assignUpdatingAttributesCallback assign updating attributes to model
|
||||||
func assignUpdatingAttributesCallback(scope *Scope) {
|
func assignUpdatingAttributesCallback(scope *Scope) {
|
||||||
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
|
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
|
||||||
if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
|
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
|
||||||
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(maps); hasUpdate {
|
|
||||||
scope.InstanceSet("gorm:update_attrs", updateMaps)
|
scope.InstanceSet("gorm:update_attrs", updateMaps)
|
||||||
} else {
|
} else {
|
||||||
scope.SkipLeft()
|
scope.SkipLeft()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
|
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
|
||||||
func beforeUpdateCallback(scope *Scope) {
|
func beforeUpdateCallback(scope *Scope) {
|
||||||
|
|
2
main.go
2
main.go
|
@ -310,7 +310,7 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
||||||
}
|
}
|
||||||
c.NewScope(out).inlineCondition(where...).initialize()
|
c.NewScope(out).inlineCondition(where...).initialize()
|
||||||
} else {
|
} else {
|
||||||
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs))
|
c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs)
|
||||||
}
|
}
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
50
scope.go
50
scope.go
|
@ -793,28 +793,56 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) {
|
func convertInterfaceToMap(values interface{}) map[string]interface{} {
|
||||||
|
var attrs = map[string]interface{}{}
|
||||||
|
|
||||||
|
switch value := values.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
return value
|
||||||
|
case []interface{}:
|
||||||
|
for _, v := range value {
|
||||||
|
for key, value := range convertInterfaceToMap(v) {
|
||||||
|
attrs[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case interface{}:
|
||||||
|
reflectValue := reflect.ValueOf(values)
|
||||||
|
|
||||||
|
switch reflectValue.Kind() {
|
||||||
|
case reflect.Map:
|
||||||
|
for _, key := range reflectValue.MapKeys() {
|
||||||
|
attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
for _, field := range (&Scope{Value: values}).Fields() {
|
||||||
|
if !field.IsBlank {
|
||||||
|
attrs[field.DBName] = field.Field.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return attrs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) {
|
||||||
if scope.IndirectValue().Kind() != reflect.Struct {
|
if scope.IndirectValue().Kind() != reflect.Struct {
|
||||||
return values, true
|
return convertInterfaceToMap(value), true
|
||||||
}
|
}
|
||||||
|
|
||||||
results = map[string]interface{}{}
|
results = map[string]interface{}{}
|
||||||
for key, value := range values {
|
|
||||||
|
for key, value := range convertInterfaceToMap(value) {
|
||||||
if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
|
if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
|
||||||
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
|
|
||||||
if _, ok := value.(*expr); ok {
|
if _, ok := value.(*expr); ok {
|
||||||
hasUpdate = true
|
hasUpdate = true
|
||||||
results[field.DBName] = value
|
results[field.DBName] = value
|
||||||
} else if !equalAsString(field.Field.Interface(), value) {
|
} else {
|
||||||
field.Set(value)
|
field.Set(value)
|
||||||
if field.IsNormal {
|
if field.IsNormal {
|
||||||
hasUpdate = true
|
hasUpdate = true
|
||||||
results[field.DBName] = field.Field.Interface()
|
results[field.DBName] = field.Field.Interface()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
field.Set(value)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -836,10 +864,10 @@ func (scope *Scope) rows() (*sql.Rows, error) {
|
||||||
|
|
||||||
func (scope *Scope) initialize() *Scope {
|
func (scope *Scope) initialize() *Scope {
|
||||||
for _, clause := range scope.Search.whereConditions {
|
for _, clause := range scope.Search.whereConditions {
|
||||||
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]))
|
scope.updatedAttrsWithValues(clause["query"])
|
||||||
}
|
}
|
||||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs))
|
scope.updatedAttrsWithValues(scope.Search.initAttrs)
|
||||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs))
|
scope.updatedAttrsWithValues(scope.Search.assignAttrs)
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,13 +20,6 @@ func TestUpdate(t *testing.T) {
|
||||||
DB.First(&product1, product1.Id)
|
DB.First(&product1, product1.Id)
|
||||||
DB.First(&product2, product2.Id)
|
DB.First(&product2, product2.Id)
|
||||||
updatedAt1 := product1.UpdatedAt
|
updatedAt1 := product1.UpdatedAt
|
||||||
updatedAt2 := product2.UpdatedAt
|
|
||||||
|
|
||||||
var product3 Product
|
|
||||||
DB.First(&product3, product2.Id).Update("code", "product2newcode")
|
|
||||||
if updatedAt2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
|
|
||||||
t.Errorf("updatedAt should not be updated if nothing changed")
|
|
||||||
}
|
|
||||||
|
|
||||||
if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() {
|
if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() {
|
||||||
t.Errorf("Product1 should not be updated")
|
t.Errorf("Product1 should not be updated")
|
||||||
|
@ -135,19 +128,8 @@ func TestUpdates(t *testing.T) {
|
||||||
|
|
||||||
DB.First(&product1, product1.Id)
|
DB.First(&product1, product1.Id)
|
||||||
DB.First(&product2, product2.Id)
|
DB.First(&product2, product2.Id)
|
||||||
updatedAt1 := product1.UpdatedAt
|
|
||||||
updatedAt2 := product2.UpdatedAt
|
updatedAt2 := product2.UpdatedAt
|
||||||
|
|
||||||
var product3 Product
|
|
||||||
DB.First(&product3, product1.Id).Updates(Product{Code: "product1newcode", Price: 100})
|
|
||||||
if product3.Code != "product1newcode" || product3.Price != 100 {
|
|
||||||
t.Errorf("Record should be updated with struct")
|
|
||||||
}
|
|
||||||
|
|
||||||
if updatedAt1.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
|
|
||||||
t.Errorf("updatedAt should not be updated if nothing changed")
|
|
||||||
}
|
|
||||||
|
|
||||||
if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() {
|
if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() {
|
||||||
t.Errorf("Product2 should not be updated")
|
t.Errorf("Product2 should not be updated")
|
||||||
}
|
}
|
||||||
|
|
31
utils.go
31
utils.go
|
@ -199,37 +199,6 @@ func toSearchableMap(attrs ...interface{}) (result interface{}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertInterfaceToMap(values interface{}) map[string]interface{} {
|
|
||||||
attrs := map[string]interface{}{}
|
|
||||||
|
|
||||||
switch value := values.(type) {
|
|
||||||
case map[string]interface{}:
|
|
||||||
return value
|
|
||||||
case []interface{}:
|
|
||||||
for _, v := range value {
|
|
||||||
for key, value := range convertInterfaceToMap(v) {
|
|
||||||
attrs[key] = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case interface{}:
|
|
||||||
reflectValue := reflect.ValueOf(values)
|
|
||||||
|
|
||||||
switch reflectValue.Kind() {
|
|
||||||
case reflect.Map:
|
|
||||||
for _, key := range reflectValue.MapKeys() {
|
|
||||||
attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
for _, field := range (&Scope{Value: values}).Fields() {
|
|
||||||
if !field.IsBlank {
|
|
||||||
attrs[field.DBName] = field.Field.Interface()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return attrs
|
|
||||||
}
|
|
||||||
|
|
||||||
func equalAsString(a interface{}, b interface{}) bool {
|
func equalAsString(a interface{}, b interface{}) bool {
|
||||||
return toString(a) == toString(b)
|
return toString(a) == toString(b)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue