gorm/callbacks/update.go

291 lines
8.9 KiB
Go
Raw Normal View History

2020-02-02 14:32:27 +03:00
package callbacks
2020-02-23 16:22:35 +03:00
import (
"reflect"
2020-03-07 08:43:20 +03:00
"sort"
2020-02-23 16:22:35 +03:00
2020-06-02 04:16:07 +03:00
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
2020-02-23 16:22:35 +03:00
)
2020-02-02 14:32:27 +03:00
2020-06-01 14:41:33 +03:00
func SetupUpdateReflectValue(db *gorm.DB) {
2020-06-05 14:18:22 +03:00
if db.Error == nil && db.Statement.Schema != nil {
2020-06-01 14:41:33 +03:00
if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest {
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
for db.Statement.ReflectValue.Kind() == reflect.Ptr {
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
}
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if _, ok := dest[rel.Name]; ok {
rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name])
}
}
2021-09-05 18:12:24 +03:00
} else if modelType, destType := findType(db.Statement.Model), findType(db.Statement.Dest); modelType.Kind() == reflect.Struct && destType.Kind() == reflect.Struct {
db.Statement.Dest = transToModel(reflect.Indirect(reflect.ValueOf(db.Statement.Dest)), reflect.New(modelType).Elem())
2020-06-01 14:41:33 +03:00
}
}
}
}
2021-09-05 18:12:24 +03:00
func findType(target interface{}) reflect.Type {
t := reflect.TypeOf(target)
if t.Kind() == reflect.Ptr {
return t.Elem()
}
return t
}
func transToModel(from, to reflect.Value) interface{} {
if from.String() == to.String() {
return from.Interface()
}
fromType := from.Type()
for i := 0; i < fromType.NumField(); i++ {
fieldName := fromType.Field(i).Name
fromField, toField := from.FieldByName(fieldName), to.FieldByName(fieldName)
if !toField.IsValid() || !toField.CanSet() || toField.Kind() != fromField.Kind() {
continue
}
toField.Set(fromField)
}
return to.Interface()
}
2020-02-02 14:32:27 +03:00
func BeforeUpdate(db *gorm.DB) {
2020-11-17 12:49:43 +03:00
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
2020-02-23 16:22:35 +03:00
if db.Statement.Schema.BeforeSave {
2020-08-27 10:03:57 +03:00
if i, ok := value.(BeforeSaveInterface); ok {
2020-06-05 15:24:15 +03:00
called = true
2020-05-31 18:55:56 +03:00
db.AddError(i.BeforeSave(tx))
2020-02-23 16:22:35 +03:00
}
}
if db.Statement.Schema.BeforeUpdate {
2020-08-27 10:03:57 +03:00
if i, ok := value.(BeforeUpdateInterface); ok {
2020-06-05 15:24:15 +03:00
called = true
2020-05-31 18:55:56 +03:00
db.AddError(i.BeforeUpdate(tx))
2020-02-23 16:22:35 +03:00
}
}
return called
})
2020-02-23 16:22:35 +03:00
}
2020-02-02 14:32:27 +03:00
}
func Update(db *gorm.DB) {
2021-09-03 18:09:20 +03:00
if db.Error != nil {
return
}
2020-05-29 02:35:45 +03:00
2021-09-03 18:09:20 +03:00
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.UpdateClauses {
db.Statement.AddClause(c)
2020-05-29 02:35:45 +03:00
}
2021-09-03 18:09:20 +03:00
}
2020-03-07 08:43:20 +03:00
2021-09-03 18:09:20 +03:00
if db.Statement.SQL.String() == "" {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Update{})
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
db.Statement.AddClause(set)
} else {
2020-05-31 18:55:56 +03:00
return
}
2021-09-03 18:09:20 +03:00
db.Statement.Build(db.Statement.BuildClauses...)
}
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
db.AddError(gorm.ErrMissingWhereClause)
return
}
2020-05-31 15:42:07 +03:00
2021-09-03 18:09:20 +03:00
if !db.DryRun && db.Error == nil {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
2020-03-07 08:43:20 +03:00
2021-09-03 18:09:20 +03:00
if err == nil {
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
2020-05-31 18:55:56 +03:00
}
2020-03-07 08:43:20 +03:00
}
2020-02-02 14:32:27 +03:00
}
func AfterUpdate(db *gorm.DB) {
2020-11-17 12:49:43 +03:00
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
2020-02-23 16:22:35 +03:00
if db.Statement.Schema.AfterSave {
2020-08-27 10:03:57 +03:00
if i, ok := value.(AfterSaveInterface); ok {
2020-06-05 15:24:15 +03:00
called = true
2020-05-31 18:55:56 +03:00
db.AddError(i.AfterSave(tx))
2020-02-23 16:22:35 +03:00
}
}
if db.Statement.Schema.AfterUpdate {
2020-08-27 10:03:57 +03:00
if i, ok := value.(AfterUpdateInterface); ok {
2020-06-05 15:24:15 +03:00
called = true
2020-05-31 18:55:56 +03:00
db.AddError(i.AfterUpdate(tx))
2020-02-23 16:22:35 +03:00
}
}
2020-06-05 15:24:15 +03:00
return called
})
2020-02-23 16:22:35 +03:00
}
2020-02-02 14:32:27 +03:00
}
2020-03-07 08:43:20 +03:00
// ConvertToAssignments convert to update assignments
2020-03-08 08:24:08 +03:00
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
2020-05-24 12:24:23 +03:00
var (
2020-06-30 11:53:54 +03:00
selectColumns, restricted = stmt.SelectAndOmitColumns(false, true)
2020-05-24 12:24:23 +03:00
assignValue func(field *schema.Field, value interface{})
)
2020-06-01 14:41:33 +03:00
switch stmt.ReflectValue.Kind() {
2020-05-24 12:24:23 +03:00
case reflect.Slice, reflect.Array:
assignValue = func(field *schema.Field, value interface{}) {
2020-06-01 14:41:33 +03:00
for i := 0; i < stmt.ReflectValue.Len(); i++ {
field.Set(stmt.ReflectValue.Index(i), value)
2020-05-24 12:24:23 +03:00
}
}
case reflect.Struct:
assignValue = func(field *schema.Field, value interface{}) {
2020-06-01 14:41:33 +03:00
if stmt.ReflectValue.CanAddr() {
field.Set(stmt.ReflectValue, value)
}
2020-05-24 12:24:23 +03:00
}
default:
assignValue = func(field *schema.Field, value interface{}) {
}
}
2020-03-07 08:43:20 +03:00
2020-06-01 14:41:33 +03:00
updatingValue := reflect.ValueOf(stmt.Dest)
for updatingValue.Kind() == reflect.Ptr {
updatingValue = updatingValue.Elem()
}
2020-06-19 19:48:15 +03:00
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var primaryKeyExprs []clause.Expression
2020-06-19 19:48:15 +03:00
for i := 0; i < stmt.ReflectValue.Len(); i++ {
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
var notZero bool
for idx, field := range stmt.Schema.PrimaryFields {
value, isZero := field.ValueOf(stmt.ReflectValue.Index(i))
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
notZero = notZero || !isZero
}
if notZero {
primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
2020-06-19 19:48:15 +03:00
}
}
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}})
2020-06-19 19:48:15 +03:00
case reflect.Struct:
for _, field := range stmt.Schema.PrimaryFields {
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
}
}
}
2020-06-01 14:41:33 +03:00
switch value := updatingValue.Interface().(type) {
2020-03-07 08:43:20 +03:00
case map[string]interface{}:
2020-03-08 08:24:08 +03:00
set = make([]clause.Assignment, 0, len(value))
2020-03-07 08:43:20 +03:00
2020-06-08 08:45:41 +03:00
keys := make([]string, 0, len(value))
2020-06-07 10:24:34 +03:00
for k := range value {
2020-03-07 08:43:20 +03:00
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
2020-08-13 11:05:06 +03:00
kv := value[k]
if _, ok := kv.(*gorm.DB); ok {
kv = []interface{}{kv}
}
2020-06-05 14:18:22 +03:00
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(k); field != nil {
if field.DBName != "" {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
2020-08-13 11:05:06 +03:00
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv})
2020-06-05 14:18:22 +03:00
assignValue(field, value[k])
}
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
2020-06-01 14:41:33 +03:00
assignValue(field, value[k])
}
2020-06-05 14:18:22 +03:00
continue
2020-03-07 08:43:20 +03:00
}
2020-06-05 14:18:22 +03:00
}
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
2020-08-13 11:05:06 +03:00
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv})
2020-03-07 08:43:20 +03:00
}
}
2020-11-17 12:49:43 +03:00
if !stmt.SkipHooks && stmt.Schema != nil {
2020-08-27 14:52:01 +03:00
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.LookUpField(dbName)
2020-05-30 12:34:22 +03:00
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || !ok {
2020-07-06 06:20:43 +03:00
now := stmt.DB.NowFunc()
assignValue(field, now)
if field.AutoUpdateTime == schema.UnixNanosecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
} else if field.AutoUpdateTime == schema.UnixMillisecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
2020-07-20 13:59:28 +03:00
} else if field.GORMDataType == schema.Time {
2020-07-06 06:20:43 +03:00
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
} else {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
}
2020-06-05 14:18:22 +03:00
}
2020-05-30 12:34:22 +03:00
}
2020-03-08 08:24:08 +03:00
}
}
2020-03-07 08:43:20 +03:00
default:
2020-06-01 14:41:33 +03:00
switch updatingValue.Kind() {
2020-03-07 08:43:20 +03:00
case reflect.Struct:
2020-03-08 08:24:08 +03:00
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
2020-08-27 14:52:01 +03:00
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.LookUpField(dbName)
2021-09-05 18:12:24 +03:00
if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
2020-06-01 14:41:33 +03:00
value, isZero := field.ValueOf(updatingValue)
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
if field.AutoUpdateTime == schema.UnixNanosecond {
value = stmt.DB.NowFunc().UnixNano()
} else if field.AutoUpdateTime == schema.UnixMillisecond {
value = stmt.DB.NowFunc().UnixNano() / 1e6
} else if field.GORMDataType == schema.Time {
value = stmt.DB.NowFunc()
} else {
value = stmt.DB.NowFunc().Unix()
2020-05-30 12:34:22 +03:00
}
isZero = false
2020-03-08 08:24:08 +03:00
}
if ok || !isZero {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
2020-05-24 12:24:23 +03:00
assignValue(field, value)
2020-03-08 08:24:08 +03:00
}
}
} else {
2020-06-01 14:41:33 +03:00
if value, isZero := field.ValueOf(updatingValue); !isZero {
2020-03-08 08:24:08 +03:00
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
2020-03-07 08:43:20 +03:00
}
}
2020-08-18 13:00:36 +03:00
default:
stmt.AddError(gorm.ErrInvalidData)
2020-03-07 08:43:20 +03:00
}
}
2020-03-08 08:24:08 +03:00
return
2020-03-07 08:43:20 +03:00
}