diff --git a/callbacks/preload.go b/callbacks/preload.go index a77db2b1..5b5beb06 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -52,8 +52,8 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()) // convert join identity map to relation identity map - fieldValues := make([]interface{}, len(foreignFields)) - joinFieldValues := make([]interface{}, len(joinForeignFields)) + fieldValues := make([]interface{}, len(joinForeignFields)) + joinFieldValues := make([]interface{}, len(joinRelForeignFields)) for i := 0; i < joinResults.Len(); i++ { for idx, field := range joinForeignFields { fieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) @@ -94,7 +94,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { column, values := schema.ToQueryValues(relForeignKeys, foreignValues) tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), conds...) - fieldValues := make([]interface{}, len(foreignFields)) + fieldValues := make([]interface{}, len(relForeignFields)) for i := 0; i < reflectResults.Len(); i++ { for idx, field := range relForeignFields { fieldValues[idx], _ = field.ValueOf(reflectResults.Index(i)) diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 3828c546..066aa38f 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -19,6 +19,10 @@ type Dialector struct { DSN string } +func (dialector Dialector) Name() string { + return "mssql" +} + func Open(dsn string) gorm.Dialector { return &Dialector{DSN: dsn} } diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index baeb79c7..e617a1e1 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -22,6 +22,10 @@ func Open(dsn string) gorm.Dialector { return &Dialector{DSN: dsn} } +func (dialector Dialector) Name() string { + return "mysql" +} + func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index db559b9d..fb3ecc68 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -23,6 +23,10 @@ func Open(dsn string) gorm.Dialector { return &Dialector{DSN: dsn} } +func (dialector Dialector) Name() string { + return "postgres" +} + func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 51829b17..1b9809af 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -20,6 +20,10 @@ func Open(dsn string) gorm.Dialector { return &Dialector{DSN: dsn} } +func (dialector Dialector) Name() string { + return "sqlite" +} + func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ diff --git a/interfaces.go b/interfaces.go index 14d8fa34..421428a3 100644 --- a/interfaces.go +++ b/interfaces.go @@ -10,6 +10,7 @@ import ( // Dialector GORM database dialector type Dialector interface { + Name() string Initialize(*DB) error Migrator(db *DB) Migrator DataTypeOf(*schema.Field) string diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 41e8c7bd..0f62f45d 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -197,3 +197,51 @@ func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { }, }) } + +func TestMany2ManyWithMultiPrimaryKeys(t *testing.T) { + type Tag struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Value string + } + + type Blog struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Subject string + Body string + Tags []Tag `gorm:"many2many:blog_tags;"` + SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;References:id"` + LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;References:id"` + } + + checkStructRelation(t, &Blog{}, + Relation{ + Name: "Tags", Type: schema.Many2Many, Schema: "Blog", FieldSchema: "Tag", + JoinTable: JoinTable{Name: "blog_tags", Table: "blog_tags"}, + References: []Reference{ + {"ID", "Blog", "BlogID", "blog_tags", "", true}, + {"Locale", "Blog", "BlogLocale", "blog_tags", "", true}, + {"ID", "Tag", "TagID", "blog_tags", "", false}, + {"Locale", "Tag", "TagLocale", "blog_tags", "", false}, + }, + }, + Relation{ + Name: "SharedTags", Type: schema.Many2Many, Schema: "Blog", FieldSchema: "Tag", + JoinTable: JoinTable{Name: "shared_blog_tags", Table: "shared_blog_tags"}, + References: []Reference{ + {"ID", "Blog", "BlogID", "shared_blog_tags", "", true}, + {"ID", "Tag", "TagID", "shared_blog_tags", "", false}, + }, + }, + Relation{ + Name: "LocaleTags", Type: schema.Many2Many, Schema: "Blog", FieldSchema: "Tag", + JoinTable: JoinTable{Name: "locale_blog_tags", Table: "locale_blog_tags"}, + References: []Reference{ + {"ID", "Blog", "BlogID", "locale_blog_tags", "", true}, + {"Locale", "Blog", "BlogLocale", "locale_blog_tags", "", true}, + {"ID", "Tag", "TagID", "locale_blog_tags", "", false}, + }, + }, + ) +} diff --git a/tests/dummy_dialecter.go b/tests/dummy_dialecter.go index 63af0c9c..4ea17a0f 100644 --- a/tests/dummy_dialecter.go +++ b/tests/dummy_dialecter.go @@ -10,6 +10,10 @@ import ( type DummyDialector struct { } +func (DummyDialector) Name() string { + return "dummy" +} + func (DummyDialector) Initialize(*gorm.DB) error { return nil } diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go new file mode 100644 index 00000000..b3284f15 --- /dev/null +++ b/tests/multi_primary_keys_test.go @@ -0,0 +1,395 @@ +package tests_test + +import ( + "reflect" + "sort" + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +type Blog struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Subject string + Body string + Tags []Tag `gorm:"many2many:blog_tags;"` + SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;References:id"` + LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;References:id"` +} + +type Tag struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Value string + Blogs []*Blog `gorm:"many2many:blogs_tags"` +} + +func compareTags(tags []Tag, contents []string) bool { + var tagContents []string + for _, tag := range tags { + tagContents = append(tagContents, tag.Value) + } + sort.Strings(tagContents) + sort.Strings(contents) + return reflect.DeepEqual(tagContents, contents) +} + +func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { + t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") + } + + DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") + if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { + t.Fatalf("Failed to auto migrate, got error: %v", err) + } + + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + Tags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + + DB.Save(&blog) + if !compareTags(blog.Tags, []string{"tag1", "tag2"}) { + t.Fatalf("Blog should has two tags") + } + + // Append + var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("Tags").Append([]*Tag{tag3}) + + if !compareTags(blog.Tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Blog should has three tags after Append") + } + + if count := DB.Model(&blog).Association("Tags").Count(); count != 3 { + t.Fatalf("Blog should has 3 tags after Append, got %v", count) + } + + var tags []Tag + DB.Model(&blog).Association("Tags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Should find 3 tags") + } + + var blog1 Blog + DB.Preload("Tags").Find(&blog1) + if !compareTags(blog1.Tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Preload many2many relations") + } + + // Replace + var tag5 = &Tag{Locale: "ZH", Value: "tag5"} + var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog).Association("Tags").Replace(tag5, tag6) + var tags2 []Tag + DB.Model(&blog).Association("Tags").Find(&tags2) + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Fatalf("Should find 2 tags after Replace") + } + + if DB.Model(&blog).Association("Tags").Count() != 2 { + t.Fatalf("Blog should has three tags after Replace") + } + + // Delete + DB.Model(&blog).Association("Tags").Delete(tag5) + var tags3 []Tag + DB.Model(&blog).Association("Tags").Find(&tags3) + if !compareTags(tags3, []string{"tag6"}) { + t.Fatalf("Should find 1 tags after Delete") + } + + if DB.Model(&blog).Association("Tags").Count() != 1 { + t.Fatalf("Blog should has three tags after Delete") + } + + DB.Model(&blog).Association("Tags").Delete(tag3) + var tags4 []Tag + DB.Model(&blog).Association("Tags").Find(&tags4) + if !compareTags(tags4, []string{"tag6"}) { + t.Fatalf("Tag should not be deleted when Delete with a unrelated tag") + } + + // Clear + DB.Model(&blog).Association("Tags").Clear() + if DB.Model(&blog).Association("Tags").Count() != 0 { + t.Fatalf("All tags should be cleared") + } +} + +func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { + t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") + } + + DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") + if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { + t.Fatalf("Failed to auto migrate, got error: %v", err) + } + + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + SharedTags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + DB.Save(&blog) + + blog2 := Blog{ + ID: blog.ID, + Locale: "EN", + } + DB.Create(&blog2) + + if !compareTags(blog.SharedTags, []string{"tag1", "tag2"}) { + t.Fatalf("Blog should has two tags") + } + + // Append + var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("SharedTags").Append([]*Tag{tag3}) + if !compareTags(blog.SharedTags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Blog should has three tags after Append") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 3 { + t.Fatalf("Blog should has three tags after Append") + } + + if DB.Model(&blog2).Association("SharedTags").Count() != 3 { + t.Fatalf("Blog should has three tags after Append") + } + + var tags []Tag + DB.Model(&blog).Association("SharedTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Should find 3 tags") + } + + DB.Model(&blog2).Association("SharedTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Should find 3 tags") + } + + var blog1 Blog + DB.Preload("SharedTags").Find(&blog1) + if !compareTags(blog1.SharedTags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Preload many2many relations") + } + + var tag4 = &Tag{Locale: "ZH", Value: "tag4"} + DB.Model(&blog2).Association("SharedTags").Append(tag4) + + DB.Model(&blog).Association("SharedTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { + t.Fatalf("Should find 3 tags") + } + + DB.Model(&blog2).Association("SharedTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { + t.Fatalf("Should find 3 tags") + } + + // Replace + var tag5 = &Tag{Locale: "ZH", Value: "tag5"} + var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog2).Association("SharedTags").Replace(tag5, tag6) + var tags2 []Tag + DB.Model(&blog).Association("SharedTags").Find(&tags2) + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Fatalf("Should find 2 tags after Replace") + } + + DB.Model(&blog2).Association("SharedTags").Find(&tags2) + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Fatalf("Should find 2 tags after Replace") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 2 { + t.Fatalf("Blog should has three tags after Replace") + } + + // Delete + DB.Model(&blog).Association("SharedTags").Delete(tag5) + var tags3 []Tag + DB.Model(&blog).Association("SharedTags").Find(&tags3) + if !compareTags(tags3, []string{"tag6"}) { + t.Fatalf("Should find 1 tags after Delete") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 1 { + t.Fatalf("Blog should has three tags after Delete") + } + + DB.Model(&blog2).Association("SharedTags").Delete(tag3) + var tags4 []Tag + DB.Model(&blog).Association("SharedTags").Find(&tags4) + if !compareTags(tags4, []string{"tag6"}) { + t.Fatalf("Tag should not be deleted when Delete with a unrelated tag") + } + + // Clear + DB.Model(&blog2).Association("SharedTags").Clear() + if DB.Model(&blog).Association("SharedTags").Count() != 0 { + t.Fatalf("All tags should be cleared") + } +} + +func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { + t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") + } + + DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") + if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { + t.Fatalf("Failed to auto migrate, got error: %v", err) + } + + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + LocaleTags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + DB.Save(&blog) + + blog2 := Blog{ + ID: blog.ID, + Locale: "EN", + } + DB.Create(&blog2) + + // Append + var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3}) + if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Blog should has three tags after Append") + } + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Fatalf("Blog should has three tags after Append") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Fatalf("EN Blog should has 0 tags after ZH Blog Append") + } + + var tags []Tag + DB.Model(&blog).Association("LocaleTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Should find 3 tags") + } + + DB.Model(&blog2).Association("LocaleTags").Find(&tags) + if len(tags) != 0 { + t.Fatalf("Should find 0 tags for EN Blog") + } + + var blog1 Blog + DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID) + if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Preload many2many relations") + } + + var tag4 = &Tag{Locale: "ZH", Value: "tag4"} + DB.Model(&blog2).Association("LocaleTags").Append(tag4) + + DB.Model(&blog).Association("LocaleTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Should find 3 tags for EN Blog") + } + + DB.Model(&blog2).Association("LocaleTags").Find(&tags) + if !compareTags(tags, []string{"tag4"}) { + t.Fatalf("Should find 1 tags for EN Blog") + } + + // Replace + var tag5 = &Tag{Locale: "ZH", Value: "tag5"} + var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6) + + var tags2 []Tag + DB.Model(&blog).Association("LocaleTags").Find(&tags2) + if !compareTags(tags2, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("CN Blog's tags should not be changed after EN Blog Replace") + } + + var blog11 Blog + DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale) + if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("CN Blog's tags should not be changed after EN Blog Replace") + } + + DB.Model(&blog2).Association("LocaleTags").Find(&tags2) + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Fatalf("Should find 2 tags after Replace") + } + + var blog21 Blog + DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale) + if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) { + t.Fatalf("EN Blog's tags should be changed after Replace") + } + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Fatalf("ZH Blog should has three tags after Replace") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { + t.Fatalf("EN Blog should has two tags after Replace") + } + + // Delete + DB.Model(&blog).Association("LocaleTags").Delete(tag5) + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Fatalf("ZH Blog should has three tags after Delete with EN's tag") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { + t.Fatalf("EN Blog should has two tags after ZH Blog Delete with EN's tag") + } + + DB.Model(&blog2).Association("LocaleTags").Delete(tag5) + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Fatalf("ZH Blog should has three tags after EN Blog Delete with EN's tag") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 1 { + t.Fatalf("EN Blog should has 1 tags after EN Blog Delete with EN's tag") + } + + // Clear + DB.Model(&blog2).Association("LocaleTags").Clear() + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Fatalf("ZH Blog's tags should not be cleared when clear EN Blog's tags") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Fatalf("EN Blog's tags should be cleared when clear EN Blog's tags") + } + + DB.Model(&blog).Association("LocaleTags").Clear() + if DB.Model(&blog).Association("LocaleTags").Count() != 0 { + t.Fatalf("ZH Blog's tags should be cleared when clear ZH Blog's tags") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Fatalf("EN Blog's tags should be cleared") + } +}