gorm/callbacks/update.go

219 lines
6.3 KiB
Go
Raw Normal View History

2020-02-02 14:32:27 +03:00
package callbacks
2020-02-23 16:22:35 +03:00
import (
"reflect"
2020-03-07 08:43:20 +03:00
"sort"
2020-03-08 08:24:08 +03:00
"time"
2020-02-23 16:22:35 +03:00
"github.com/jinzhu/gorm"
2020-03-07 08:43:20 +03:00
"github.com/jinzhu/gorm/clause"
2020-05-24 12:24:23 +03:00
"github.com/jinzhu/gorm/schema"
2020-02-23 16:22:35 +03:00
)
2020-02-02 14:32:27 +03:00
func BeforeUpdate(db *gorm.DB) {
2020-02-23 16:22:35 +03:00
if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok {
ok = true
i.BeforeSave(db)
}
}
if db.Statement.Schema.BeforeUpdate {
if i, ok := value.(gorm.BeforeUpdateInterface); ok {
ok = true
i.BeforeUpdate(db)
}
}
return ok
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
2020-02-02 14:32:27 +03:00
}
func Update(db *gorm.DB) {
2020-05-29 02:35:45 +03:00
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.UpdateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
db.Statement.AddClauseIfNotExists(clause.Update{})
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
db.Statement.AddClause(set)
} else {
return
}
db.Statement.Build("UPDATE", "SET", "WHERE")
2020-05-24 17:52:16 +03:00
}
2020-03-07 08:43:20 +03:00
2020-03-09 08:10:48 +03:00
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
2020-03-07 08:43:20 +03:00
if err == nil {
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
}
2020-02-02 14:32:27 +03:00
}
func AfterUpdate(db *gorm.DB) {
2020-02-23 16:22:35 +03:00
if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok {
ok = true
i.AfterSave(db)
}
}
if db.Statement.Schema.AfterUpdate {
if i, ok := value.(gorm.AfterUpdateInterface); ok {
ok = true
i.AfterUpdate(db)
}
}
return ok
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
2020-02-02 14:32:27 +03:00
}
2020-03-07 08:43:20 +03:00
// ConvertToAssignments convert to update assignments
2020-03-08 08:24:08 +03:00
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
2020-05-24 12:24:23 +03:00
var (
selectColumns, restricted = SelectAndOmitColumns(stmt, false, true)
reflectModelValue = reflect.Indirect(reflect.ValueOf(stmt.Model))
assignValue func(field *schema.Field, value interface{})
)
switch reflectModelValue.Kind() {
case reflect.Slice, reflect.Array:
assignValue = func(field *schema.Field, value interface{}) {
for i := 0; i < reflectModelValue.Len(); i++ {
field.Set(reflectModelValue.Index(i), value)
}
}
case reflect.Struct:
assignValue = func(field *schema.Field, value interface{}) {
if reflectModelValue.CanAddr() {
field.Set(reflectModelValue, value)
}
2020-05-24 12:24:23 +03:00
}
default:
assignValue = func(field *schema.Field, value interface{}) {
}
}
2020-03-07 08:43:20 +03:00
switch value := stmt.Dest.(type) {
case map[string]interface{}:
2020-03-08 08:24:08 +03:00
set = make([]clause.Assignment, 0, len(value))
2020-03-07 08:43:20 +03:00
var keys []string
for k, _ := range value {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
if field := stmt.Schema.LookUpField(k); field != nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]})
2020-05-24 12:24:23 +03:00
assignValue(field, value[k])
2020-03-07 08:43:20 +03:00
}
} else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]})
}
}
2020-05-30 12:34:22 +03:00
if !stmt.DisableUpdateTime {
for _, field := range stmt.Schema.FieldsByDBName {
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
now := time.Now()
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
assignValue(field, now)
}
2020-03-08 08:24:08 +03:00
}
}
2020-03-07 08:43:20 +03:00
default:
switch stmt.ReflectValue.Kind() {
case reflect.Struct:
2020-03-08 08:24:08 +03:00
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
2020-03-07 08:43:20 +03:00
for _, field := range stmt.Schema.FieldsByDBName {
2020-05-30 11:47:16 +03:00
if !field.PrimaryKey || (!stmt.ReflectValue.CanAddr() || stmt.Dest != stmt.Model) {
2020-03-08 08:24:08 +03:00
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
value, isZero := field.ValueOf(stmt.ReflectValue)
2020-05-30 12:34:22 +03:00
if !stmt.DisableUpdateTime {
if field.AutoUpdateTime > 0 {
value = time.Now()
isZero = false
}
2020-03-08 08:24:08 +03:00
}
if ok || !isZero {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
2020-05-24 12:24:23 +03:00
assignValue(field, value)
2020-03-08 08:24:08 +03:00
}
}
} else {
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
2020-03-07 08:43:20 +03:00
}
}
}
}
2020-05-30 11:47:16 +03:00
if !stmt.ReflectValue.CanAddr() || stmt.Dest != stmt.Model {
2020-05-24 15:44:37 +03:00
reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Model))
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
var priamryKeyExprs []clause.Expression
for i := 0; i < reflectValue.Len(); i++ {
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
var notZero bool
for idx, field := range stmt.Schema.PrimaryFields {
value, isZero := field.ValueOf(reflectValue.Index(i))
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
notZero = notZero || !isZero
}
if notZero {
priamryKeyExprs = append(priamryKeyExprs, clause.And(exprs...))
}
}
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}})
case reflect.Struct:
for _, field := range stmt.Schema.PrimaryFields {
if value, isZero := field.ValueOf(reflectValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
2020-03-08 08:24:08 +03:00
}
}
}
2020-05-24 17:52:16 +03:00
2020-03-08 08:24:08 +03:00
return
2020-03-07 08:43:20 +03:00
}