forked from mirror/gorm
parent
5228735915
commit
c0de3c5051
|
@ -66,9 +66,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
|||
}
|
||||
|
||||
if elems.Len() > 0 {
|
||||
if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
|
||||
DoNothing: true,
|
||||
}).Create(elems.Interface()).Error) == nil {
|
||||
if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil {
|
||||
for i := 0; i < elems.Len(); i++ {
|
||||
setupReferences(objs[i], elems.Index(i))
|
||||
}
|
||||
|
@ -81,9 +79,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
|||
rv = rv.Addr()
|
||||
}
|
||||
|
||||
if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
|
||||
DoNothing: true,
|
||||
}).Create(rv.Interface()).Error) == nil {
|
||||
if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil {
|
||||
setupReferences(db.Statement.ReflectValue, rv)
|
||||
}
|
||||
}
|
||||
|
@ -145,10 +141,9 @@ func SaveAfterAssociations(db *gorm.DB) {
|
|||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
|
||||
Columns: onConflictColumns(rel.FieldSchema),
|
||||
DoUpdates: clause.AssignmentColumns(assignmentColumns),
|
||||
}).Create(elems.Interface()).Error)
|
||||
db.AddError(db.Session(&gorm.Session{}).Clauses(
|
||||
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
|
||||
).Create(elems.Interface()).Error)
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
|
||||
|
@ -168,10 +163,9 @@ func SaveAfterAssociations(db *gorm.DB) {
|
|||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
|
||||
Columns: onConflictColumns(rel.FieldSchema),
|
||||
DoUpdates: clause.AssignmentColumns(assignmentColumns),
|
||||
}).Create(f.Interface()).Error)
|
||||
db.AddError(db.Session(&gorm.Session{}).Clauses(
|
||||
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
|
||||
).Create(f.Interface()).Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -230,10 +224,9 @@ func SaveAfterAssociations(db *gorm.DB) {
|
|||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
|
||||
Columns: onConflictColumns(rel.FieldSchema),
|
||||
DoUpdates: clause.AssignmentColumns(assignmentColumns),
|
||||
}).Create(elems.Interface()).Error)
|
||||
db.AddError(db.Session(&gorm.Session{}).Clauses(
|
||||
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
|
||||
).Create(elems.Interface()).Error)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -298,7 +291,7 @@ func SaveAfterAssociations(db *gorm.DB) {
|
|||
}
|
||||
|
||||
if elems.Len() > 0 {
|
||||
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(elems.Interface()).Error)
|
||||
db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error)
|
||||
|
||||
for i := 0; i < elems.Len(); i++ {
|
||||
appendToJoins(objs[i], elems.Index(i))
|
||||
|
@ -312,13 +305,31 @@ func SaveAfterAssociations(db *gorm.DB) {
|
|||
}
|
||||
}
|
||||
|
||||
func onConflictColumns(s *schema.Schema) (columns []clause.Column) {
|
||||
if s.PrioritizedPrimaryField != nil {
|
||||
return []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}}
|
||||
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) clause.OnConflict {
|
||||
if stmt.DB.FullSaveAssociations {
|
||||
defaultUpdatingColumns = make([]string, 0, len(s.DBNames))
|
||||
for _, dbName := range s.DBNames {
|
||||
if !s.LookUpField(dbName).PrimaryKey {
|
||||
defaultUpdatingColumns = append(defaultUpdatingColumns, dbName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, dbName := range s.PrimaryFieldDBNames {
|
||||
columns = append(columns, clause.Column{Name: dbName})
|
||||
if len(defaultUpdatingColumns) > 0 {
|
||||
var columns []clause.Column
|
||||
if s.PrioritizedPrimaryField != nil {
|
||||
columns = []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}}
|
||||
} else {
|
||||
for _, dbName := range s.PrimaryFieldDBNames {
|
||||
columns = append(columns, clause.Column{Name: dbName})
|
||||
}
|
||||
}
|
||||
|
||||
return clause.OnConflict{
|
||||
Columns: columns,
|
||||
DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns),
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
return clause.OnConflict{DoNothing: true}
|
||||
}
|
||||
|
|
|
@ -88,7 +88,10 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||
}
|
||||
case reflect.Struct:
|
||||
if insertID > 0 {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
||||
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero {
|
||||
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
7
gorm.go
7
gorm.go
|
@ -20,6 +20,8 @@ type Config struct {
|
|||
SkipDefaultTransaction bool
|
||||
// NamingStrategy tables, columns naming strategy
|
||||
NamingStrategy schema.Namer
|
||||
// FullSaveAssociations full save associations
|
||||
FullSaveAssociations bool
|
||||
// Logger
|
||||
Logger logger.Interface
|
||||
// NowFunc the function to be used when creating a new timestamp
|
||||
|
@ -64,6 +66,7 @@ type Session struct {
|
|||
WithConditions bool
|
||||
SkipDefaultTransaction bool
|
||||
AllowGlobalUpdate bool
|
||||
FullSaveAssociations bool
|
||||
Context context.Context
|
||||
Logger logger.Interface
|
||||
NowFunc func() time.Time
|
||||
|
@ -161,6 +164,10 @@ func (db *DB) Session(config *Session) *DB {
|
|||
txConfig.AllowGlobalUpdate = true
|
||||
}
|
||||
|
||||
if config.FullSaveAssociations {
|
||||
txConfig.FullSaveAssociations = true
|
||||
}
|
||||
|
||||
if config.Context != nil {
|
||||
tx.Statement = tx.Statement.clone()
|
||||
tx.Statement.DB = tx
|
||||
|
|
|
@ -20,6 +20,7 @@ const (
|
|||
Magenta = "\033[35m"
|
||||
Cyan = "\033[36m"
|
||||
White = "\033[37m"
|
||||
BlueBold = "\033[34;1m"
|
||||
MagentaBold = "\033[35;1m"
|
||||
RedBold = "\033[31;1m"
|
||||
YellowBold = "\033[33;1m"
|
||||
|
@ -76,11 +77,11 @@ func New(writer Writer, config Config) Interface {
|
|||
|
||||
if config.Colorful {
|
||||
infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset
|
||||
warnStr = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset
|
||||
warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset
|
||||
errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
|
||||
traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s"
|
||||
traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%d]" + Reset + " %s"
|
||||
traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%d]" + Magenta + " %s" + Reset
|
||||
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s"
|
||||
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%d]" + Reset + " %s"
|
||||
}
|
||||
|
||||
return &logger{
|
||||
|
|
|
@ -3,6 +3,7 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
|
@ -22,4 +23,22 @@ func TestUpdateBelongsTo(t *testing.T) {
|
|||
var user2 User
|
||||
DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user)
|
||||
|
||||
user.Company.Name += "new"
|
||||
user.Manager.Name += "new"
|
||||
if err := DB.Save(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var user3 User
|
||||
DB.Preload("Company").Preload("Manager").Find(&user3, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user3)
|
||||
|
||||
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var user4 User
|
||||
DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID)
|
||||
CheckUser(t, user4, user)
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
|
@ -22,6 +23,26 @@ func TestUpdateHasManyAssociations(t *testing.T) {
|
|||
DB.Preload("Pets").Find(&user2, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user)
|
||||
|
||||
for _, pet := range user.Pets {
|
||||
pet.Name += "new"
|
||||
}
|
||||
|
||||
if err := DB.Save(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var user3 User
|
||||
DB.Preload("Pets").Find(&user3, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user3)
|
||||
|
||||
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var user4 User
|
||||
DB.Preload("Pets").Find(&user4, "id = ?", user.ID)
|
||||
CheckUser(t, user4, user)
|
||||
|
||||
t.Run("Polymorphic", func(t *testing.T) {
|
||||
var user = *GetUser("update-has-many", Config{})
|
||||
|
||||
|
@ -37,5 +58,25 @@ func TestUpdateHasManyAssociations(t *testing.T) {
|
|||
var user2 User
|
||||
DB.Preload("Toys").Find(&user2, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user)
|
||||
|
||||
for idx := range user.Toys {
|
||||
user.Toys[idx].Name += "new"
|
||||
}
|
||||
|
||||
if err := DB.Save(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var user3 User
|
||||
DB.Preload("Toys").Find(&user3, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user3)
|
||||
|
||||
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var user4 User
|
||||
DB.Preload("Toys").Find(&user4, "id = ?", user.ID)
|
||||
CheckUser(t, user4, user)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
|
@ -23,6 +24,23 @@ func TestUpdateHasOne(t *testing.T) {
|
|||
DB.Preload("Account").Find(&user2, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user)
|
||||
|
||||
user.Account.Number += "new"
|
||||
if err := DB.Save(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var user3 User
|
||||
DB.Preload("Account").Find(&user3, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user3)
|
||||
|
||||
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var user4 User
|
||||
DB.Preload("Account").Find(&user4, "id = ?", user.ID)
|
||||
CheckUser(t, user4, user)
|
||||
|
||||
t.Run("Polymorphic", func(t *testing.T) {
|
||||
var pet = Pet{Name: "create"}
|
||||
|
||||
|
@ -39,5 +57,22 @@ func TestUpdateHasOne(t *testing.T) {
|
|||
var pet2 Pet
|
||||
DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID)
|
||||
CheckPet(t, pet2, pet)
|
||||
|
||||
pet.Toy.Name += "new"
|
||||
if err := DB.Save(&pet).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var pet3 Pet
|
||||
DB.Preload("Toy").Find(&pet3, "id = ?", pet.ID)
|
||||
CheckPet(t, pet2, pet3)
|
||||
|
||||
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&pet).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var pet4 Pet
|
||||
DB.Preload("Toy").Find(&pet4, "id = ?", pet.ID)
|
||||
CheckPet(t, pet4, pet)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
|
@ -26,4 +27,28 @@ func TestUpdateMany2ManyAssociations(t *testing.T) {
|
|||
var user2 User
|
||||
DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user)
|
||||
|
||||
for idx := range user.Friends {
|
||||
user.Friends[idx].Name += "new"
|
||||
}
|
||||
|
||||
for idx := range user.Languages {
|
||||
user.Languages[idx].Name += "new"
|
||||
}
|
||||
|
||||
if err := DB.Save(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var user3 User
|
||||
DB.Preload("Languages").Preload("Friends").Find(&user3, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user3)
|
||||
|
||||
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when update: %v", err)
|
||||
}
|
||||
|
||||
var user4 User
|
||||
DB.Preload("Languages").Preload("Friends").Find(&user4, "id = ?", user.ID)
|
||||
CheckUser(t, user4, user)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue