Refactor update record (#4679)

This commit is contained in:
Jinzhu 2021-09-07 20:04:54 +08:00 committed by GitHub
parent 6c94b07e98
commit ba16b2368f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 53 deletions

View File

@ -23,38 +23,11 @@ func SetupUpdateReflectValue(db *gorm.DB) {
rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name]) rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name])
} }
} }
} else if modelType, destType := findType(db.Statement.Model), findType(db.Statement.Dest); modelType.Kind() == reflect.Struct && destType.Kind() == reflect.Struct {
db.Statement.Dest = transToModel(reflect.Indirect(reflect.ValueOf(db.Statement.Dest)), reflect.New(modelType).Elem())
} }
} }
} }
} }
func findType(target interface{}) reflect.Type {
t := reflect.TypeOf(target)
if t.Kind() == reflect.Ptr {
return t.Elem()
}
return t
}
func transToModel(from, to reflect.Value) interface{} {
if from.String() == to.String() {
return from.Interface()
}
fromType := from.Type()
for i := 0; i < fromType.NumField(); i++ {
fieldName := fromType.Field(i).Name
fromField, toField := from.FieldByName(fieldName), to.FieldByName(fieldName)
if !toField.IsValid() || !toField.CanSet() || toField.Kind() != fromField.Kind() {
continue
}
toField.Set(fromField)
}
return to.Interface()
}
func BeforeUpdate(db *gorm.DB) { func BeforeUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
@ -249,11 +222,20 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
} }
} }
default: default:
var updatingSchema = stmt.Schema
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
// different schema
updatingStmt := &gorm.Statement{DB: stmt.DB}
if err := updatingStmt.Parse(stmt.Dest); err == nil {
updatingSchema = updatingStmt.Schema
}
}
switch updatingValue.Kind() { switch updatingValue.Kind() {
case reflect.Struct: case reflect.Struct:
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
for _, dbName := range stmt.Schema.DBNames { for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.LookUpField(dbName) if field := updatingSchema.LookUpField(dbName); field != nil && field.Updatable {
if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
value, isZero := field.ValueOf(updatingValue) value, isZero := field.ValueOf(updatingValue)
@ -281,6 +263,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
} }
} }
} }
}
default: default:
stmt.AddError(gorm.ErrInvalidData) stmt.AddError(gorm.ErrInvalidData)
} }

View File

@ -651,14 +651,16 @@ func TestSave(t *testing.T) {
} }
user3.Name = "save3_" user3.Name = "save3_"
DB.Model(User{Model: user3.Model}).Save(&user3) if err := DB.Model(User{Model: user3.Model}).Save(&user3).Error; err != nil {
t.Fatalf("failed to save user, got %v", err)
}
var result2 User var result2 User
if err := DB.First(&result2, "name = ?", "save3_").Error; err != nil || result2.ID != user3.ID { if err := DB.First(&result2, "name = ?", "save3_").Error; err != nil || result2.ID != user3.ID {
t.Fatalf("failed to find updated user") t.Fatalf("failed to find updated user, got %v", err)
} }
DB.Debug().Model(User{Model: user3.Model}).Save(&struct { if err := DB.Model(User{Model: user3.Model}).Save(&struct {
gorm.Model gorm.Model
Placeholder string Placeholder string
Name string Name string
@ -666,7 +668,9 @@ func TestSave(t *testing.T) {
Model: user3.Model, Model: user3.Model,
Placeholder: "placeholder", Placeholder: "placeholder",
Name: "save3__", Name: "save3__",
}) }).Error; err != nil {
t.Fatalf("failed to update user, got %v", err)
}
var result3 User var result3 User
if err := DB.First(&result3, "name = ?", "save3__").Error; err != nil || result3.ID != user3.ID { if err := DB.First(&result3, "name = ?", "save3__").Error; err != nil || result3.ID != user3.ID {