Add SetColumn, Changed method

This commit is contained in:
Jinzhu 2020-06-30 16:53:54 +08:00
parent e308b103c0
commit 66dcd7e3ca
8 changed files with 221 additions and 60 deletions

View File

@ -11,7 +11,7 @@ import (
func SaveBeforeAssociations(db *gorm.DB) { func SaveBeforeAssociations(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil { if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
// Save Belongs To associations // Save Belongs To associations
for _, rel := range db.Statement.Schema.Relationships.BelongsTo { for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
@ -90,7 +90,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
func SaveAfterAssociations(db *gorm.DB) { func SaveAfterAssociations(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil { if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
// Save Has One associations // Save Has One associations
for _, rel := range db.Statement.Schema.Relationships.HasOne { for _, rel := range db.Statement.Schema.Relationships.HasOne {

View File

@ -218,7 +218,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
values = ConvertSliceOfMapToValuesForCreate(stmt, value) values = ConvertSliceOfMapToValuesForCreate(stmt, value)
default: default:
var ( var (
selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
curTime = stmt.DB.NowFunc() curTime = stmt.DB.NowFunc()
isZero bool isZero bool
) )

View File

@ -7,64 +7,10 @@ import (
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate bool) (map[string]bool, bool) {
results := map[string]bool{}
notRestricted := false
// select columns
for _, column := range stmt.Selects {
if column == "*" {
notRestricted = true
for _, dbName := range stmt.Schema.DBNames {
results[dbName] = true
}
} else if column == clause.Associations {
for _, rel := range stmt.Schema.Relationships.Relations {
results[rel.Name] = true
}
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
results[field.DBName] = true
} else {
results[column] = true
}
}
// omit columns
for _, omit := range stmt.Omits {
if omit == clause.Associations {
for _, rel := range stmt.Schema.Relationships.Relations {
results[rel.Name] = false
}
} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
results[field.DBName] = false
} else {
results[omit] = false
}
}
if stmt.Schema != nil {
for _, field := range stmt.Schema.Fields {
name := field.DBName
if name == "" {
name = field.Name
}
if requireCreate && !field.Creatable {
results[name] = false
} else if requireUpdate && !field.Updatable {
results[name] = false
}
}
}
return results, !notRestricted && len(stmt.Selects) > 0
}
// ConvertMapToValuesForCreate convert map to values // ConvertMapToValuesForCreate convert map to values
func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) {
columns := make([]string, 0, len(mapValue)) columns := make([]string, 0, len(mapValue))
selectColumns, restricted := SelectAndOmitColumns(stmt, true, false) selectColumns, restricted := stmt.SelectAndOmitColumns(true, false)
var keys []string var keys []string
for k := range mapValue { for k := range mapValue {
@ -91,7 +37,7 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
var ( var (
columns = []string{} columns = []string{}
result = map[string][]interface{}{} result = map[string][]interface{}{}
selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
) )
for idx, mapValue := range mapValues { for idx, mapValue := range mapValues {

View File

@ -110,7 +110,7 @@ func AfterUpdate(db *gorm.DB) {
// ConvertToAssignments convert to update assignments // ConvertToAssignments convert to update assignments
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
var ( var (
selectColumns, restricted = SelectAndOmitColumns(stmt, false, true) selectColumns, restricted = stmt.SelectAndOmitColumns(false, true)
assignValue func(field *schema.Field, value interface{}) assignValue func(field *schema.Field, value interface{})
) )

View File

@ -29,4 +29,6 @@ var (
ErrUnsupportedDriver = errors.New("unsupported driver") ErrUnsupportedDriver = errors.New("unsupported driver")
// ErrRegistered registered // ErrRegistered registered
ErrRegistered = errors.New("registered") ErrRegistered = errors.New("registered")
// ErrInvalidField invalid field
ErrInvalidField = errors.New("invalid field")
) )

View File

@ -12,6 +12,7 @@ import (
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils"
) )
// Statement statement // Statement statement
@ -370,3 +371,119 @@ func (stmt *Statement) clone() *Statement {
return newStmt return newStmt
} }
// Helpers
// SetColumn set column's value
func (stmt *Statement) SetColumn(name string, value interface{}) {
if v, ok := stmt.Dest.(map[string]interface{}); ok {
v[name] = value
} else if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(name); field != nil {
field.Set(stmt.ReflectValue, value)
} else {
stmt.AddError(ErrInvalidField)
}
} else {
stmt.AddError(ErrInvalidField)
}
}
// Changed check model changed or not when updating
func (stmt *Statement) Changed(fields ...string) bool {
modelValue := reflect.ValueOf(stmt.Model)
for modelValue.Kind() == reflect.Ptr {
modelValue = modelValue.Elem()
}
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
changed := func(field *schema.Field) bool {
fieldValue, isZero := field.ValueOf(modelValue)
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if v, ok := stmt.Dest.(map[string]interface{}); ok {
if fv, ok := v[field.Name]; ok {
return !utils.AssertEqual(fv, fieldValue)
} else if fv, ok := v[field.DBName]; ok {
return !utils.AssertEqual(fv, fieldValue)
} else if isZero {
return true
}
} else {
changedValue, _ := field.ValueOf(stmt.ReflectValue)
return !utils.AssertEqual(changedValue, fieldValue)
}
}
return false
}
if len(fields) == 0 {
for _, field := range stmt.Schema.FieldsByDBName {
if changed(field) {
return true
}
}
} else {
for _, name := range fields {
if field := stmt.Schema.LookUpField(name); field != nil {
if changed(field) {
return true
}
}
}
}
return false
}
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
results := map[string]bool{}
notRestricted := false
// select columns
for _, column := range stmt.Selects {
if column == "*" {
notRestricted = true
for _, dbName := range stmt.Schema.DBNames {
results[dbName] = true
}
} else if column == clause.Associations {
for _, rel := range stmt.Schema.Relationships.Relations {
results[rel.Name] = true
}
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
results[field.DBName] = true
} else {
results[column] = true
}
}
// omit columns
for _, omit := range stmt.Omits {
if omit == clause.Associations {
for _, rel := range stmt.Schema.Relationships.Relations {
results[rel.Name] = false
}
} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
results[field.DBName] = false
} else {
results[omit] = false
}
}
if stmt.Schema != nil {
for _, field := range stmt.Schema.Fields {
name := field.DBName
if name == "" {
name = field.Name
}
if requireCreate && !field.Creatable {
results[name] = false
} else if requireUpdate && !field.Updatable {
results[name] = false
}
}
}
return results, !notRestricted && len(stmt.Selects) > 0
}

View File

@ -285,3 +285,84 @@ func TestUseDBInHooks(t *testing.T) {
t.Errorf("Admin product's price should not be changed, expects: %v, got %v", 600, result4.Price) t.Errorf("Admin product's price should not be changed, expects: %v, got %v", 600, result4.Price)
} }
} }
type Product3 struct {
gorm.Model
Name string
Code string
Price int64
Owner string
}
func (s Product3) BeforeCreate(tx *gorm.DB) (err error) {
tx.Statement.SetColumn("Price", s.Price+100)
return nil
}
func (s Product3) BeforeUpdate(tx *gorm.DB) (err error) {
if tx.Statement.Changed() {
tx.Statement.SetColumn("Price", s.Price+10)
}
if tx.Statement.Changed("Code") {
s.Price += 20
tx.Statement.SetColumn("Price", s.Price+30)
}
return nil
}
func TestSetColumn(t *testing.T) {
DB.Migrator().DropTable(&Product3{})
DB.AutoMigrate(&Product3{})
product := Product3{Name: "Product", Price: 0}
DB.Create(&product)
if product.Price != 100 {
t.Errorf("invalid price after create, got %+v", product)
}
DB.Model(&product).Select("code", "price").Updates(map[string]interface{}{"code": "L1212"})
if product.Price != 150 || product.Code != "L1212" {
t.Errorf("invalid data after update, got %+v", product)
}
// Code not changed, price should not change
DB.Model(&product).Updates(map[string]interface{}{"Name": "Product New"})
if product.Name != "Product New" || product.Price != 160 || product.Code != "L1212" {
t.Errorf("invalid data after update, got %+v", product)
}
// Code changed, but not selected, price should not change
DB.Model(&product).Select("Name", "Price").Updates(map[string]interface{}{"Name": "Product New2", "code": "L1213"})
if product.Name != "Product New2" || product.Price != 170 || product.Code != "L1212" {
t.Errorf("invalid data after update, got %+v", product)
}
// Code changed, price should changed
DB.Model(&product).Select("Name", "Code", "Price").Updates(map[string]interface{}{"Name": "Product New3", "code": "L1213"})
if product.Name != "Product New3" || product.Price != 220 || product.Code != "L1213" {
t.Errorf("invalid data after update, got %+v", product)
}
var result Product3
DB.First(&result, product.ID)
AssertEqual(t, result, product)
// Code changed, price not selected, price should not change
DB.Model(&product).Select("code").Updates(map[string]interface{}{"name": "L1214"})
if product.Price != 220 || product.Code != "L1213" {
t.Errorf("invalid data after update, got %+v", product)
}
var result2 Product3
DB.First(&result2, product.ID)
AssertEqual(t, result2, product)
}

View File

@ -68,3 +68,18 @@ func ToStringKey(values ...interface{}) string {
return strings.Join(results, "_") return strings.Join(results, "_")
} }
func AssertEqual(src, dst interface{}) bool {
if !reflect.DeepEqual(src, dst) {
if valuer, ok := src.(driver.Valuer); ok {
src, _ = valuer.Value()
}
if valuer, ok := dst.(driver.Valuer); ok {
dst, _ = valuer.Value()
}
return reflect.DeepEqual(src, dst)
}
return true
}