From 8fb9a317756bc07dcaaa82d43613ddcf9295c1ad Mon Sep 17 00:00:00 2001 From: black-06 Date: Tue, 6 Feb 2024 19:48:40 +0800 Subject: [PATCH] refactor: part 2 of distinguish between Unique and UniqueIndex (#6822) --- migrator/migrator.go | 22 ++-- tests/go.mod | 22 ++-- tests/migrate_test.go | 227 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 247 insertions(+), 24 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index d97fbf35..ae82f769 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -93,10 +93,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { expr.SQL += " NOT NULL" } - if field.Unique { - expr.SQL += " UNIQUE" - } - if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { if field.DefaultValueInterface != nil { defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} @@ -512,14 +508,6 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } - // check unique - if unique, ok := columnType.Unique(); ok && unique != (field.Unique || field.UniqueIndex != "") { - // not primary key - if !field.PrimaryKey { - alterColumn = true - } - } - // check default value if !field.PrimaryKey { currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL")) @@ -548,8 +536,14 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } - if alterColumn && !field.IgnoreMigration { - return m.DB.Migrator().AlterColumn(value, field.DBName) + if alterColumn { + if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil { + return err + } + } + + if err := m.DB.Migrator().MigrateColumnUnique(value, field, columnType); err != nil { + return err } return nil diff --git a/tests/go.mod b/tests/go.mod index 07fedc45..136667b7 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,28 +3,34 @@ module gorm.io/gorm/tests go 1.18 require ( - github.com/google/uuid v1.5.0 + github.com/google/uuid v1.6.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.9 - gorm.io/driver/mysql v1.5.2 - gorm.io/driver/postgres v1.5.4 - gorm.io/driver/sqlite v1.5.4 - gorm.io/driver/sqlserver v1.5.2 - gorm.io/gorm v1.25.5 + github.com/stretchr/testify v1.8.4 + gorm.io/driver/mysql v1.5.4 + gorm.io/driver/postgres v1.5.6 + gorm.io/driver/sqlite v1.5.5 + gorm.io/driver/sqlserver v1.5.3 + gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-sql-driver/mysql v1.7.1 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect - github.com/jackc/pgx/v5 v5.5.1 // indirect + github.com/jackc/pgx/v5 v5.5.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect - github.com/mattn/go-sqlite3 v1.14.19 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/microsoft/go-mssqldb v1.6.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.12.0 // indirect golang.org/x/crypto v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) replace gorm.io/gorm => ../ diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 28fa315b..837d92c1 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -2,6 +2,7 @@ package tests_test import ( "context" + "database/sql" "fmt" "math/rand" "os" @@ -10,10 +11,15 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/clause" "gorm.io/gorm/logger" + "gorm.io/gorm/migrator" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" . "gorm.io/gorm/utils/tests" ) @@ -984,7 +990,8 @@ func TestCurrentTimestamp(t *testing.T) { if err != nil { t.Fatalf("AutoMigrate err:%v", err) } - AssertEqual(t, true, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at")) + AssertEqual(t, true, DB.Migrator().HasConstraint(&CurrentTimestampTest{}, "uni_current_timestamp_tests_time_at")) + AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at")) AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at_2")) } @@ -1046,7 +1053,8 @@ func TestUniqueColumn(t *testing.T) { } // not trigger alert column - AssertEqual(t, true, DB.Migrator().HasIndex(&UniqueTest{}, "name")) + AssertEqual(t, true, DB.Migrator().HasConstraint(&UniqueTest{}, "uni_unique_tests_name")) + AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name")) AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_1")) AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_2")) @@ -1712,3 +1720,218 @@ func TestTableType(t *testing.T) { t.Fatalf("expected comment %s got %s", tblComment, comment) } } + +func TestMigrateWithUniqueIndexAndUnique(t *testing.T) { + const table = "unique_struct" + + checkField := func(model interface{}, fieldName string, unique bool, uniqueIndex string) { + stmt := &gorm.Statement{DB: DB} + err := stmt.Parse(model) + if err != nil { + t.Fatalf("%v: failed to parse schema, got error: %v", utils.FileWithLineNum(), err) + } + _ = stmt.Schema.ParseIndexes() + field := stmt.Schema.LookUpField(fieldName) + if field == nil { + t.Fatalf("%v: failed to find column %q", utils.FileWithLineNum(), fieldName) + } + if field.Unique != unique { + t.Fatalf("%v: %q column %q unique should be %v but got %v", utils.FileWithLineNum(), stmt.Schema.Table, fieldName, unique, field.Unique) + } + if field.UniqueIndex != uniqueIndex { + t.Fatalf("%v: %q column %q uniqueIndex should be %v but got %v", utils.FileWithLineNum(), stmt.Schema, fieldName, uniqueIndex, field.UniqueIndex) + } + } + + type ( // not unique + UniqueStruct1 struct { + Name string `gorm:"size:10"` + } + UniqueStruct2 struct { + Name string `gorm:"size:20"` + } + ) + checkField(&UniqueStruct1{}, "name", false, "") + checkField(&UniqueStruct2{}, "name", false, "") + + type ( // unique + UniqueStruct3 struct { + Name string `gorm:"size:30;unique"` + } + UniqueStruct4 struct { + Name string `gorm:"size:40;unique"` + } + ) + checkField(&UniqueStruct3{}, "name", true, "") + checkField(&UniqueStruct4{}, "name", true, "") + + type ( // uniqueIndex + UniqueStruct5 struct { + Name string `gorm:"size:50;uniqueIndex"` + } + UniqueStruct6 struct { + Name string `gorm:"size:60;uniqueIndex"` + } + UniqueStruct7 struct { + Name string `gorm:"size:70;uniqueIndex:idx_us6_all_names"` + NickName string `gorm:"size:70;uniqueIndex:idx_us6_all_names"` + } + ) + checkField(&UniqueStruct5{}, "name", false, "idx_unique_struct5_name") + checkField(&UniqueStruct6{}, "name", false, "idx_unique_struct6_name") + + checkField(&UniqueStruct7{}, "name", false, "") + checkField(&UniqueStruct7{}, "nick_name", false, "") + checkField(&UniqueStruct7{}, "nick_name", false, "") + + type UniqueStruct8 struct { // unique and uniqueIndex + Name string `gorm:"size:60;unique;index:my_us8_index,unique;"` + } + checkField(&UniqueStruct8{}, "name", true, "my_us8_index") + + type TestCase struct { + name string + from, to interface{} + checkFunc func(t *testing.T) + } + + checkColumnType := func(t *testing.T, fieldName string, unique bool) { + columnTypes, err := DB.Migrator().ColumnTypes(table) + if err != nil { + t.Fatalf("%v: failed to get column types, got error: %v", utils.FileWithLineNum(), err) + } + var found gorm.ColumnType + for _, columnType := range columnTypes { + if columnType.Name() == fieldName { + found = columnType + } + } + if found == nil { + t.Fatalf("%v: failed to find column type %q", utils.FileWithLineNum(), fieldName) + } + if actualUnique, ok := found.Unique(); !ok || actualUnique != unique { + t.Fatalf("%v: column %q unique should be %v but got %v", utils.FileWithLineNum(), fieldName, unique, actualUnique) + } + } + + checkIndex := func(t *testing.T, expected []gorm.Index) { + indexes, err := DB.Migrator().GetIndexes(table) + if err != nil { + t.Fatalf("%v: failed to get indexes, got error: %v", utils.FileWithLineNum(), err) + } + assert.ElementsMatch(t, expected, indexes) + } + + uniqueIndex := &migrator.Index{TableName: table, NameValue: DB.Config.NamingStrategy.IndexName(table, "name"), ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}} + myIndex := &migrator.Index{TableName: table, NameValue: "my_us8_index", ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}} + mulIndex := &migrator.Index{TableName: table, NameValue: "idx_us6_all_names", ColumnList: []string{"name", "nick_name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}} + + var checkNotUnique, checkUnique, checkUniqueIndex, checkMyIndex, checkMulIndex func(t *testing.T) + // UniqueAffectedByUniqueIndex is true + if DB.Dialector.Name() == "mysql" { + uniqueConstraintIndex := &migrator.Index{TableName: table, NameValue: DB.Config.NamingStrategy.UniqueName(table, "name"), ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}} + checkNotUnique = func(t *testing.T) { + checkColumnType(t, "name", false) + checkIndex(t, nil) + } + checkUnique = func(t *testing.T) { + checkColumnType(t, "name", true) + checkIndex(t, []gorm.Index{uniqueConstraintIndex}) + } + checkUniqueIndex = func(t *testing.T) { + checkColumnType(t, "name", true) + checkIndex(t, []gorm.Index{uniqueIndex}) + } + checkMyIndex = func(t *testing.T) { + checkColumnType(t, "name", true) + checkIndex(t, []gorm.Index{uniqueConstraintIndex, myIndex}) + } + checkMulIndex = func(t *testing.T) { + checkColumnType(t, "name", false) + checkColumnType(t, "nick_name", false) + checkIndex(t, []gorm.Index{mulIndex}) + } + } else { + checkNotUnique = func(t *testing.T) { checkColumnType(t, "name", false) } + checkUnique = func(t *testing.T) { checkColumnType(t, "name", true) } + checkUniqueIndex = func(t *testing.T) { + checkColumnType(t, "name", false) + checkIndex(t, []gorm.Index{uniqueIndex}) + } + checkMyIndex = func(t *testing.T) { + checkColumnType(t, "name", true) + if !DB.Migrator().HasIndex(table, myIndex.Name()) { + t.Errorf("%v: should has index %s but not", utils.FileWithLineNum(), myIndex.Name()) + } + } + checkMulIndex = func(t *testing.T) { + checkColumnType(t, "name", false) + checkColumnType(t, "nick_name", false) + if !DB.Migrator().HasIndex(table, mulIndex.Name()) { + t.Errorf("%v: should has index %s but not", utils.FileWithLineNum(), mulIndex.Name()) + } + } + } + + tests := []TestCase{ + {name: "notUnique to notUnique", from: &UniqueStruct1{}, to: &UniqueStruct2{}, checkFunc: checkNotUnique}, + {name: "notUnique to unique", from: &UniqueStruct1{}, to: &UniqueStruct3{}, checkFunc: checkUnique}, + {name: "notUnique to uniqueIndex", from: &UniqueStruct1{}, to: &UniqueStruct5{}, checkFunc: checkUniqueIndex}, + {name: "notUnique to uniqueAndUniqueIndex", from: &UniqueStruct1{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex}, + {name: "unique to unique", from: &UniqueStruct3{}, to: &UniqueStruct4{}, checkFunc: checkUnique}, + {name: "unique to uniqueIndex", from: &UniqueStruct3{}, to: &UniqueStruct5{}, checkFunc: checkUniqueIndex}, + {name: "unique to uniqueAndUniqueIndex", from: &UniqueStruct3{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex}, + {name: "uniqueIndex to uniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct6{}, checkFunc: checkUniqueIndex}, + {name: "uniqueIndex to uniqueAndUniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex}, + {name: "uniqueIndex to multi uniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct7{}, checkFunc: checkMulIndex}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if err := DB.Migrator().DropTable(table); err != nil { + t.Fatalf("failed to drop table, got error: %v", err) + } + if err := DB.Table(table).AutoMigrate(test.from); err != nil { + t.Fatalf("failed to migrate table, got error: %v", err) + } + if err := DB.Table(table).AutoMigrate(test.to); err != nil { + t.Fatalf("failed to migrate table, got error: %v", err) + } + test.checkFunc(t) + }) + } + + if DB.Dialector.Name() != "sqlserver" { + // In SQLServer, If an index or constraint depends on the column, + // this column will not be able to run ALTER + // see https://stackoverflow.com/questions/19460912/the-object-df-is-dependent-on-column-changing-int-to-double/19461205#19461205 + // may we need to create another PR to fix it, see https://github.com/go-gorm/sqlserver/pull/106 + tests = []TestCase{ + {name: "unique to notUnique", from: &UniqueStruct3{}, to: &UniqueStruct1{}, checkFunc: checkNotUnique}, + {name: "uniqueIndex to notUnique", from: &UniqueStruct5{}, to: &UniqueStruct2{}, checkFunc: checkNotUnique}, + {name: "uniqueIndex to unique", from: &UniqueStruct5{}, to: &UniqueStruct3{}, checkFunc: checkUnique}, + } + } + + if DB.Dialector.Name() == "mysql" { + compatibilityTests := []TestCase{ + {name: "oldUnique to notUnique", to: UniqueStruct1{}, checkFunc: checkNotUnique}, + {name: "oldUnique to unique", to: UniqueStruct3{}, checkFunc: checkUnique}, + {name: "oldUnique to uniqueIndex", to: UniqueStruct5{}, checkFunc: checkUniqueIndex}, + {name: "oldUnique to uniqueAndUniqueIndex", to: UniqueStruct8{}, checkFunc: checkMyIndex}, + } + for _, test := range compatibilityTests { + t.Run(test.name, func(t *testing.T) { + if err := DB.Migrator().DropTable(table); err != nil { + t.Fatalf("failed to drop table, got error: %v", err) + } + if err := DB.Exec("CREATE TABLE ? (`name` varchar(10) UNIQUE)", clause.Table{Name: table}).Error; err != nil { + t.Fatalf("failed to create table, got error: %v", err) + } + if err := DB.Table(table).AutoMigrate(test.to); err != nil { + t.Fatalf("failed to migrate table, got error: %v", err) + } + test.checkFunc(t) + }) + } + } +}