From 9fea15ae75fb9ff2bd86dcaa167673c8ed77394f Mon Sep 17 00:00:00 2001 From: black-06 Date: Mon, 30 Oct 2023 17:15:49 +0800 Subject: [PATCH] feat: add MigrateColumnUnique (#6640) * feat: add MigrateColumnUnique * feat: define new methods * delete debug in test --- migrator.go | 2 ++ migrator/migrator.go | 22 ++++++++++++++++++++++ schema/naming.go | 8 ++++++++ tests/associations_belongs_to_test.go | 2 -- tests/count_test.go | 2 +- tests/preload_test.go | 1 - tests/update_test.go | 2 +- 7 files changed, 34 insertions(+), 5 deletions(-) diff --git a/migrator.go b/migrator.go index 0e01f567..3d2b032b 100644 --- a/migrator.go +++ b/migrator.go @@ -87,6 +87,8 @@ type Migrator interface { DropColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error + // MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn. + MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error HasColumn(dst interface{}, field string) bool RenameColumn(dst interface{}, oldName, field string) error ColumnTypes(dst interface{}) ([]ColumnType, error) diff --git a/migrator/migrator.go b/migrator/migrator.go index 49bc9371..64a5a4b5 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -27,6 +27,8 @@ var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) // TODO:? Create const vars for raw sql queries ? +var _ gorm.Migrator = (*Migrator)(nil) + // Migrator m struct type Migrator struct { Config @@ -539,6 +541,26 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy return nil } +func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { + unique, ok := columnType.Unique() + if !ok || field.PrimaryKey { + return nil // skip primary key + } + // By default, ColumnType's Unique is not affected by UniqueIndex, so we don't care about UniqueIndex. + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + // We're currently only receiving boolean values on `Unique` tag, + // so the UniqueConstraint name is fixed + constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName) + if unique && !field.Unique { + return m.DB.Migrator().DropConstraint(value, constraint) + } + if !unique && field.Unique { + return m.DB.Migrator().CreateConstraint(value, constraint) + } + return nil + }) +} + // ColumnTypes return columnTypes []gorm.ColumnType and execErr error func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { columnTypes := make([]gorm.ColumnType, 0) diff --git a/schema/naming.go b/schema/naming.go index a2a0150a..e6fb81b2 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -19,6 +19,7 @@ type Namer interface { RelationshipFKName(Relationship) string CheckerName(table, column string) string IndexName(table, column string) string + UniqueName(table, column string) string } // Replacer replacer interface like strings.Replacer @@ -26,6 +27,8 @@ type Replacer interface { Replace(name string) string } +var _ Namer = (*NamingStrategy)(nil) + // NamingStrategy tables, columns naming strategy type NamingStrategy struct { TablePrefix string @@ -85,6 +88,11 @@ func (ns NamingStrategy) IndexName(table, column string) string { return ns.formatName("idx", table, ns.toDBName(column)) } +// UniqueName generate unique constraint name +func (ns NamingStrategy) UniqueName(table, column string) string { + return ns.formatName("uni", table, ns.toDBName(column)) +} + func (ns NamingStrategy) formatName(prefix, table, name string) string { formattedName := strings.ReplaceAll(strings.Join([]string{ prefix, table, name, diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index 6befb5f2..103da032 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -278,8 +278,6 @@ func TestBelongsToAssociationUnscoped(t *testing.T) { t.Fatalf("failed to create items, got error: %v", err) } - tx = tx.Debug() - // test replace if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{ Logo: "updated logo", diff --git a/tests/count_test.go b/tests/count_test.go index b0dfb0b5..4449515b 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -29,7 +29,7 @@ func TestCountWithGroup(t *testing.T) { } var count2 int64 - if err := DB.Debug().Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil { + if err := DB.Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil { t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) } if count2 != 2 { diff --git a/tests/preload_test.go b/tests/preload_test.go index 7304e350..3ff86492 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -429,7 +429,6 @@ func TestEmbedPreload(t *testing.T) { }, } - DB = DB.Debug() for _, test := range tests { t.Run(test.name, func(t *testing.T) { actual := Org{} diff --git a/tests/update_test.go b/tests/update_test.go index a3fb7015..b719cc45 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -838,7 +838,7 @@ func TestSaveWithHooks(t *testing.T) { saveTokenOwner := func(owner *TokenOwner) (*TokenOwner, error) { var newOwner TokenOwner if err := DB.Transaction(func(tx *gorm.DB) error { - if err := tx.Debug().Session(&gorm.Session{FullSaveAssociations: true}).Save(owner).Error; err != nil { + if err := tx.Session(&gorm.Session{FullSaveAssociations: true}).Save(owner).Error; err != nil { return err } if err := tx.Preload("Token").First(&newOwner, owner.ID).Error; err != nil {