forked from mirror/gorm
Add SetColumn, Changed method
This commit is contained in:
parent
e308b103c0
commit
66dcd7e3ca
|
@ -11,7 +11,7 @@ import (
|
|||
|
||||
func SaveBeforeAssociations(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false)
|
||||
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
|
||||
|
||||
// Save Belongs To associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
||||
|
@ -90,7 +90,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
|||
|
||||
func SaveAfterAssociations(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false)
|
||||
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
|
||||
|
||||
// Save Has One associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.HasOne {
|
||||
|
|
|
@ -218,7 +218,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
|||
values = ConvertSliceOfMapToValuesForCreate(stmt, value)
|
||||
default:
|
||||
var (
|
||||
selectColumns, restricted = SelectAndOmitColumns(stmt, true, false)
|
||||
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||
curTime = stmt.DB.NowFunc()
|
||||
isZero bool
|
||||
)
|
||||
|
|
|
@ -7,64 +7,10 @@ import (
|
|||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
|
||||
func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate bool) (map[string]bool, bool) {
|
||||
results := map[string]bool{}
|
||||
notRestricted := false
|
||||
|
||||
// select columns
|
||||
for _, column := range stmt.Selects {
|
||||
if column == "*" {
|
||||
notRestricted = true
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
results[dbName] = true
|
||||
}
|
||||
} else if column == clause.Associations {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
results[rel.Name] = true
|
||||
}
|
||||
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
|
||||
results[field.DBName] = true
|
||||
} else {
|
||||
results[column] = true
|
||||
}
|
||||
}
|
||||
|
||||
// omit columns
|
||||
for _, omit := range stmt.Omits {
|
||||
if omit == clause.Associations {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
results[rel.Name] = false
|
||||
}
|
||||
} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
|
||||
results[field.DBName] = false
|
||||
} else {
|
||||
results[omit] = false
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Schema != nil {
|
||||
for _, field := range stmt.Schema.Fields {
|
||||
name := field.DBName
|
||||
if name == "" {
|
||||
name = field.Name
|
||||
}
|
||||
|
||||
if requireCreate && !field.Creatable {
|
||||
results[name] = false
|
||||
} else if requireUpdate && !field.Updatable {
|
||||
results[name] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return results, !notRestricted && len(stmt.Selects) > 0
|
||||
}
|
||||
|
||||
// ConvertMapToValuesForCreate convert map to values
|
||||
func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) {
|
||||
columns := make([]string, 0, len(mapValue))
|
||||
selectColumns, restricted := SelectAndOmitColumns(stmt, true, false)
|
||||
selectColumns, restricted := stmt.SelectAndOmitColumns(true, false)
|
||||
|
||||
var keys []string
|
||||
for k := range mapValue {
|
||||
|
@ -91,7 +37,7 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
|
|||
var (
|
||||
columns = []string{}
|
||||
result = map[string][]interface{}{}
|
||||
selectColumns, restricted = SelectAndOmitColumns(stmt, true, false)
|
||||
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||
)
|
||||
|
||||
for idx, mapValue := range mapValues {
|
||||
|
|
|
@ -110,7 +110,7 @@ func AfterUpdate(db *gorm.DB) {
|
|||
// ConvertToAssignments convert to update assignments
|
||||
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
var (
|
||||
selectColumns, restricted = SelectAndOmitColumns(stmt, false, true)
|
||||
selectColumns, restricted = stmt.SelectAndOmitColumns(false, true)
|
||||
assignValue func(field *schema.Field, value interface{})
|
||||
)
|
||||
|
||||
|
|
|
@ -29,4 +29,6 @@ var (
|
|||
ErrUnsupportedDriver = errors.New("unsupported driver")
|
||||
// ErrRegistered registered
|
||||
ErrRegistered = errors.New("registered")
|
||||
// ErrInvalidField invalid field
|
||||
ErrInvalidField = errors.New("invalid field")
|
||||
)
|
||||
|
|
117
statement.go
117
statement.go
|
@ -12,6 +12,7 @@ import (
|
|||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// Statement statement
|
||||
|
@ -370,3 +371,119 @@ func (stmt *Statement) clone() *Statement {
|
|||
|
||||
return newStmt
|
||||
}
|
||||
|
||||
// Helpers
|
||||
// SetColumn set column's value
|
||||
func (stmt *Statement) SetColumn(name string, value interface{}) {
|
||||
if v, ok := stmt.Dest.(map[string]interface{}); ok {
|
||||
v[name] = value
|
||||
} else if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
field.Set(stmt.ReflectValue, value)
|
||||
} else {
|
||||
stmt.AddError(ErrInvalidField)
|
||||
}
|
||||
} else {
|
||||
stmt.AddError(ErrInvalidField)
|
||||
}
|
||||
}
|
||||
|
||||
// Changed check model changed or not when updating
|
||||
func (stmt *Statement) Changed(fields ...string) bool {
|
||||
modelValue := reflect.ValueOf(stmt.Model)
|
||||
for modelValue.Kind() == reflect.Ptr {
|
||||
modelValue = modelValue.Elem()
|
||||
}
|
||||
|
||||
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
|
||||
changed := func(field *schema.Field) bool {
|
||||
fieldValue, isZero := field.ValueOf(modelValue)
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if v, ok := stmt.Dest.(map[string]interface{}); ok {
|
||||
if fv, ok := v[field.Name]; ok {
|
||||
return !utils.AssertEqual(fv, fieldValue)
|
||||
} else if fv, ok := v[field.DBName]; ok {
|
||||
return !utils.AssertEqual(fv, fieldValue)
|
||||
} else if isZero {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
changedValue, _ := field.ValueOf(stmt.ReflectValue)
|
||||
return !utils.AssertEqual(changedValue, fieldValue)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if len(fields) == 0 {
|
||||
for _, field := range stmt.Schema.FieldsByDBName {
|
||||
if changed(field) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, name := range fields {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
if changed(field) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
|
||||
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
|
||||
results := map[string]bool{}
|
||||
notRestricted := false
|
||||
|
||||
// select columns
|
||||
for _, column := range stmt.Selects {
|
||||
if column == "*" {
|
||||
notRestricted = true
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
results[dbName] = true
|
||||
}
|
||||
} else if column == clause.Associations {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
results[rel.Name] = true
|
||||
}
|
||||
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
|
||||
results[field.DBName] = true
|
||||
} else {
|
||||
results[column] = true
|
||||
}
|
||||
}
|
||||
|
||||
// omit columns
|
||||
for _, omit := range stmt.Omits {
|
||||
if omit == clause.Associations {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
results[rel.Name] = false
|
||||
}
|
||||
} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
|
||||
results[field.DBName] = false
|
||||
} else {
|
||||
results[omit] = false
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Schema != nil {
|
||||
for _, field := range stmt.Schema.Fields {
|
||||
name := field.DBName
|
||||
if name == "" {
|
||||
name = field.Name
|
||||
}
|
||||
|
||||
if requireCreate && !field.Creatable {
|
||||
results[name] = false
|
||||
} else if requireUpdate && !field.Updatable {
|
||||
results[name] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return results, !notRestricted && len(stmt.Selects) > 0
|
||||
}
|
||||
|
|
|
@ -285,3 +285,84 @@ func TestUseDBInHooks(t *testing.T) {
|
|||
t.Errorf("Admin product's price should not be changed, expects: %v, got %v", 600, result4.Price)
|
||||
}
|
||||
}
|
||||
|
||||
type Product3 struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
Code string
|
||||
Price int64
|
||||
Owner string
|
||||
}
|
||||
|
||||
func (s Product3) BeforeCreate(tx *gorm.DB) (err error) {
|
||||
tx.Statement.SetColumn("Price", s.Price+100)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s Product3) BeforeUpdate(tx *gorm.DB) (err error) {
|
||||
if tx.Statement.Changed() {
|
||||
tx.Statement.SetColumn("Price", s.Price+10)
|
||||
}
|
||||
|
||||
if tx.Statement.Changed("Code") {
|
||||
s.Price += 20
|
||||
tx.Statement.SetColumn("Price", s.Price+30)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSetColumn(t *testing.T) {
|
||||
DB.Migrator().DropTable(&Product3{})
|
||||
DB.AutoMigrate(&Product3{})
|
||||
|
||||
product := Product3{Name: "Product", Price: 0}
|
||||
DB.Create(&product)
|
||||
|
||||
if product.Price != 100 {
|
||||
t.Errorf("invalid price after create, got %+v", product)
|
||||
}
|
||||
|
||||
DB.Model(&product).Select("code", "price").Updates(map[string]interface{}{"code": "L1212"})
|
||||
|
||||
if product.Price != 150 || product.Code != "L1212" {
|
||||
t.Errorf("invalid data after update, got %+v", product)
|
||||
}
|
||||
|
||||
// Code not changed, price should not change
|
||||
DB.Model(&product).Updates(map[string]interface{}{"Name": "Product New"})
|
||||
|
||||
if product.Name != "Product New" || product.Price != 160 || product.Code != "L1212" {
|
||||
t.Errorf("invalid data after update, got %+v", product)
|
||||
}
|
||||
|
||||
// Code changed, but not selected, price should not change
|
||||
DB.Model(&product).Select("Name", "Price").Updates(map[string]interface{}{"Name": "Product New2", "code": "L1213"})
|
||||
|
||||
if product.Name != "Product New2" || product.Price != 170 || product.Code != "L1212" {
|
||||
t.Errorf("invalid data after update, got %+v", product)
|
||||
}
|
||||
|
||||
// Code changed, price should changed
|
||||
DB.Model(&product).Select("Name", "Code", "Price").Updates(map[string]interface{}{"Name": "Product New3", "code": "L1213"})
|
||||
|
||||
if product.Name != "Product New3" || product.Price != 220 || product.Code != "L1213" {
|
||||
t.Errorf("invalid data after update, got %+v", product)
|
||||
}
|
||||
|
||||
var result Product3
|
||||
DB.First(&result, product.ID)
|
||||
|
||||
AssertEqual(t, result, product)
|
||||
|
||||
// Code changed, price not selected, price should not change
|
||||
DB.Model(&product).Select("code").Updates(map[string]interface{}{"name": "L1214"})
|
||||
|
||||
if product.Price != 220 || product.Code != "L1213" {
|
||||
t.Errorf("invalid data after update, got %+v", product)
|
||||
}
|
||||
|
||||
var result2 Product3
|
||||
DB.First(&result2, product.ID)
|
||||
|
||||
AssertEqual(t, result2, product)
|
||||
}
|
||||
|
|
|
@ -68,3 +68,18 @@ func ToStringKey(values ...interface{}) string {
|
|||
|
||||
return strings.Join(results, "_")
|
||||
}
|
||||
|
||||
func AssertEqual(src, dst interface{}) bool {
|
||||
if !reflect.DeepEqual(src, dst) {
|
||||
if valuer, ok := src.(driver.Valuer); ok {
|
||||
src, _ = valuer.Value()
|
||||
}
|
||||
|
||||
if valuer, ok := dst.(driver.Valuer); ok {
|
||||
dst, _ = valuer.Value()
|
||||
}
|
||||
|
||||
return reflect.DeepEqual(src, dst)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue