From ba16b2368f253572195de14fef62272a752595ef Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 7 Sep 2021 20:04:54 +0800 Subject: [PATCH] Refactor update record (#4679) --- callbacks/update.go | 81 +++++++++++++++++--------------------------- tests/update_test.go | 12 ++++--- 2 files changed, 40 insertions(+), 53 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index ee60bcd7..7d5ea4a4 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -23,38 +23,11 @@ func SetupUpdateReflectValue(db *gorm.DB) { rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name]) } } - } else if modelType, destType := findType(db.Statement.Model), findType(db.Statement.Dest); modelType.Kind() == reflect.Struct && destType.Kind() == reflect.Struct { - db.Statement.Dest = transToModel(reflect.Indirect(reflect.ValueOf(db.Statement.Dest)), reflect.New(modelType).Elem()) } } } } -func findType(target interface{}) reflect.Type { - t := reflect.TypeOf(target) - if t.Kind() == reflect.Ptr { - return t.Elem() - } - return t -} - -func transToModel(from, to reflect.Value) interface{} { - if from.String() == to.String() { - return from.Interface() - } - - fromType := from.Type() - for i := 0; i < fromType.NumField(); i++ { - fieldName := fromType.Field(i).Name - fromField, toField := from.FieldByName(fieldName), to.FieldByName(fieldName) - if !toField.IsValid() || !toField.CanSet() || toField.Kind() != fromField.Kind() { - continue - } - toField.Set(fromField) - } - return to.Interface() -} - func BeforeUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { @@ -249,35 +222,45 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } default: + var updatingSchema = stmt.Schema + if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { + // different schema + updatingStmt := &gorm.Statement{DB: stmt.DB} + if err := updatingStmt.Parse(stmt.Dest); err == nil { + updatingSchema = updatingStmt.Schema + } + } + switch updatingValue.Kind() { case reflect.Struct: set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) for _, dbName := range stmt.Schema.DBNames { - field := stmt.Schema.LookUpField(dbName) - if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { - if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { - value, isZero := field.ValueOf(updatingValue) - if !stmt.SkipHooks && field.AutoUpdateTime > 0 { - if field.AutoUpdateTime == schema.UnixNanosecond { - value = stmt.DB.NowFunc().UnixNano() - } else if field.AutoUpdateTime == schema.UnixMillisecond { - value = stmt.DB.NowFunc().UnixNano() / 1e6 - } else if field.GORMDataType == schema.Time { - value = stmt.DB.NowFunc() - } else { - value = stmt.DB.NowFunc().Unix() + if field := updatingSchema.LookUpField(dbName); field != nil && field.Updatable { + if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { + value, isZero := field.ValueOf(updatingValue) + if !stmt.SkipHooks && field.AutoUpdateTime > 0 { + if field.AutoUpdateTime == schema.UnixNanosecond { + value = stmt.DB.NowFunc().UnixNano() + } else if field.AutoUpdateTime == schema.UnixMillisecond { + value = stmt.DB.NowFunc().UnixNano() / 1e6 + } else if field.GORMDataType == schema.Time { + value = stmt.DB.NowFunc() + } else { + value = stmt.DB.NowFunc().Unix() + } + isZero = false } - isZero = false - } - if ok || !isZero { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) - assignValue(field, value) + if ok || !isZero { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) + assignValue(field, value) + } + } + } else { + if value, isZero := field.ValueOf(updatingValue); !isZero { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } - } - } else { - if value, isZero := field.ValueOf(updatingValue); !isZero { - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } } } diff --git a/tests/update_test.go b/tests/update_test.go index 2a747ce5..9e5b630e 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -651,14 +651,16 @@ func TestSave(t *testing.T) { } user3.Name = "save3_" - DB.Model(User{Model: user3.Model}).Save(&user3) + if err := DB.Model(User{Model: user3.Model}).Save(&user3).Error; err != nil { + t.Fatalf("failed to save user, got %v", err) + } var result2 User if err := DB.First(&result2, "name = ?", "save3_").Error; err != nil || result2.ID != user3.ID { - t.Fatalf("failed to find updated user") + t.Fatalf("failed to find updated user, got %v", err) } - DB.Debug().Model(User{Model: user3.Model}).Save(&struct { + if err := DB.Model(User{Model: user3.Model}).Save(&struct { gorm.Model Placeholder string Name string @@ -666,7 +668,9 @@ func TestSave(t *testing.T) { Model: user3.Model, Placeholder: "placeholder", Name: "save3__", - }) + }).Error; err != nil { + t.Fatalf("failed to update user, got %v", err) + } var result3 User if err := DB.First(&result3, "name = ?", "save3__").Error; err != nil || result3.ID != user3.ID {