forked from mirror/gorm
Refactor update record (#4679)
This commit is contained in:
parent
6c94b07e98
commit
ba16b2368f
|
@ -23,38 +23,11 @@ func SetupUpdateReflectValue(db *gorm.DB) {
|
||||||
rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name])
|
rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if modelType, destType := findType(db.Statement.Model), findType(db.Statement.Dest); modelType.Kind() == reflect.Struct && destType.Kind() == reflect.Struct {
|
|
||||||
db.Statement.Dest = transToModel(reflect.Indirect(reflect.ValueOf(db.Statement.Dest)), reflect.New(modelType).Elem())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func findType(target interface{}) reflect.Type {
|
|
||||||
t := reflect.TypeOf(target)
|
|
||||||
if t.Kind() == reflect.Ptr {
|
|
||||||
return t.Elem()
|
|
||||||
}
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
func transToModel(from, to reflect.Value) interface{} {
|
|
||||||
if from.String() == to.String() {
|
|
||||||
return from.Interface()
|
|
||||||
}
|
|
||||||
|
|
||||||
fromType := from.Type()
|
|
||||||
for i := 0; i < fromType.NumField(); i++ {
|
|
||||||
fieldName := fromType.Field(i).Name
|
|
||||||
fromField, toField := from.FieldByName(fieldName), to.FieldByName(fieldName)
|
|
||||||
if !toField.IsValid() || !toField.CanSet() || toField.Kind() != fromField.Kind() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
toField.Set(fromField)
|
|
||||||
}
|
|
||||||
return to.Interface()
|
|
||||||
}
|
|
||||||
|
|
||||||
func BeforeUpdate(db *gorm.DB) {
|
func BeforeUpdate(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||||
|
@ -249,35 +222,45 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
var updatingSchema = stmt.Schema
|
||||||
|
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||||
|
// different schema
|
||||||
|
updatingStmt := &gorm.Statement{DB: stmt.DB}
|
||||||
|
if err := updatingStmt.Parse(stmt.Dest); err == nil {
|
||||||
|
updatingSchema = updatingStmt.Schema
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch updatingValue.Kind() {
|
switch updatingValue.Kind() {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
|
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
|
||||||
for _, dbName := range stmt.Schema.DBNames {
|
for _, dbName := range stmt.Schema.DBNames {
|
||||||
field := stmt.Schema.LookUpField(dbName)
|
if field := updatingSchema.LookUpField(dbName); field != nil && field.Updatable {
|
||||||
if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
|
||||||
value, isZero := field.ValueOf(updatingValue)
|
value, isZero := field.ValueOf(updatingValue)
|
||||||
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
|
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
|
||||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||||
value = stmt.DB.NowFunc().UnixNano()
|
value = stmt.DB.NowFunc().UnixNano()
|
||||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||||
value = stmt.DB.NowFunc().UnixNano() / 1e6
|
value = stmt.DB.NowFunc().UnixNano() / 1e6
|
||||||
} else if field.GORMDataType == schema.Time {
|
} else if field.GORMDataType == schema.Time {
|
||||||
value = stmt.DB.NowFunc()
|
value = stmt.DB.NowFunc()
|
||||||
} else {
|
} else {
|
||||||
value = stmt.DB.NowFunc().Unix()
|
value = stmt.DB.NowFunc().Unix()
|
||||||
|
}
|
||||||
|
isZero = false
|
||||||
}
|
}
|
||||||
isZero = false
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok || !isZero {
|
if ok || !isZero {
|
||||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
|
||||||
assignValue(field, value)
|
assignValue(field, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if value, isZero := field.ValueOf(updatingValue); !isZero {
|
||||||
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
||||||
}
|
}
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if value, isZero := field.ValueOf(updatingValue); !isZero {
|
|
||||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -651,14 +651,16 @@ func TestSave(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
user3.Name = "save3_"
|
user3.Name = "save3_"
|
||||||
DB.Model(User{Model: user3.Model}).Save(&user3)
|
if err := DB.Model(User{Model: user3.Model}).Save(&user3).Error; err != nil {
|
||||||
|
t.Fatalf("failed to save user, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
var result2 User
|
var result2 User
|
||||||
if err := DB.First(&result2, "name = ?", "save3_").Error; err != nil || result2.ID != user3.ID {
|
if err := DB.First(&result2, "name = ?", "save3_").Error; err != nil || result2.ID != user3.ID {
|
||||||
t.Fatalf("failed to find updated user")
|
t.Fatalf("failed to find updated user, got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Debug().Model(User{Model: user3.Model}).Save(&struct {
|
if err := DB.Model(User{Model: user3.Model}).Save(&struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Placeholder string
|
Placeholder string
|
||||||
Name string
|
Name string
|
||||||
|
@ -666,7 +668,9 @@ func TestSave(t *testing.T) {
|
||||||
Model: user3.Model,
|
Model: user3.Model,
|
||||||
Placeholder: "placeholder",
|
Placeholder: "placeholder",
|
||||||
Name: "save3__",
|
Name: "save3__",
|
||||||
})
|
}).Error; err != nil {
|
||||||
|
t.Fatalf("failed to update user, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
var result3 User
|
var result3 User
|
||||||
if err := DB.First(&result3, "name = ?", "save3__").Error; err != nil || result3.ID != user3.ID {
|
if err := DB.First(&result3, "name = ?", "save3__").Error; err != nil || result3.ID != user3.ID {
|
||||||
|
|
Loading…
Reference in New Issue