From c0de3c505176b0fea74c2e09fb9cae7c595b7020 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 24 Sep 2020 19:28:52 +0800 Subject: [PATCH] Support FullSaveAssociations Mode, close #3487, #3506 --- callbacks/associations.go | 61 +++++++++++++++++++-------------- callbacks/create.go | 5 ++- gorm.go | 7 ++++ logger/logger.go | 7 ++-- tests/update_belongs_to_test.go | 19 ++++++++++ tests/update_has_many_test.go | 41 ++++++++++++++++++++++ tests/update_has_one_test.go | 35 +++++++++++++++++++ tests/update_many2many_test.go | 25 ++++++++++++++ 8 files changed, 171 insertions(+), 29 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 0c677f47..64d79f24 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -66,9 +66,7 @@ func SaveBeforeAssociations(db *gorm.DB) { } if elems.Len() > 0 { - if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(elems.Interface()).Error) == nil { + if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } @@ -81,9 +79,7 @@ func SaveBeforeAssociations(db *gorm.DB) { rv = rv.Addr() } - if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(rv.Interface()).Error) == nil { + if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil { setupReferences(db.Statement.ReflectValue, rv) } } @@ -145,10 +141,9 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: onConflictColumns(rel.FieldSchema), - DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(elems.Interface()).Error) + db.AddError(db.Session(&gorm.Session{}).Clauses( + onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), + ).Create(elems.Interface()).Error) } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { @@ -168,10 +163,9 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: onConflictColumns(rel.FieldSchema), - DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(f.Interface()).Error) + db.AddError(db.Session(&gorm.Session{}).Clauses( + onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), + ).Create(f.Interface()).Error) } } } @@ -230,10 +224,9 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: onConflictColumns(rel.FieldSchema), - DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(elems.Interface()).Error) + db.AddError(db.Session(&gorm.Session{}).Clauses( + onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), + ).Create(elems.Interface()).Error) } } @@ -298,7 +291,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if elems.Len() > 0 { - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(elems.Interface()).Error) + db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) for i := 0; i < elems.Len(); i++ { appendToJoins(objs[i], elems.Index(i)) @@ -312,13 +305,31 @@ func SaveAfterAssociations(db *gorm.DB) { } } -func onConflictColumns(s *schema.Schema) (columns []clause.Column) { - if s.PrioritizedPrimaryField != nil { - return []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}} +func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) clause.OnConflict { + if stmt.DB.FullSaveAssociations { + defaultUpdatingColumns = make([]string, 0, len(s.DBNames)) + for _, dbName := range s.DBNames { + if !s.LookUpField(dbName).PrimaryKey { + defaultUpdatingColumns = append(defaultUpdatingColumns, dbName) + } + } } - for _, dbName := range s.PrimaryFieldDBNames { - columns = append(columns, clause.Column{Name: dbName}) + if len(defaultUpdatingColumns) > 0 { + var columns []clause.Column + if s.PrioritizedPrimaryField != nil { + columns = []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}} + } else { + for _, dbName := range s.PrimaryFieldDBNames { + columns = append(columns, clause.Column{Name: dbName}) + } + } + + return clause.OnConflict{ + Columns: columns, + DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns), + } } - return + + return clause.OnConflict{DoNothing: true} } diff --git a/callbacks/create.go b/callbacks/create.go index c00a0a73..8e2454e8 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -88,7 +88,10 @@ func Create(config *Config) func(db *gorm.DB) { } case reflect.Struct: if insertID > 0 { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { + + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } } } } else { diff --git a/gorm.go b/gorm.go index 8efd8a73..e5c4a8a4 100644 --- a/gorm.go +++ b/gorm.go @@ -20,6 +20,8 @@ type Config struct { SkipDefaultTransaction bool // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer + // FullSaveAssociations full save associations + FullSaveAssociations bool // Logger Logger logger.Interface // NowFunc the function to be used when creating a new timestamp @@ -64,6 +66,7 @@ type Session struct { WithConditions bool SkipDefaultTransaction bool AllowGlobalUpdate bool + FullSaveAssociations bool Context context.Context Logger logger.Interface NowFunc func() time.Time @@ -161,6 +164,10 @@ func (db *DB) Session(config *Session) *DB { txConfig.AllowGlobalUpdate = true } + if config.FullSaveAssociations { + txConfig.FullSaveAssociations = true + } + if config.Context != nil { tx.Statement = tx.Statement.clone() tx.Statement.DB = tx diff --git a/logger/logger.go b/logger/logger.go index 831192fc..e568fb24 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -20,6 +20,7 @@ const ( Magenta = "\033[35m" Cyan = "\033[36m" White = "\033[37m" + BlueBold = "\033[34;1m" MagentaBold = "\033[35;1m" RedBold = "\033[31;1m" YellowBold = "\033[33;1m" @@ -76,11 +77,11 @@ func New(writer Writer, config Config) Interface { if config.Colorful { infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset - warnStr = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset + warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset - traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s" + traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%d]" + Reset + " %s" traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%d]" + Magenta + " %s" + Reset - traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s" + traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%d]" + Reset + " %s" } return &logger{ diff --git a/tests/update_belongs_to_test.go b/tests/update_belongs_to_test.go index 47076e69..736dfc5b 100644 --- a/tests/update_belongs_to_test.go +++ b/tests/update_belongs_to_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -22,4 +23,22 @@ func TestUpdateBelongsTo(t *testing.T) { var user2 User DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + user.Company.Name += "new" + user.Manager.Name += "new" + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Company").Preload("Manager").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) } diff --git a/tests/update_has_many_test.go b/tests/update_has_many_test.go index 01ea2e3a..9066cbac 100644 --- a/tests/update_has_many_test.go +++ b/tests/update_has_many_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -22,6 +23,26 @@ func TestUpdateHasManyAssociations(t *testing.T) { DB.Preload("Pets").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + for _, pet := range user.Pets { + pet.Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Pets").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Pets").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) + t.Run("Polymorphic", func(t *testing.T) { var user = *GetUser("update-has-many", Config{}) @@ -37,5 +58,25 @@ func TestUpdateHasManyAssociations(t *testing.T) { var user2 User DB.Preload("Toys").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + for idx := range user.Toys { + user.Toys[idx].Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Toys").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Toys").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) }) } diff --git a/tests/update_has_one_test.go b/tests/update_has_one_test.go index 7b29f424..54568546 100644 --- a/tests/update_has_one_test.go +++ b/tests/update_has_one_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -23,6 +24,23 @@ func TestUpdateHasOne(t *testing.T) { DB.Preload("Account").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + user.Account.Number += "new" + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Account").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Account").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) + t.Run("Polymorphic", func(t *testing.T) { var pet = Pet{Name: "create"} @@ -39,5 +57,22 @@ func TestUpdateHasOne(t *testing.T) { var pet2 Pet DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID) CheckPet(t, pet2, pet) + + pet.Toy.Name += "new" + if err := DB.Save(&pet).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var pet3 Pet + DB.Preload("Toy").Find(&pet3, "id = ?", pet.ID) + CheckPet(t, pet2, pet3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&pet).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var pet4 Pet + DB.Preload("Toy").Find(&pet4, "id = ?", pet.ID) + CheckPet(t, pet4, pet) }) } diff --git a/tests/update_many2many_test.go b/tests/update_many2many_test.go index a46deeb0..d94ef4ab 100644 --- a/tests/update_many2many_test.go +++ b/tests/update_many2many_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -26,4 +27,28 @@ func TestUpdateMany2ManyAssociations(t *testing.T) { var user2 User DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + for idx := range user.Friends { + user.Friends[idx].Name += "new" + } + + for idx := range user.Languages { + user.Languages[idx].Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Languages").Preload("Friends").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Languages").Preload("Friends").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) }