2020-02-02 14:32:27 +03:00
|
|
|
package callbacks
|
|
|
|
|
2020-02-23 16:22:35 +03:00
|
|
|
import (
|
|
|
|
"reflect"
|
2020-03-07 08:43:20 +03:00
|
|
|
"sort"
|
2020-02-23 16:22:35 +03:00
|
|
|
|
2020-06-02 04:16:07 +03:00
|
|
|
"gorm.io/gorm"
|
|
|
|
"gorm.io/gorm/clause"
|
|
|
|
"gorm.io/gorm/schema"
|
2021-10-28 02:24:38 +03:00
|
|
|
"gorm.io/gorm/utils"
|
2020-02-23 16:22:35 +03:00
|
|
|
)
|
2020-02-02 14:32:27 +03:00
|
|
|
|
2020-06-01 14:41:33 +03:00
|
|
|
func SetupUpdateReflectValue(db *gorm.DB) {
|
2020-06-05 14:18:22 +03:00
|
|
|
if db.Error == nil && db.Statement.Schema != nil {
|
2020-06-01 14:41:33 +03:00
|
|
|
if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest {
|
|
|
|
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
|
|
|
|
for db.Statement.ReflectValue.Kind() == reflect.Ptr {
|
|
|
|
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
|
|
|
|
}
|
|
|
|
|
|
|
|
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
|
|
|
|
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
|
|
|
if _, ok := dest[rel.Name]; ok {
|
2022-03-23 12:24:25 +03:00
|
|
|
db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]))
|
2020-06-01 14:41:33 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-03-22 17:42:36 +03:00
|
|
|
// BeforeUpdate before update hooks
|
2020-02-02 14:32:27 +03:00
|
|
|
func BeforeUpdate(db *gorm.DB) {
|
2020-11-17 12:49:43 +03:00
|
|
|
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
2020-06-05 16:23:20 +03:00
|
|
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
2020-02-23 16:22:35 +03:00
|
|
|
if db.Statement.Schema.BeforeSave {
|
2020-08-27 10:03:57 +03:00
|
|
|
if i, ok := value.(BeforeSaveInterface); ok {
|
2020-06-05 15:24:15 +03:00
|
|
|
called = true
|
2020-05-31 18:55:56 +03:00
|
|
|
db.AddError(i.BeforeSave(tx))
|
2020-02-23 16:22:35 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if db.Statement.Schema.BeforeUpdate {
|
2020-08-27 10:03:57 +03:00
|
|
|
if i, ok := value.(BeforeUpdateInterface); ok {
|
2020-06-05 15:24:15 +03:00
|
|
|
called = true
|
2020-05-31 18:55:56 +03:00
|
|
|
db.AddError(i.BeforeUpdate(tx))
|
2020-02-23 16:22:35 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-06-05 16:23:20 +03:00
|
|
|
return called
|
|
|
|
})
|
2020-02-23 16:22:35 +03:00
|
|
|
}
|
2020-02-02 14:32:27 +03:00
|
|
|
}
|
|
|
|
|
2022-03-22 17:42:36 +03:00
|
|
|
// Update update hook
|
2021-10-26 17:36:37 +03:00
|
|
|
func Update(config *Config) func(db *gorm.DB) {
|
2021-10-28 02:24:38 +03:00
|
|
|
supportReturning := utils.Contains(config.UpdateClauses, "RETURNING")
|
2020-03-07 08:43:20 +03:00
|
|
|
|
2021-10-26 17:36:37 +03:00
|
|
|
return func(db *gorm.DB) {
|
|
|
|
if db.Error != nil {
|
2020-05-31 18:55:56 +03:00
|
|
|
return
|
|
|
|
}
|
2021-09-03 18:09:20 +03:00
|
|
|
|
2022-02-25 05:48:23 +03:00
|
|
|
if db.Statement.Schema != nil {
|
|
|
|
for _, c := range db.Statement.Schema.UpdateClauses {
|
|
|
|
db.Statement.AddClause(c)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-12-08 08:58:06 +03:00
|
|
|
if db.Statement.SQL.Len() == 0 {
|
2021-10-26 17:36:37 +03:00
|
|
|
db.Statement.SQL.Grow(180)
|
|
|
|
db.Statement.AddClauseIfNotExists(clause.Update{})
|
2022-09-16 10:02:44 +03:00
|
|
|
if _, ok := db.Statement.Clauses["SET"]; !ok {
|
|
|
|
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
|
|
|
db.Statement.AddClause(set)
|
|
|
|
} else {
|
|
|
|
return
|
|
|
|
}
|
2021-10-26 17:36:37 +03:00
|
|
|
}
|
2021-12-08 08:58:06 +03:00
|
|
|
|
2021-10-26 17:36:37 +03:00
|
|
|
db.Statement.Build(db.Statement.BuildClauses...)
|
|
|
|
}
|
2020-03-07 08:43:20 +03:00
|
|
|
|
2022-02-25 05:48:23 +03:00
|
|
|
checkMissingWhereConditions(db)
|
2021-10-26 17:36:37 +03:00
|
|
|
|
|
|
|
if !db.DryRun && db.Error == nil {
|
2021-10-28 02:24:38 +03:00
|
|
|
if ok, mode := hasReturning(db, supportReturning); ok {
|
2021-10-26 17:36:37 +03:00
|
|
|
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
2021-10-28 03:03:23 +03:00
|
|
|
dest := db.Statement.Dest
|
|
|
|
db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface()
|
2021-10-28 02:24:38 +03:00
|
|
|
gorm.Scan(rows, db, mode)
|
2021-10-28 03:03:23 +03:00
|
|
|
db.Statement.Dest = dest
|
2021-12-02 05:39:24 +03:00
|
|
|
db.AddError(rows.Close())
|
2021-10-26 17:36:37 +03:00
|
|
|
}
|
|
|
|
} else {
|
|
|
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
|
|
|
|
2021-10-28 02:24:38 +03:00
|
|
|
if db.AddError(err) == nil {
|
2021-10-26 17:36:37 +03:00
|
|
|
db.RowsAffected, _ = result.RowsAffected()
|
|
|
|
}
|
|
|
|
}
|
2020-05-31 18:55:56 +03:00
|
|
|
}
|
2020-03-07 08:43:20 +03:00
|
|
|
}
|
2020-02-02 14:32:27 +03:00
|
|
|
}
|
|
|
|
|
2022-03-22 17:42:36 +03:00
|
|
|
// AfterUpdate after update hooks
|
2020-02-02 14:32:27 +03:00
|
|
|
func AfterUpdate(db *gorm.DB) {
|
2020-11-17 12:49:43 +03:00
|
|
|
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
2020-06-05 16:23:20 +03:00
|
|
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
2022-03-22 17:42:36 +03:00
|
|
|
if db.Statement.Schema.AfterUpdate {
|
|
|
|
if i, ok := value.(AfterUpdateInterface); ok {
|
2020-06-05 15:24:15 +03:00
|
|
|
called = true
|
2022-03-22 17:42:36 +03:00
|
|
|
db.AddError(i.AfterUpdate(tx))
|
2020-02-23 16:22:35 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-03-22 17:42:36 +03:00
|
|
|
if db.Statement.Schema.AfterSave {
|
|
|
|
if i, ok := value.(AfterSaveInterface); ok {
|
2020-06-05 15:24:15 +03:00
|
|
|
called = true
|
2022-03-22 17:42:36 +03:00
|
|
|
db.AddError(i.AfterSave(tx))
|
2020-02-23 16:22:35 +03:00
|
|
|
}
|
|
|
|
}
|
2022-03-22 17:42:36 +03:00
|
|
|
|
2020-06-05 15:24:15 +03:00
|
|
|
return called
|
2020-06-05 16:23:20 +03:00
|
|
|
})
|
2020-02-23 16:22:35 +03:00
|
|
|
}
|
2020-02-02 14:32:27 +03:00
|
|
|
}
|
2020-03-07 08:43:20 +03:00
|
|
|
|
|
|
|
// ConvertToAssignments convert to update assignments
|
2020-03-08 08:24:08 +03:00
|
|
|
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
2020-05-24 12:24:23 +03:00
|
|
|
var (
|
2020-06-30 11:53:54 +03:00
|
|
|
selectColumns, restricted = stmt.SelectAndOmitColumns(false, true)
|
2020-05-24 12:24:23 +03:00
|
|
|
assignValue func(field *schema.Field, value interface{})
|
|
|
|
)
|
|
|
|
|
2020-06-01 14:41:33 +03:00
|
|
|
switch stmt.ReflectValue.Kind() {
|
2020-05-24 12:24:23 +03:00
|
|
|
case reflect.Slice, reflect.Array:
|
|
|
|
assignValue = func(field *schema.Field, value interface{}) {
|
2020-06-01 14:41:33 +03:00
|
|
|
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
2023-02-18 04:20:29 +03:00
|
|
|
if stmt.ReflectValue.CanAddr() {
|
|
|
|
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
|
|
|
|
}
|
2020-05-24 12:24:23 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
case reflect.Struct:
|
|
|
|
assignValue = func(field *schema.Field, value interface{}) {
|
2020-06-01 14:41:33 +03:00
|
|
|
if stmt.ReflectValue.CanAddr() {
|
2022-02-16 10:30:43 +03:00
|
|
|
field.Set(stmt.Context, stmt.ReflectValue, value)
|
2020-05-30 11:03:27 +03:00
|
|
|
}
|
2020-05-24 12:24:23 +03:00
|
|
|
}
|
|
|
|
default:
|
|
|
|
assignValue = func(field *schema.Field, value interface{}) {
|
|
|
|
}
|
|
|
|
}
|
2020-03-07 08:43:20 +03:00
|
|
|
|
2020-06-01 14:41:33 +03:00
|
|
|
updatingValue := reflect.ValueOf(stmt.Dest)
|
|
|
|
for updatingValue.Kind() == reflect.Ptr {
|
|
|
|
updatingValue = updatingValue.Elem()
|
|
|
|
}
|
|
|
|
|
2020-06-19 19:48:15 +03:00
|
|
|
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
|
|
|
switch stmt.ReflectValue.Kind() {
|
|
|
|
case reflect.Slice, reflect.Array:
|
2021-10-28 03:03:23 +03:00
|
|
|
if size := stmt.ReflectValue.Len(); size > 0 {
|
2022-07-07 10:39:29 +03:00
|
|
|
var isZero bool
|
2021-11-29 04:33:20 +03:00
|
|
|
for i := 0; i < size; i++ {
|
2022-07-07 10:39:29 +03:00
|
|
|
for _, field := range stmt.Schema.PrimaryFields {
|
|
|
|
_, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
|
|
|
|
if !isZero {
|
|
|
|
break
|
|
|
|
}
|
2021-10-28 03:03:23 +03:00
|
|
|
}
|
2020-06-19 19:48:15 +03:00
|
|
|
}
|
2021-10-28 03:03:23 +03:00
|
|
|
|
2022-07-07 10:39:29 +03:00
|
|
|
if !isZero {
|
|
|
|
_, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
|
|
|
|
column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues)
|
|
|
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
|
|
|
}
|
2020-06-19 19:48:15 +03:00
|
|
|
}
|
|
|
|
case reflect.Struct:
|
|
|
|
for _, field := range stmt.Schema.PrimaryFields {
|
2022-02-16 10:30:43 +03:00
|
|
|
if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
|
2020-06-19 19:48:15 +03:00
|
|
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-06-01 14:41:33 +03:00
|
|
|
switch value := updatingValue.Interface().(type) {
|
2020-03-07 08:43:20 +03:00
|
|
|
case map[string]interface{}:
|
2020-03-08 08:24:08 +03:00
|
|
|
set = make([]clause.Assignment, 0, len(value))
|
2020-03-07 08:43:20 +03:00
|
|
|
|
2020-06-08 08:45:41 +03:00
|
|
|
keys := make([]string, 0, len(value))
|
2020-06-07 10:24:34 +03:00
|
|
|
for k := range value {
|
2020-03-07 08:43:20 +03:00
|
|
|
keys = append(keys, k)
|
|
|
|
}
|
|
|
|
sort.Strings(keys)
|
|
|
|
|
|
|
|
for _, k := range keys {
|
2020-08-13 11:05:06 +03:00
|
|
|
kv := value[k]
|
|
|
|
if _, ok := kv.(*gorm.DB); ok {
|
|
|
|
kv = []interface{}{kv}
|
|
|
|
}
|
|
|
|
|
2020-06-05 14:18:22 +03:00
|
|
|
if stmt.Schema != nil {
|
|
|
|
if field := stmt.Schema.LookUpField(k); field != nil {
|
|
|
|
if field.DBName != "" {
|
|
|
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
2020-08-13 11:05:06 +03:00
|
|
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv})
|
2020-06-05 14:18:22 +03:00
|
|
|
assignValue(field, value[k])
|
|
|
|
}
|
|
|
|
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
|
2020-06-01 14:41:33 +03:00
|
|
|
assignValue(field, value[k])
|
|
|
|
}
|
2020-06-05 14:18:22 +03:00
|
|
|
continue
|
2020-03-07 08:43:20 +03:00
|
|
|
}
|
2020-06-05 14:18:22 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
2020-08-13 11:05:06 +03:00
|
|
|
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv})
|
2020-03-07 08:43:20 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-11-17 12:49:43 +03:00
|
|
|
if !stmt.SkipHooks && stmt.Schema != nil {
|
2020-08-27 14:52:01 +03:00
|
|
|
for _, dbName := range stmt.Schema.DBNames {
|
|
|
|
field := stmt.Schema.LookUpField(dbName)
|
2020-05-30 12:34:22 +03:00
|
|
|
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
|
2020-12-15 06:18:29 +03:00
|
|
|
if v, ok := selectColumns[field.DBName]; (ok && v) || !ok {
|
2020-07-06 06:20:43 +03:00
|
|
|
now := stmt.DB.NowFunc()
|
|
|
|
assignValue(field, now)
|
|
|
|
|
|
|
|
if field.AutoUpdateTime == schema.UnixNanosecond {
|
|
|
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
|
2020-07-30 12:39:57 +03:00
|
|
|
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
|
|
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
|
2022-02-23 12:48:13 +03:00
|
|
|
} else if field.AutoUpdateTime == schema.UnixSecond {
|
2020-07-06 06:20:43 +03:00
|
|
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
|
2022-02-23 12:48:13 +03:00
|
|
|
} else {
|
|
|
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
|
2020-07-06 06:20:43 +03:00
|
|
|
}
|
2020-06-05 14:18:22 +03:00
|
|
|
}
|
2020-05-30 12:34:22 +03:00
|
|
|
}
|
2020-03-08 08:24:08 +03:00
|
|
|
}
|
|
|
|
}
|
2020-03-07 08:43:20 +03:00
|
|
|
default:
|
2022-01-06 10:02:53 +03:00
|
|
|
updatingSchema := stmt.Schema
|
2023-03-10 12:04:54 +03:00
|
|
|
var isDiffSchema bool
|
2021-09-07 15:04:54 +03:00
|
|
|
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
|
|
|
// different schema
|
|
|
|
updatingStmt := &gorm.Statement{DB: stmt.DB}
|
|
|
|
if err := updatingStmt.Parse(stmt.Dest); err == nil {
|
|
|
|
updatingSchema = updatingStmt.Schema
|
2023-03-10 12:04:54 +03:00
|
|
|
isDiffSchema = true
|
2021-09-07 15:04:54 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-06-01 14:41:33 +03:00
|
|
|
switch updatingValue.Kind() {
|
2020-03-07 08:43:20 +03:00
|
|
|
case reflect.Struct:
|
2020-03-08 08:24:08 +03:00
|
|
|
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
|
2020-08-27 14:52:01 +03:00
|
|
|
for _, dbName := range stmt.Schema.DBNames {
|
2021-10-08 08:47:01 +03:00
|
|
|
if field := updatingSchema.LookUpField(dbName); field != nil {
|
2021-09-07 15:04:54 +03:00
|
|
|
if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
|
|
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
|
2022-02-16 10:30:43 +03:00
|
|
|
value, isZero := field.ValueOf(stmt.Context, updatingValue)
|
2021-09-07 15:04:54 +03:00
|
|
|
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
|
|
|
|
if field.AutoUpdateTime == schema.UnixNanosecond {
|
|
|
|
value = stmt.DB.NowFunc().UnixNano()
|
|
|
|
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
|
|
|
value = stmt.DB.NowFunc().UnixNano() / 1e6
|
2022-02-23 12:48:13 +03:00
|
|
|
} else if field.AutoUpdateTime == schema.UnixSecond {
|
2021-09-07 15:04:54 +03:00
|
|
|
value = stmt.DB.NowFunc().Unix()
|
2022-02-23 12:48:13 +03:00
|
|
|
} else {
|
|
|
|
value = stmt.DB.NowFunc()
|
2021-09-07 15:04:54 +03:00
|
|
|
}
|
|
|
|
isZero = false
|
2020-05-30 12:34:22 +03:00
|
|
|
}
|
2020-03-08 08:24:08 +03:00
|
|
|
|
2021-10-08 08:47:01 +03:00
|
|
|
if (ok || !isZero) && field.Updatable {
|
2021-09-07 15:04:54 +03:00
|
|
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
|
2023-03-10 12:04:54 +03:00
|
|
|
assignField := field
|
|
|
|
if isDiffSchema {
|
|
|
|
if originField := stmt.Schema.LookUpField(dbName); originField != nil {
|
|
|
|
assignField = originField
|
|
|
|
}
|
|
|
|
}
|
|
|
|
assignValue(assignField, value)
|
2021-09-07 15:04:54 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
2022-02-16 10:30:43 +03:00
|
|
|
if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero {
|
2021-09-07 15:04:54 +03:00
|
|
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
2020-03-08 08:24:08 +03:00
|
|
|
}
|
|
|
|
}
|
2020-03-07 08:43:20 +03:00
|
|
|
}
|
|
|
|
}
|
2020-08-18 13:00:36 +03:00
|
|
|
default:
|
|
|
|
stmt.AddError(gorm.ErrInvalidData)
|
2020-03-07 08:43:20 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-03-08 08:24:08 +03:00
|
|
|
return
|
2020-03-07 08:43:20 +03:00
|
|
|
}
|