mirror of https://github.com/go-gorm/gorm.git
parent
5228735915
commit
c0de3c5051
|
@ -66,9 +66,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if elems.Len() > 0 {
|
if elems.Len() > 0 {
|
||||||
if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
|
if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil {
|
||||||
DoNothing: true,
|
|
||||||
}).Create(elems.Interface()).Error) == nil {
|
|
||||||
for i := 0; i < elems.Len(); i++ {
|
for i := 0; i < elems.Len(); i++ {
|
||||||
setupReferences(objs[i], elems.Index(i))
|
setupReferences(objs[i], elems.Index(i))
|
||||||
}
|
}
|
||||||
|
@ -81,9 +79,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
||||||
rv = rv.Addr()
|
rv = rv.Addr()
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
|
if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil {
|
||||||
DoNothing: true,
|
|
||||||
}).Create(rv.Interface()).Error) == nil {
|
|
||||||
setupReferences(db.Statement.ReflectValue, rv)
|
setupReferences(db.Statement.ReflectValue, rv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -145,10 +141,9 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
|
db.AddError(db.Session(&gorm.Session{}).Clauses(
|
||||||
Columns: onConflictColumns(rel.FieldSchema),
|
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
|
||||||
DoUpdates: clause.AssignmentColumns(assignmentColumns),
|
).Create(elems.Interface()).Error)
|
||||||
}).Create(elems.Interface()).Error)
|
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
|
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
|
||||||
|
@ -168,10 +163,9 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
|
db.AddError(db.Session(&gorm.Session{}).Clauses(
|
||||||
Columns: onConflictColumns(rel.FieldSchema),
|
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
|
||||||
DoUpdates: clause.AssignmentColumns(assignmentColumns),
|
).Create(f.Interface()).Error)
|
||||||
}).Create(f.Interface()).Error)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -230,10 +224,9 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{
|
db.AddError(db.Session(&gorm.Session{}).Clauses(
|
||||||
Columns: onConflictColumns(rel.FieldSchema),
|
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
|
||||||
DoUpdates: clause.AssignmentColumns(assignmentColumns),
|
).Create(elems.Interface()).Error)
|
||||||
}).Create(elems.Interface()).Error)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -298,7 +291,7 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if elems.Len() > 0 {
|
if elems.Len() > 0 {
|
||||||
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(elems.Interface()).Error)
|
db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error)
|
||||||
|
|
||||||
for i := 0; i < elems.Len(); i++ {
|
for i := 0; i < elems.Len(); i++ {
|
||||||
appendToJoins(objs[i], elems.Index(i))
|
appendToJoins(objs[i], elems.Index(i))
|
||||||
|
@ -312,13 +305,31 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func onConflictColumns(s *schema.Schema) (columns []clause.Column) {
|
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) clause.OnConflict {
|
||||||
if s.PrioritizedPrimaryField != nil {
|
if stmt.DB.FullSaveAssociations {
|
||||||
return []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}}
|
defaultUpdatingColumns = make([]string, 0, len(s.DBNames))
|
||||||
|
for _, dbName := range s.DBNames {
|
||||||
|
if !s.LookUpField(dbName).PrimaryKey {
|
||||||
|
defaultUpdatingColumns = append(defaultUpdatingColumns, dbName)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(defaultUpdatingColumns) > 0 {
|
||||||
|
var columns []clause.Column
|
||||||
|
if s.PrioritizedPrimaryField != nil {
|
||||||
|
columns = []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}}
|
||||||
|
} else {
|
||||||
for _, dbName := range s.PrimaryFieldDBNames {
|
for _, dbName := range s.PrimaryFieldDBNames {
|
||||||
columns = append(columns, clause.Column{Name: dbName})
|
columns = append(columns, clause.Column{Name: dbName})
|
||||||
}
|
}
|
||||||
return
|
}
|
||||||
|
|
||||||
|
return clause.OnConflict{
|
||||||
|
Columns: columns,
|
||||||
|
DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return clause.OnConflict{DoNothing: true}
|
||||||
}
|
}
|
||||||
|
|
|
@ -88,9 +88,12 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if insertID > 0 {
|
if insertID > 0 {
|
||||||
|
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero {
|
||||||
|
|
||||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
|
|
7
gorm.go
7
gorm.go
|
@ -20,6 +20,8 @@ type Config struct {
|
||||||
SkipDefaultTransaction bool
|
SkipDefaultTransaction bool
|
||||||
// NamingStrategy tables, columns naming strategy
|
// NamingStrategy tables, columns naming strategy
|
||||||
NamingStrategy schema.Namer
|
NamingStrategy schema.Namer
|
||||||
|
// FullSaveAssociations full save associations
|
||||||
|
FullSaveAssociations bool
|
||||||
// Logger
|
// Logger
|
||||||
Logger logger.Interface
|
Logger logger.Interface
|
||||||
// NowFunc the function to be used when creating a new timestamp
|
// NowFunc the function to be used when creating a new timestamp
|
||||||
|
@ -64,6 +66,7 @@ type Session struct {
|
||||||
WithConditions bool
|
WithConditions bool
|
||||||
SkipDefaultTransaction bool
|
SkipDefaultTransaction bool
|
||||||
AllowGlobalUpdate bool
|
AllowGlobalUpdate bool
|
||||||
|
FullSaveAssociations bool
|
||||||
Context context.Context
|
Context context.Context
|
||||||
Logger logger.Interface
|
Logger logger.Interface
|
||||||
NowFunc func() time.Time
|
NowFunc func() time.Time
|
||||||
|
@ -161,6 +164,10 @@ func (db *DB) Session(config *Session) *DB {
|
||||||
txConfig.AllowGlobalUpdate = true
|
txConfig.AllowGlobalUpdate = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.FullSaveAssociations {
|
||||||
|
txConfig.FullSaveAssociations = true
|
||||||
|
}
|
||||||
|
|
||||||
if config.Context != nil {
|
if config.Context != nil {
|
||||||
tx.Statement = tx.Statement.clone()
|
tx.Statement = tx.Statement.clone()
|
||||||
tx.Statement.DB = tx
|
tx.Statement.DB = tx
|
||||||
|
|
|
@ -20,6 +20,7 @@ const (
|
||||||
Magenta = "\033[35m"
|
Magenta = "\033[35m"
|
||||||
Cyan = "\033[36m"
|
Cyan = "\033[36m"
|
||||||
White = "\033[37m"
|
White = "\033[37m"
|
||||||
|
BlueBold = "\033[34;1m"
|
||||||
MagentaBold = "\033[35;1m"
|
MagentaBold = "\033[35;1m"
|
||||||
RedBold = "\033[31;1m"
|
RedBold = "\033[31;1m"
|
||||||
YellowBold = "\033[33;1m"
|
YellowBold = "\033[33;1m"
|
||||||
|
@ -76,11 +77,11 @@ func New(writer Writer, config Config) Interface {
|
||||||
|
|
||||||
if config.Colorful {
|
if config.Colorful {
|
||||||
infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset
|
infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset
|
||||||
warnStr = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset
|
warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset
|
||||||
errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
|
errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
|
||||||
traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s"
|
traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%d]" + Reset + " %s"
|
||||||
traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%d]" + Magenta + " %s" + Reset
|
traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%d]" + Magenta + " %s" + Reset
|
||||||
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s"
|
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%d]" + Reset + " %s"
|
||||||
}
|
}
|
||||||
|
|
||||||
return &logger{
|
return &logger{
|
||||||
|
|
|
@ -3,6 +3,7 @@ package tests_test
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -22,4 +23,22 @@ func TestUpdateBelongsTo(t *testing.T) {
|
||||||
var user2 User
|
var user2 User
|
||||||
DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID)
|
DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID)
|
||||||
CheckUser(t, user2, user)
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
user.Company.Name += "new"
|
||||||
|
user.Manager.Name += "new"
|
||||||
|
if err := DB.Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user3 User
|
||||||
|
DB.Preload("Company").Preload("Manager").Find(&user3, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user2, user3)
|
||||||
|
|
||||||
|
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user4 User
|
||||||
|
DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user4, user)
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package tests_test
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -22,6 +23,26 @@ func TestUpdateHasManyAssociations(t *testing.T) {
|
||||||
DB.Preload("Pets").Find(&user2, "id = ?", user.ID)
|
DB.Preload("Pets").Find(&user2, "id = ?", user.ID)
|
||||||
CheckUser(t, user2, user)
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
for _, pet := range user.Pets {
|
||||||
|
pet.Name += "new"
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user3 User
|
||||||
|
DB.Preload("Pets").Find(&user3, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user2, user3)
|
||||||
|
|
||||||
|
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user4 User
|
||||||
|
DB.Preload("Pets").Find(&user4, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user4, user)
|
||||||
|
|
||||||
t.Run("Polymorphic", func(t *testing.T) {
|
t.Run("Polymorphic", func(t *testing.T) {
|
||||||
var user = *GetUser("update-has-many", Config{})
|
var user = *GetUser("update-has-many", Config{})
|
||||||
|
|
||||||
|
@ -37,5 +58,25 @@ func TestUpdateHasManyAssociations(t *testing.T) {
|
||||||
var user2 User
|
var user2 User
|
||||||
DB.Preload("Toys").Find(&user2, "id = ?", user.ID)
|
DB.Preload("Toys").Find(&user2, "id = ?", user.ID)
|
||||||
CheckUser(t, user2, user)
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
for idx := range user.Toys {
|
||||||
|
user.Toys[idx].Name += "new"
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user3 User
|
||||||
|
DB.Preload("Toys").Find(&user3, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user2, user3)
|
||||||
|
|
||||||
|
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user4 User
|
||||||
|
DB.Preload("Toys").Find(&user4, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user4, user)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package tests_test
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,6 +24,23 @@ func TestUpdateHasOne(t *testing.T) {
|
||||||
DB.Preload("Account").Find(&user2, "id = ?", user.ID)
|
DB.Preload("Account").Find(&user2, "id = ?", user.ID)
|
||||||
CheckUser(t, user2, user)
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
user.Account.Number += "new"
|
||||||
|
if err := DB.Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user3 User
|
||||||
|
DB.Preload("Account").Find(&user3, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user2, user3)
|
||||||
|
|
||||||
|
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user4 User
|
||||||
|
DB.Preload("Account").Find(&user4, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user4, user)
|
||||||
|
|
||||||
t.Run("Polymorphic", func(t *testing.T) {
|
t.Run("Polymorphic", func(t *testing.T) {
|
||||||
var pet = Pet{Name: "create"}
|
var pet = Pet{Name: "create"}
|
||||||
|
|
||||||
|
@ -39,5 +57,22 @@ func TestUpdateHasOne(t *testing.T) {
|
||||||
var pet2 Pet
|
var pet2 Pet
|
||||||
DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID)
|
DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID)
|
||||||
CheckPet(t, pet2, pet)
|
CheckPet(t, pet2, pet)
|
||||||
|
|
||||||
|
pet.Toy.Name += "new"
|
||||||
|
if err := DB.Save(&pet).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var pet3 Pet
|
||||||
|
DB.Preload("Toy").Find(&pet3, "id = ?", pet.ID)
|
||||||
|
CheckPet(t, pet2, pet3)
|
||||||
|
|
||||||
|
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&pet).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var pet4 Pet
|
||||||
|
DB.Preload("Toy").Find(&pet4, "id = ?", pet.ID)
|
||||||
|
CheckPet(t, pet4, pet)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package tests_test
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -26,4 +27,28 @@ func TestUpdateMany2ManyAssociations(t *testing.T) {
|
||||||
var user2 User
|
var user2 User
|
||||||
DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID)
|
DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID)
|
||||||
CheckUser(t, user2, user)
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
for idx := range user.Friends {
|
||||||
|
user.Friends[idx].Name += "new"
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx := range user.Languages {
|
||||||
|
user.Languages[idx].Name += "new"
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user3 User
|
||||||
|
DB.Preload("Languages").Preload("Friends").Find(&user3, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user2, user3)
|
||||||
|
|
||||||
|
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user4 User
|
||||||
|
DB.Preload("Languages").Preload("Friends").Find(&user4, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user4, user)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue