From f5566288de9b58172f4796053055abde57988b7f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 30 Jun 2020 16:53:54 +0800 Subject: [PATCH] Add SetColumn, Changed method --- callbacks/associations.go | 4 +- callbacks/create.go | 2 +- callbacks/helper.go | 58 +------------------ callbacks/update.go | 2 +- errors.go | 2 + statement.go | 117 ++++++++++++++++++++++++++++++++++++++ tests/hooks_test.go | 81 ++++++++++++++++++++++++++ utils/utils.go | 15 +++++ 8 files changed, 221 insertions(+), 60 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 3ff0f4b0..bcb6c414 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -11,7 +11,7 @@ import ( func SaveBeforeAssociations(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil { - selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) + selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) // Save Belongs To associations for _, rel := range db.Statement.Schema.Relationships.BelongsTo { @@ -90,7 +90,7 @@ func SaveBeforeAssociations(db *gorm.DB) { func SaveAfterAssociations(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil { - selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) + selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) // Save Has One associations for _, rel := range db.Statement.Schema.Relationships.HasOne { diff --git a/callbacks/create.go b/callbacks/create.go index 283d3fd1..eecb80a1 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -218,7 +218,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { values = ConvertSliceOfMapToValuesForCreate(stmt, value) default: var ( - selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) + selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) curTime = stmt.DB.NowFunc() isZero bool ) diff --git a/callbacks/helper.go b/callbacks/helper.go index 3b0cca16..1b06e0b7 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -7,64 +7,10 @@ import ( "gorm.io/gorm/clause" ) -// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false -func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate bool) (map[string]bool, bool) { - results := map[string]bool{} - notRestricted := false - - // select columns - for _, column := range stmt.Selects { - if column == "*" { - notRestricted = true - for _, dbName := range stmt.Schema.DBNames { - results[dbName] = true - } - } else if column == clause.Associations { - for _, rel := range stmt.Schema.Relationships.Relations { - results[rel.Name] = true - } - } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { - results[field.DBName] = true - } else { - results[column] = true - } - } - - // omit columns - for _, omit := range stmt.Omits { - if omit == clause.Associations { - for _, rel := range stmt.Schema.Relationships.Relations { - results[rel.Name] = false - } - } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { - results[field.DBName] = false - } else { - results[omit] = false - } - } - - if stmt.Schema != nil { - for _, field := range stmt.Schema.Fields { - name := field.DBName - if name == "" { - name = field.Name - } - - if requireCreate && !field.Creatable { - results[name] = false - } else if requireUpdate && !field.Updatable { - results[name] = false - } - } - } - - return results, !notRestricted && len(stmt.Selects) > 0 -} - // ConvertMapToValuesForCreate convert map to values func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { columns := make([]string, 0, len(mapValue)) - selectColumns, restricted := SelectAndOmitColumns(stmt, true, false) + selectColumns, restricted := stmt.SelectAndOmitColumns(true, false) var keys []string for k := range mapValue { @@ -91,7 +37,7 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st var ( columns = []string{} result = map[string][]interface{}{} - selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) + selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) ) for idx, mapValue := range mapValues { diff --git a/callbacks/update.go b/callbacks/update.go index 1ea77552..f84e933c 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -110,7 +110,7 @@ func AfterUpdate(db *gorm.DB) { // ConvertToAssignments convert to update assignments func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { var ( - selectColumns, restricted = SelectAndOmitColumns(stmt, false, true) + selectColumns, restricted = stmt.SelectAndOmitColumns(false, true) assignValue func(field *schema.Field, value interface{}) ) diff --git a/errors.go b/errors.go index b41eefae..e1b58835 100644 --- a/errors.go +++ b/errors.go @@ -29,4 +29,6 @@ var ( ErrUnsupportedDriver = errors.New("unsupported driver") // ErrRegistered registered ErrRegistered = errors.New("registered") + // ErrInvalidField invalid field + ErrInvalidField = errors.New("invalid field") ) diff --git a/statement.go b/statement.go index e902b739..164ddbd7 100644 --- a/statement.go +++ b/statement.go @@ -12,6 +12,7 @@ import ( "gorm.io/gorm/clause" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) // Statement statement @@ -370,3 +371,119 @@ func (stmt *Statement) clone() *Statement { return newStmt } + +// Helpers +// SetColumn set column's value +func (stmt *Statement) SetColumn(name string, value interface{}) { + if v, ok := stmt.Dest.(map[string]interface{}); ok { + v[name] = value + } else if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(name); field != nil { + field.Set(stmt.ReflectValue, value) + } else { + stmt.AddError(ErrInvalidField) + } + } else { + stmt.AddError(ErrInvalidField) + } +} + +// Changed check model changed or not when updating +func (stmt *Statement) Changed(fields ...string) bool { + modelValue := reflect.ValueOf(stmt.Model) + for modelValue.Kind() == reflect.Ptr { + modelValue = modelValue.Elem() + } + + selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) + changed := func(field *schema.Field) bool { + fieldValue, isZero := field.ValueOf(modelValue) + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if v, ok := stmt.Dest.(map[string]interface{}); ok { + if fv, ok := v[field.Name]; ok { + return !utils.AssertEqual(fv, fieldValue) + } else if fv, ok := v[field.DBName]; ok { + return !utils.AssertEqual(fv, fieldValue) + } else if isZero { + return true + } + } else { + changedValue, _ := field.ValueOf(stmt.ReflectValue) + return !utils.AssertEqual(changedValue, fieldValue) + } + } + return false + } + + if len(fields) == 0 { + for _, field := range stmt.Schema.FieldsByDBName { + if changed(field) { + return true + } + } + } else { + for _, name := range fields { + if field := stmt.Schema.LookUpField(name); field != nil { + if changed(field) { + return true + } + } + } + } + + return false +} + +// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false +func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { + results := map[string]bool{} + notRestricted := false + + // select columns + for _, column := range stmt.Selects { + if column == "*" { + notRestricted = true + for _, dbName := range stmt.Schema.DBNames { + results[dbName] = true + } + } else if column == clause.Associations { + for _, rel := range stmt.Schema.Relationships.Relations { + results[rel.Name] = true + } + } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { + results[field.DBName] = true + } else { + results[column] = true + } + } + + // omit columns + for _, omit := range stmt.Omits { + if omit == clause.Associations { + for _, rel := range stmt.Schema.Relationships.Relations { + results[rel.Name] = false + } + } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { + results[field.DBName] = false + } else { + results[omit] = false + } + } + + if stmt.Schema != nil { + for _, field := range stmt.Schema.Fields { + name := field.DBName + if name == "" { + name = field.Name + } + + if requireCreate && !field.Creatable { + results[name] = false + } else if requireUpdate && !field.Updatable { + results[name] = false + } + } + } + + return results, !notRestricted && len(stmt.Selects) > 0 +} diff --git a/tests/hooks_test.go b/tests/hooks_test.go index c74e8f10..8f8c60f5 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -285,3 +285,84 @@ func TestUseDBInHooks(t *testing.T) { t.Errorf("Admin product's price should not be changed, expects: %v, got %v", 600, result4.Price) } } + +type Product3 struct { + gorm.Model + Name string + Code string + Price int64 + Owner string +} + +func (s Product3) BeforeCreate(tx *gorm.DB) (err error) { + tx.Statement.SetColumn("Price", s.Price+100) + return nil +} + +func (s Product3) BeforeUpdate(tx *gorm.DB) (err error) { + if tx.Statement.Changed() { + tx.Statement.SetColumn("Price", s.Price+10) + } + + if tx.Statement.Changed("Code") { + s.Price += 20 + tx.Statement.SetColumn("Price", s.Price+30) + } + return nil +} + +func TestSetColumn(t *testing.T) { + DB.Migrator().DropTable(&Product3{}) + DB.AutoMigrate(&Product3{}) + + product := Product3{Name: "Product", Price: 0} + DB.Create(&product) + + if product.Price != 100 { + t.Errorf("invalid price after create, got %+v", product) + } + + DB.Model(&product).Select("code", "price").Updates(map[string]interface{}{"code": "L1212"}) + + if product.Price != 150 || product.Code != "L1212" { + t.Errorf("invalid data after update, got %+v", product) + } + + // Code not changed, price should not change + DB.Model(&product).Updates(map[string]interface{}{"Name": "Product New"}) + + if product.Name != "Product New" || product.Price != 160 || product.Code != "L1212" { + t.Errorf("invalid data after update, got %+v", product) + } + + // Code changed, but not selected, price should not change + DB.Model(&product).Select("Name", "Price").Updates(map[string]interface{}{"Name": "Product New2", "code": "L1213"}) + + if product.Name != "Product New2" || product.Price != 170 || product.Code != "L1212" { + t.Errorf("invalid data after update, got %+v", product) + } + + // Code changed, price should changed + DB.Model(&product).Select("Name", "Code", "Price").Updates(map[string]interface{}{"Name": "Product New3", "code": "L1213"}) + + if product.Name != "Product New3" || product.Price != 220 || product.Code != "L1213" { + t.Errorf("invalid data after update, got %+v", product) + } + + var result Product3 + DB.First(&result, product.ID) + + AssertEqual(t, result, product) + + // Code changed, price not selected, price should not change + DB.Model(&product).Select("code").Updates(map[string]interface{}{"name": "L1214"}) + + if product.Price != 220 || product.Code != "L1213" { + t.Errorf("invalid data after update, got %+v", product) + } + + var result2 Product3 + DB.First(&result2, product.ID) + + AssertEqual(t, result2, product) +} diff --git a/utils/utils.go b/utils/utils.go index 81d2dc34..9bf00683 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -68,3 +68,18 @@ func ToStringKey(values ...interface{}) string { return strings.Join(results, "_") } + +func AssertEqual(src, dst interface{}) bool { + if !reflect.DeepEqual(src, dst) { + if valuer, ok := src.(driver.Valuer); ok { + src, _ = valuer.Value() + } + + if valuer, ok := dst.(driver.Valuer); ok { + dst, _ = valuer.Value() + } + + return reflect.DeepEqual(src, dst) + } + return true +}