Support FullSaveAssociations Mode, close #3487, #3506

This commit is contained in:
Jinzhu 2020-09-24 19:28:52 +08:00
parent 5228735915
commit c0de3c5051
8 changed files with 171 additions and 29 deletions

View File

@ -66,9 +66,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
}
if elems.Len() > 0 {
if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
DoNothing: true,
}).Create(elems.Interface()).Error) == nil {
if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil {
for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i))
}
@ -81,9 +79,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
rv = rv.Addr()
}
if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
DoNothing: true,
}).Create(rv.Interface()).Error) == nil {
if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil {
setupReferences(db.Statement.ReflectValue, rv)
}
}
@ -145,10 +141,9 @@ func SaveAfterAssociations(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
Columns: onConflictColumns(rel.FieldSchema),
DoUpdates: clause.AssignmentColumns(assignmentColumns),
}).Create(elems.Interface()).Error)
db.AddError(db.Session(&gorm.Session{}).Clauses(
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
).Create(elems.Interface()).Error)
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
@ -168,10 +163,9 @@ func SaveAfterAssociations(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
Columns: onConflictColumns(rel.FieldSchema),
DoUpdates: clause.AssignmentColumns(assignmentColumns),
}).Create(f.Interface()).Error)
db.AddError(db.Session(&gorm.Session{}).Clauses(
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
).Create(f.Interface()).Error)
}
}
}
@ -230,10 +224,9 @@ func SaveAfterAssociations(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
Columns: onConflictColumns(rel.FieldSchema),
DoUpdates: clause.AssignmentColumns(assignmentColumns),
}).Create(elems.Interface()).Error)
db.AddError(db.Session(&gorm.Session{}).Clauses(
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
).Create(elems.Interface()).Error)
}
}
@ -298,7 +291,7 @@ func SaveAfterAssociations(db *gorm.DB) {
}
if elems.Len() > 0 {
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(elems.Interface()).Error)
db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error)
for i := 0; i < elems.Len(); i++ {
appendToJoins(objs[i], elems.Index(i))
@ -312,13 +305,31 @@ func SaveAfterAssociations(db *gorm.DB) {
}
}
func onConflictColumns(s *schema.Schema) (columns []clause.Column) {
if s.PrioritizedPrimaryField != nil {
return []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}}
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) clause.OnConflict {
if stmt.DB.FullSaveAssociations {
defaultUpdatingColumns = make([]string, 0, len(s.DBNames))
for _, dbName := range s.DBNames {
if !s.LookUpField(dbName).PrimaryKey {
defaultUpdatingColumns = append(defaultUpdatingColumns, dbName)
}
}
}
for _, dbName := range s.PrimaryFieldDBNames {
columns = append(columns, clause.Column{Name: dbName})
if len(defaultUpdatingColumns) > 0 {
var columns []clause.Column
if s.PrioritizedPrimaryField != nil {
columns = []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}}
} else {
for _, dbName := range s.PrimaryFieldDBNames {
columns = append(columns, clause.Column{Name: dbName})
}
}
return clause.OnConflict{
Columns: columns,
DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns),
}
}
return
return clause.OnConflict{DoNothing: true}
}

View File

