From fcb666cfa31ecf0de77fcd23e60a67c6819ad7fa Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 10:58:48 +0800 Subject: [PATCH] Fix associations using composite primary keys without ID field, close #3365 --- callbacks/associations.go | 18 +++++++++++++--- tests/multi_primary_keys_test.go | 36 ++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 2710ffe9..0c677f47 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -5,6 +5,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" + "gorm.io/gorm/schema" ) func SaveBeforeAssociations(db *gorm.DB) { @@ -145,7 +146,7 @@ func SaveAfterAssociations(db *gorm.DB) { } db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, + Columns: onConflictColumns(rel.FieldSchema), DoUpdates: clause.AssignmentColumns(assignmentColumns), }).Create(elems.Interface()).Error) } @@ -168,7 +169,7 @@ func SaveAfterAssociations(db *gorm.DB) { } db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, + Columns: onConflictColumns(rel.FieldSchema), DoUpdates: clause.AssignmentColumns(assignmentColumns), }).Create(f.Interface()).Error) } @@ -230,7 +231,7 @@ func SaveAfterAssociations(db *gorm.DB) { } db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, + Columns: onConflictColumns(rel.FieldSchema), DoUpdates: clause.AssignmentColumns(assignmentColumns), }).Create(elems.Interface()).Error) } @@ -310,3 +311,14 @@ func SaveAfterAssociations(db *gorm.DB) { } } } + +func onConflictColumns(s *schema.Schema) (columns []clause.Column) { + if s.PrioritizedPrimaryField != nil { + return []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}} + } + + for _, dbName := range s.PrimaryFieldDBNames { + columns = append(columns, clause.Column{Name: dbName}) + } + return +} diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go index 051e3ee2..68da8a88 100644 --- a/tests/multi_primary_keys_test.go +++ b/tests/multi_primary_keys_test.go @@ -6,6 +6,7 @@ import ( "testing" "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" ) type Blog struct { @@ -410,3 +411,38 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { t.Fatalf("EN Blog's tags should be cleared") } } + +func TestCompositePrimaryKeysAssociations(t *testing.T) { + type Label struct { + BookID *uint `gorm:"primarykey"` + Name string `gorm:"primarykey"` + Value string + } + + type Book struct { + ID int + Name string + Labels []Label + } + + DB.Migrator().DropTable(&Label{}, &Book{}) + if err := DB.AutoMigrate(&Label{}, &Book{}); err != nil { + t.Fatalf("failed to migrate") + } + + book := Book{ + Name: "my book", + Labels: []Label{ + {Name: "region", Value: "emea"}, + }, + } + + DB.Create(&book) + + var result Book + if err := DB.Preload("Labels").First(&result, book.ID).Error; err != nil { + t.Fatalf("failed to preload, got error %v", err) + } + + AssertEqual(t, book, result) +}