2020-02-02 14:32:27 +03:00
|
|
|
package callbacks
|
|
|
|
|
2020-02-23 16:22:35 +03:00
|
|
|
import (
|
|
|
|
"reflect"
|
2020-03-07 08:43:20 +03:00
|
|
|
"sort"
|
2020-03-08 08:24:08 +03:00
|
|
|
"time"
|
2020-02-23 16:22:35 +03:00
|
|
|
|
|
|
|
"github.com/jinzhu/gorm"
|
2020-03-07 08:43:20 +03:00
|
|
|
"github.com/jinzhu/gorm/clause"
|
2020-02-23 16:22:35 +03:00
|
|
|
)
|
2020-02-02 14:32:27 +03:00
|
|
|
|
|
|
|
func BeforeUpdate(db *gorm.DB) {
|
2020-02-23 16:22:35 +03:00
|
|
|
if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
|
|
|
callMethod := func(value interface{}) bool {
|
|
|
|
var ok bool
|
|
|
|
if db.Statement.Schema.BeforeSave {
|
|
|
|
if i, ok := value.(gorm.BeforeSaveInterface); ok {
|
|
|
|
ok = true
|
|
|
|
i.BeforeSave(db)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if db.Statement.Schema.BeforeUpdate {
|
|
|
|
if i, ok := value.(gorm.BeforeUpdateInterface); ok {
|
|
|
|
ok = true
|
|
|
|
i.BeforeUpdate(db)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return ok
|
|
|
|
}
|
|
|
|
|
|
|
|
if ok := callMethod(db.Statement.Dest); !ok {
|
|
|
|
switch db.Statement.ReflectValue.Kind() {
|
|
|
|
case reflect.Slice, reflect.Array:
|
|
|
|
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
|
|
|
|
callMethod(db.Statement.ReflectValue.Index(i).Interface())
|
|
|
|
}
|
|
|
|
case reflect.Struct:
|
|
|
|
callMethod(db.Statement.ReflectValue.Interface())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2020-02-02 14:32:27 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
func Update(db *gorm.DB) {
|
2020-03-07 08:43:20 +03:00
|
|
|
db.Statement.AddClauseIfNotExists(clause.Update{})
|
|
|
|
db.Statement.AddClause(ConvertToAssignments(db.Statement))
|
|
|
|
db.Statement.Build("UPDATE", "SET", "WHERE")
|
|
|
|
|
2020-03-09 08:10:48 +03:00
|
|
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
2020-03-07 08:43:20 +03:00
|
|
|
|
|
|
|
if err == nil {
|
|
|
|
db.RowsAffected, _ = result.RowsAffected()
|
|
|
|
} else {
|
|
|
|
db.AddError(err)
|
|
|
|
}
|
2020-02-02 14:32:27 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
func AfterUpdate(db *gorm.DB) {
|
2020-02-23 16:22:35 +03:00
|
|
|
if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
|
|
|
callMethod := func(value interface{}) bool {
|
|
|
|
var ok bool
|
|
|
|
if db.Statement.Schema.AfterSave {
|
|
|
|
if i, ok := value.(gorm.AfterSaveInterface); ok {
|
|
|
|
ok = true
|
|
|
|
i.AfterSave(db)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if db.Statement.Schema.AfterUpdate {
|
|
|
|
if i, ok := value.(gorm.AfterUpdateInterface); ok {
|
|
|
|
ok = true
|
|
|
|
i.AfterUpdate(db)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return ok
|
|
|
|
}
|
|
|
|
|
|
|
|
if ok := callMethod(db.Statement.Dest); !ok {
|
|
|
|
switch db.Statement.ReflectValue.Kind() {
|
|
|
|
case reflect.Slice, reflect.Array:
|
|
|
|
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
|
|
|
|
callMethod(db.Statement.ReflectValue.Index(i).Interface())
|
|
|
|
}
|
|
|
|
case reflect.Struct:
|
|
|
|
callMethod(db.Statement.ReflectValue.Interface())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2020-02-02 14:32:27 +03:00
|
|
|
}
|
2020-03-07 08:43:20 +03:00
|
|
|
|
|
|
|
// ConvertToAssignments convert to update assignments
|
2020-03-08 08:24:08 +03:00
|
|
|
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
2020-04-08 03:15:00 +03:00
|
|
|
selectColumns, restricted := SelectAndOmitColumns(stmt, false, true)
|
2020-03-07 08:43:20 +03:00
|
|
|
reflectModelValue := reflect.ValueOf(stmt.Model)
|
|
|
|
|
|
|
|
switch value := stmt.Dest.(type) {
|
|
|
|
case map[string]interface{}:
|
2020-03-08 08:24:08 +03:00
|
|
|
set = make([]clause.Assignment, 0, len(value))
|
2020-03-07 08:43:20 +03:00
|
|
|
|
|
|
|
var keys []string
|
|
|
|
for k, _ := range value {
|
|
|
|
keys = append(keys, k)
|
|
|
|
}
|
|
|
|
sort.Strings(keys)
|
|
|
|
|
|
|
|
for _, k := range keys {
|
|
|
|
if field := stmt.Schema.LookUpField(k); field != nil {
|
|
|
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
2020-03-08 08:24:08 +03:00
|
|
|
if field.AutoUpdateTime > 0 {
|
|
|
|
value[k] = time.Now()
|
|
|
|
}
|
2020-03-07 08:43:20 +03:00
|
|
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]})
|
|
|
|
field.Set(reflectModelValue, value[k])
|
|
|
|
}
|
|
|
|
} else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
|
|
|
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-03-08 08:24:08 +03:00
|
|
|
for _, field := range stmt.Schema.FieldsByDBName {
|
|
|
|
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
|
|
|
|
now := time.Now()
|
|
|
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
|
|
|
|
field.Set(reflectModelValue, now)
|
|
|
|
}
|
|
|
|
}
|
2020-03-07 08:43:20 +03:00
|
|
|
default:
|
|
|
|
switch stmt.ReflectValue.Kind() {
|
|
|
|
case reflect.Struct:
|
2020-03-08 08:24:08 +03:00
|
|
|
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
|
2020-03-07 08:43:20 +03:00
|
|
|
for _, field := range stmt.Schema.FieldsByDBName {
|
2020-03-08 08:24:08 +03:00
|
|
|
if !field.PrimaryKey || stmt.Dest != stmt.Model {
|
|
|
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
|
|
|
value, isZero := field.ValueOf(stmt.ReflectValue)
|
|
|
|
if field.AutoUpdateTime > 0 {
|
|
|
|
value = time.Now()
|
|
|
|
isZero = false
|
|
|
|
}
|
|
|
|
|
|
|
|
if ok || !isZero {
|
|
|
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
|
|
|
|
field.Set(reflectModelValue, value)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
|
|
|
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
|
|
|
}
|
2020-03-07 08:43:20 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-03-08 08:24:08 +03:00
|
|
|
if stmt.Dest != stmt.Model {
|
|
|
|
reflectValue := reflect.ValueOf(stmt.Model)
|
|
|
|
for _, field := range stmt.Schema.PrimaryFields {
|
|
|
|
if value, isZero := field.ValueOf(reflectValue); !isZero {
|
|
|
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return
|
2020-03-07 08:43:20 +03:00
|
|
|
}
|