@ -88,7 +88,10 @@ func Create(config *Config) func(db *gorm.DB) {
}
case reflect.Struct:
if insertID > 0 {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
}
}
}
} else {

View File

@ -20,6 +20,8 @@ type Config struct {
SkipDefaultTransaction bool
// NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer
// FullSaveAssociations full save associations
FullSaveAssociations bool
// Logger
Logger logger.Interface
// NowFunc the function to be used when creating a new timestamp
@ -64,6 +66,7 @@ type Session struct {
WithConditions bool
SkipDefaultTransaction bool
AllowGlobalUpdate bool
FullSaveAssociations bool
Context context.Context
Logger logger.Interface
NowFunc func() time.Time
@ -161,6 +164,10 @@ func (db *DB) Session(config *Session) *DB {
txConfig.AllowGlobalUpdate = true
}
if config.FullSaveAssociations {
txConfig.FullSaveAssociations = true
}
if config.Context != nil {
tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx

View File

@ -20,6 +20,7 @@ const (
Magenta = "\033[35m"
Cyan = "\033[36m"
White = "\033[37m"
BlueBold = "\033[34;1m"
MagentaBold = "\033[35;1m"
RedBold = "\033[31;1m"
YellowBold = "\033[33;1m"
@ -76,11 +77,11 @@ func New(writer Writer, config Config) Interface {
if config.Colorful {
infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset
warnStr = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset
warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset
errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s"
traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%d]" + Reset + " %s"
traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%d]" + Magenta + " %s" + Reset
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s"
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%d]" + Reset + " %s"
}
return &logger{

View File

@ -3,6 +3,7 @@ package tests_test
import (
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
@ -22,4 +23,22 @@ func TestUpdateBelongsTo(t *testing.T) {
var user2 User
DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID)
CheckUser(t, user2, user)
user.Company.Name += "new"
user.Manager.Name += "new"
if err := DB.Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user3 User
DB.Preload("Company").Preload("Manager").Find(&user3, "id = ?", user.ID)
CheckUser(t, user2, user3)
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user4 User
DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID)
CheckUser(t, user4, user)
}

View File

@ -3,6 +3,7 @@ package tests_test
import (
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
@ -22,6 +23,26 @@ func TestUpdateHasManyAssociations(t *testing.T) {
DB.Preload("Pets").Find(&user2, "id = ?", user.ID)
CheckUser(t, user2, user)
for _, pet := range user.Pets {
pet.Name += "new"
}
if err := DB.Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user3 User
DB.Preload("Pets").Find(&user3, "id = ?", user.ID)
CheckUser(t, user2, user3)
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user4 User
DB.Preload("Pets").Find(&user4, "id = ?", user.ID)
CheckUser(t, user4, user)
t.Run("Polymorphic", func(t *testing.T) {
var user = *GetUser("update-has-many", Config{})
@ -37,5 +58,25 @@ func TestUpdateHasManyAssociations(t *testing.T) {
var user2 User
DB.Preload("Toys").Find(&user2, "id = ?", user.ID)
CheckUser(t, user2, user)
for idx := range user.Toys {
user.Toys[idx].Name += "new"
}
if err := DB.Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user3 User
DB.Preload("Toys").Find(&user3, "id = ?", user.ID)
CheckUser(t, user2, user3)
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user4 User
DB.Preload("Toys").Find(&user4, "id = ?", user.ID)
CheckUser(t, user4, user)
})
}

View File

@ -3,6 +3,7 @@ package tests_test
import (
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
@ -23,6 +24,23 @@ func TestUpdateHasOne(t *testing.T) {
DB.Preload("Account").Find(&user2, "id = ?", user.ID)
CheckUser(t, user2, user)
user.Account.Number += "new"
if err := DB.Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user3 User
DB.Preload("Account").Find(&user3, "id = ?", user.ID)
CheckUser(t, user2, user3)
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user4 User
DB.Preload("Account").Find(&user4, "id = ?", user.ID)
CheckUser(t, user4, user)
t.Run("Polymorphic", func(t *testing.T) {
var pet = Pet{Name: "create"}
@ -39,5 +57,22 @@ func TestUpdateHasOne(t *testing.T) {
var pet2 Pet
DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID)
CheckPet(t, pet2, pet)
pet.Toy.Name += "new"
if err := DB.Save(&pet).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var pet3 Pet
DB.Preload("Toy").Find(&pet3, "id = ?", pet.ID)
CheckPet(t, pet2, pet3)
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&pet).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var pet4 Pet
DB.Preload("Toy").Find(&pet4, "id = ?", pet.ID)
CheckPet(t, pet4, pet)
})
}

View File

@ -3,6 +3,7 @@ package tests_test
import (
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
@ -26,4 +27,28 @@ func TestUpdateMany2ManyAssociations(t *testing.T) {
var user2 User
DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID)
CheckUser(t, user2, user)
for idx := range user.Friends {
user.Friends[idx].Name += "new"
}
for idx := range user.Languages {
user.Languages[idx].Name += "new"
}
if err := DB.Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user3 User
DB.Preload("Languages").Preload("Friends").Find(&user3, "id = ?", user.ID)
CheckUser(t, user2, user3)
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user4 User
DB.Preload("Languages").Preload("Friends").Find(&user4, "id = ?", user.ID)
CheckUser(t, user4, user)
}