forked from mirror/gorm
Add SetColumn, Changed method
This commit is contained in:
parent
2d048d9ece
commit
f5566288de
|
@ -11,7 +11,7 @@ import (
|
||||||
|
|
||||||
func SaveBeforeAssociations(db *gorm.DB) {
|
func SaveBeforeAssociations(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil {
|
if db.Error == nil && db.Statement.Schema != nil {
|
||||||
selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false)
|
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
|
||||||
|
|
||||||
// Save Belongs To associations
|
// Save Belongs To associations
|
||||||
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
||||||
|
@ -90,7 +90,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
||||||
|
|
||||||
func SaveAfterAssociations(db *gorm.DB) {
|
func SaveAfterAssociations(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil {
|
if db.Error == nil && db.Statement.Schema != nil {
|
||||||
selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false)
|
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
|
||||||
|
|
||||||
// Save Has One associations
|
// Save Has One associations
|
||||||
for _, rel := range db.Statement.Schema.Relationships.HasOne {
|
for _, rel := range db.Statement.Schema.Relationships.HasOne {
|
||||||
|
|
|
@ -218,7 +218,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||||
values = ConvertSliceOfMapToValuesForCreate(stmt, value)
|
values = ConvertSliceOfMapToValuesForCreate(stmt, value)
|
||||||
default:
|
default:
|
||||||
var (
|
var (
|
||||||
selectColumns, restricted = SelectAndOmitColumns(stmt, true, false)
|
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||||
curTime = stmt.DB.NowFunc()
|
curTime = stmt.DB.NowFunc()
|
||||||
isZero bool
|
isZero bool
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,64 +7,10 @@ import (
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
|
|
||||||
func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate bool) (map[string]bool, bool) {
|
|
||||||
results := map[string]bool{}
|
|
||||||
notRestricted := false
|
|
||||||
|
|
||||||
// select columns
|
|
||||||
for _, column := range stmt.Selects {
|
|
||||||
if column == "*" {
|
|
||||||
notRestricted = true
|
|
||||||
for _, dbName := range stmt.Schema.DBNames {
|
|
||||||
results[dbName] = true
|
|
||||||
}
|
|
||||||
} else if column == clause.Associations {
|
|
||||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
|
||||||
results[rel.Name] = true
|
|
||||||
}
|
|
||||||
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
|
|
||||||
results[field.DBName] = true
|
|
||||||
} else {
|
|
||||||
results[column] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// omit columns
|
|
||||||
for _, omit := range stmt.Omits {
|
|
||||||
if omit == clause.Associations {
|
|
||||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
|
||||||
results[rel.Name] = false
|
|
||||||
}
|
|
||||||
} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
|
|
||||||
results[field.DBName] = false
|
|
||||||
} else {
|
|
||||||
results[omit] = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if stmt.Schema != nil {
|
|
||||||
for _, field := range stmt.Schema.Fields {
|
|
||||||
name := field.DBName
|
|
||||||
if name == "" {
|
|
||||||
name = field.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
if requireCreate && !field.Creatable {
|
|
||||||
results[name] = false
|
|
||||||
} else if requireUpdate && !field.Updatable {
|
|
||||||
results[name] = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, !notRestricted && len(stmt.Selects) > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConvertMapToValuesForCreate convert map to values
|
// ConvertMapToValuesForCreate convert map to values
|
||||||
func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) {
|
func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) {
|
||||||
columns := make([]string, 0, len(mapValue))
|
columns := make([]string, 0, len(mapValue))
|
||||||
selectColumns, restricted := SelectAndOmitColumns(stmt, true, false)
|
selectColumns, restricted := stmt.SelectAndOmitColumns(true, false)
|
||||||
|
|
||||||
var keys []string
|
var keys []string
|
||||||
for k := range mapValue {
|
for k := range mapValue {
|
||||||
|
@ -91,7 +37,7 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
|
||||||
var (
|
var (
|
||||||
columns = []string{}
|
columns = []string{}
|
||||||
result = map[string][]interface{}{}
|
result = map[string][]interface{}{}
|
||||||
selectColumns, restricted = SelectAndOmitColumns(stmt, true, false)
|
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx, mapValue := range mapValues {
|
for idx, mapValue := range mapValues {
|
||||||
|
|
|
@ -110,7 +110,7 @@ func AfterUpdate(db *gorm.DB) {
|
||||||
// ConvertToAssignments convert to update assignments
|
// ConvertToAssignments convert to update assignments
|
||||||
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||||
var (
|
var (
|
||||||
selectColumns, restricted = SelectAndOmitColumns(stmt, false, true)
|
selectColumns, restricted = stmt.SelectAndOmitColumns(false, true)
|
||||||
assignValue func(field *schema.Field, value interface{})
|
assignValue func(field *schema.Field, value interface{})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -29,4 +29,6 @@ var (
|
||||||
ErrUnsupportedDriver = errors.New("unsupported driver")
|
ErrUnsupportedDriver = errors.New("unsupported driver")
|
||||||
// ErrRegistered registered
|
// ErrRegistered registered
|
||||||
ErrRegistered = errors.New("registered")
|
ErrRegistered = errors.New("registered")
|
||||||
|
// ErrInvalidField invalid field
|
||||||
|
ErrInvalidField = errors.New("invalid field")
|
||||||
)
|
)
|
||||||
|
|
117
statement.go
117
statement.go
|
@ -12,6 +12,7 @@ import (
|
||||||
|
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Statement statement
|
// Statement statement
|
||||||
|
@ -370,3 +371,119 @@ func (stmt *Statement) clone() *Statement {
|
||||||
|
|
||||||
return newStmt
|
return newStmt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helpers
|
||||||
|
// SetColumn set column's value
|
||||||
|
func (stmt *Statement) SetColumn(name string, value interface{}) {
|
||||||
|
if v, ok := stmt.Dest.(map[string]interface{}); ok {
|
||||||
|
v[name] = value
|
||||||
|
} else if stmt.Schema != nil {
|
||||||
|
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||||
|
field.Set(stmt.ReflectValue, value)
|
||||||
|
} else {
|
||||||
|
stmt.AddError(ErrInvalidField)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
stmt.AddError(ErrInvalidField)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Changed check model changed or not when updating
|
||||||
|
func (stmt *Statement) Changed(fields ...string) bool {
|
||||||
|
modelValue := reflect.ValueOf(stmt.Model)
|
||||||
|
for modelValue.Kind() == reflect.Ptr {
|
||||||
|
modelValue = modelValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
|
||||||
|
changed := func(field *schema.Field) bool {
|
||||||
|
fieldValue, isZero := field.ValueOf(modelValue)
|
||||||
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||||
|
if v, ok := stmt.Dest.(map[string]interface{}); ok {
|
||||||
|
if fv, ok := v[field.Name]; ok {
|
||||||
|
return !utils.AssertEqual(fv, fieldValue)
|
||||||
|
} else if fv, ok := v[field.DBName]; ok {
|
||||||
|
return !utils.AssertEqual(fv, fieldValue)
|
||||||
|
} else if isZero {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
changedValue, _ := field.ValueOf(stmt.ReflectValue)
|
||||||
|
return !utils.AssertEqual(changedValue, fieldValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(fields) == 0 {
|
||||||
|
for _, field := range stmt.Schema.FieldsByDBName {
|
||||||
|
if changed(field) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _, name := range fields {
|
||||||
|
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||||
|
if changed(field) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
|
||||||
|
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
|
||||||
|
results := map[string]bool{}
|
||||||
|
notRestricted := false
|
||||||
|
|
||||||
|
// select columns
|
||||||
|
for _, column := range stmt.Selects {
|
||||||
|
if column == "*" {
|
||||||
|
notRestricted = true
|
||||||
|
for _, dbName := range stmt.Schema.DBNames {
|
||||||
|
results[dbName] = true
|
||||||
|
}
|
||||||
|
} else if column == clause.Associations {
|
||||||
|
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||||
|
results[rel.Name] = true
|
||||||
|
}
|
||||||
|
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
|
||||||
|
results[field.DBName] = true
|
||||||
|
} else {
|
||||||
|
results[column] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// omit columns
|
||||||
|
for _, omit := range stmt.Omits {
|
||||||
|
if omit == clause.Associations {
|
||||||
|
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||||
|
results[rel.Name] = false
|
||||||
|
}
|
||||||
|
} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
|
||||||
|
results[field.DBName] = false
|
||||||
|
} else {
|
||||||
|
results[omit] = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if stmt.Schema != nil {
|
||||||
|
for _, field := range stmt.Schema.Fields {
|
||||||
|
name := field.DBName
|
||||||
|
if name == "" {
|
||||||
|
name = field.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
if requireCreate && !field.Creatable {
|
||||||
|
results[name] = false
|
||||||
|
} else if requireUpdate && !field.Updatable {
|
||||||
|
results[name] = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, !notRestricted && len(stmt.Selects) > 0
|
||||||
|
}
|
||||||
|
|
|
@ -285,3 +285,84 @@ func TestUseDBInHooks(t *testing.T) {
|
||||||
t.Errorf("Admin product's price should not be changed, expects: %v, got %v", 600, result4.Price)
|
t.Errorf("Admin product's price should not be changed, expects: %v, got %v", 600, result4.Price)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Product3 struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
Code string
|
||||||
|
Price int64
|
||||||
|
Owner string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Product3) BeforeCreate(tx *gorm.DB) (err error) {
|
||||||
|
tx.Statement.SetColumn("Price", s.Price+100)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Product3) BeforeUpdate(tx *gorm.DB) (err error) {
|
||||||
|
if tx.Statement.Changed() {
|
||||||
|
tx.Statement.SetColumn("Price", s.Price+10)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tx.Statement.Changed("Code") {
|
||||||
|
s.Price += 20
|
||||||
|
tx.Statement.SetColumn("Price", s.Price+30)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetColumn(t *testing.T) {
|
||||||
|
DB.Migrator().DropTable(&Product3{})
|
||||||
|
DB.AutoMigrate(&Product3{})
|
||||||
|
|
||||||
|
product := Product3{Name: "Product", Price: 0}
|
||||||
|
DB.Create(&product)
|
||||||
|
|
||||||
|
if product.Price != 100 {
|
||||||
|
t.Errorf("invalid price after create, got %+v", product)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&product).Select("code", "price").Updates(map[string]interface{}{"code": "L1212"})
|
||||||
|
|
||||||
|
if product.Price != 150 || product.Code != "L1212" {
|
||||||
|
t.Errorf("invalid data after update, got %+v", product)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Code not changed, price should not change
|
||||||
|
DB.Model(&product).Updates(map[string]interface{}{"Name": "Product New"})
|
||||||
|
|
||||||
|
if product.Name != "Product New" || product.Price != 160 || product.Code != "L1212" {
|
||||||
|
t.Errorf("invalid data after update, got %+v", product)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Code changed, but not selected, price should not change
|
||||||
|
DB.Model(&product).Select("Name", "Price").Updates(map[string]interface{}{"Name": "Product New2", "code": "L1213"})
|
||||||
|
|
||||||
|
if product.Name != "Product New2" || product.Price != 170 || product.Code != "L1212" {
|
||||||
|
t.Errorf("invalid data after update, got %+v", product)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Code changed, price should changed
|
||||||
|
DB.Model(&product).Select("Name", "Code", "Price").Updates(map[string]interface{}{"Name": "Product New3", "code": "L1213"})
|
||||||
|
|
||||||
|
if product.Name != "Product New3" || product.Price != 220 || product.Code != "L1213" {
|
||||||
|
t.Errorf("invalid data after update, got %+v", product)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result Product3
|
||||||
|
DB.First(&result, product.ID)
|
||||||
|
|
||||||
|
AssertEqual(t, result, product)
|
||||||
|
|
||||||
|
// Code changed, price not selected, price should not change
|
||||||
|
DB.Model(&product).Select("code").Updates(map[string]interface{}{"name": "L1214"})
|
||||||
|
|
||||||
|
if product.Price != 220 || product.Code != "L1213" {
|
||||||
|
t.Errorf("invalid data after update, got %+v", product)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result2 Product3
|
||||||
|
DB.First(&result2, product.ID)
|
||||||
|
|
||||||
|
AssertEqual(t, result2, product)
|
||||||
|
}
|
||||||
|
|
|
@ -68,3 +68,18 @@ func ToStringKey(values ...interface{}) string {
|
||||||
|
|
||||||
return strings.Join(results, "_")
|
return strings.Join(results, "_")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AssertEqual(src, dst interface{}) bool {
|
||||||
|
if !reflect.DeepEqual(src, dst) {
|
||||||
|
if valuer, ok := src.(driver.Valuer); ok {
|
||||||
|
src, _ = valuer.Value()
|
||||||
|
}
|
||||||
|
|
||||||
|
if valuer, ok := dst.(driver.Valuer); ok {
|
||||||
|
dst, _ = valuer.Value()
|
||||||
|
}
|
||||||
|
|
||||||
|
return reflect.DeepEqual(src, dst)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue