From a2cac75218c9844d0b43832e2d2a3c35f9700406 Mon Sep 17 00:00:00 2001 From: Alexis Viscogliosi Date: Fri, 15 Dec 2023 09:36:08 +0100 Subject: [PATCH] feature: bring custom type and id column name to polymorphism (#6716) * feature: bring custom type and id column name to polymorphism * relationship: better returns for hasPolymorphicRelation * fix: tests --- schema/relationship.go | 67 +++++++--- schema/relationship_test.go | 187 ++++++++++++++++++++++++++++ tests/associations_has_many_test.go | 11 +- tests/helper_test.go | 26 ++-- tests/migrate_test.go | 85 ++++++++----- tests/tests_test.go | 2 +- utils/tests/models.go | 10 +- 7 files changed, 335 insertions(+), 53 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index e03dcc52..57167859 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -76,8 +76,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { return nil } - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { - schema.buildPolymorphicRelation(relation, field, polymorphic) + if hasPolymorphicRelation(field.TagSettings) { + schema.buildPolymorphicRelation(relation, field) } else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { schema.buildMany2ManyRelation(relation, field, many2many) } else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" { @@ -89,7 +89,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { case reflect.Slice: schema.guessRelation(relation, field, guessHas) default: - schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name) + schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, + field.Name) } } @@ -124,6 +125,20 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { return relation } +// hasPolymorphicRelation check if has polymorphic relation +// 1. `POLYMORPHIC` tag +// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag +func hasPolymorphicRelation(tagSettings map[string]string) bool { + if _, ok := tagSettings["POLYMORPHIC"]; ok { + return true + } + + _, hasType := tagSettings["POLYMORPHICTYPE"] + _, hasId := tagSettings["POLYMORPHICID"] + + return hasType && hasId +} + func (schema *Schema) setRelation(relation *Relationship) { // set non-embedded relation if rel := schema.Relationships.Relations[relation.Name]; rel != nil { @@ -169,23 +184,41 @@ func (schema *Schema) setRelation(relation *Relationship) { // OwnerID int // OwnerType string // } -func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) { +func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) { + polymorphic := field.TagSettings["POLYMORPHIC"] + relation.Polymorphic = &Polymorphic{ - Value: schema.Table, - PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"], - PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"], + Value: schema.Table, } + var ( + typeName = polymorphic + "Type" + typeId = polymorphic + "ID" + ) + + if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok { + typeName = strings.TrimSpace(value) + } + + if value, ok := field.TagSettings["POLYMORPHICID"]; ok { + typeId = strings.TrimSpace(value) + } + + relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName] + relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId] + if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok { relation.Polymorphic.Value = strings.TrimSpace(value) } if relation.Polymorphic.PolymorphicType == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type") + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", + relation.FieldSchema, schema, field.Name, polymorphic+"Type") } if relation.Polymorphic.PolymorphicID == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID") + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", + relation.FieldSchema, schema, field.Name, polymorphic+"ID") } if schema.err == nil { @@ -197,12 +230,14 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi primaryKeyField := schema.PrioritizedPrimaryField if len(relation.foreignKeys) > 0 { if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 { - schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name) + schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, + schema, field.Name) } } if primaryKeyField == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name) + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", + relation.FieldSchema, schema, field.Name) return } @@ -317,7 +352,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Tag: `gorm:"-"`, }) - if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { + if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, + schema.namer); err != nil { schema.err = err } relation.JoinTable.Name = many2many @@ -436,7 +472,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu schema.guessRelation(relation, field, guessEmbeddedHas) // case guessEmbeddedHas: default: - schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name) + schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", + schema, field.Name) } } @@ -492,7 +529,9 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu lookUpNames := []string{lookUpName} if len(primaryFields) == 1 { - lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) + lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", + strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, + strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) } for _, name := range lookUpNames { diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 1eb66bb4..23d79bbb 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -577,6 +577,193 @@ func TestEmbeddedHas(t *testing.T) { }) } +func TestPolymorphic(t *testing.T) { + t.Run("has one", func(t *testing.T) { + type Toy struct { + ID int + Name string + OwnerID int + OwnerType string + } + + type Cat struct { + ID int + Name string + Toy Toy `gorm:"polymorphic:Owner;"` + } + + s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toy": { + Name: "Toy", + Type: schema.HasOne, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, + References: []Reference{ + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) + }) + + t.Run("has one with custom polymorphic type and id", func(t *testing.T) { + type Toy struct { + ID int + Name string + RefId int + Type string + } + + type Cat struct { + ID int + Name string + Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type;polymorphicId:RefId"` + } + + s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toy": { + Name: "Toy", + Type: schema.HasOne, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"}, + References: []Reference{ + {ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) + }) + + t.Run("has one with only polymorphic type", func(t *testing.T) { + type Toy struct { + ID int + Name string + OwnerID int + Type string + } + + type Cat struct { + ID int + Name string + Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type"` + } + + s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toy": { + Name: "Toy", + Type: schema.HasOne, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "owner_id", Type: "Type", Value: "users"}, + References: []Reference{ + {ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) + }) + + t.Run("has many", func(t *testing.T) { + type Toy struct { + ID int + Name string + OwnerID int + OwnerType string + } + + type Cat struct { + ID int + Name string + Toys []Toy `gorm:"polymorphic:Owner;"` + } + + s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toys": { + Name: "Toys", + Type: schema.HasMany, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, + References: []Reference{ + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) + }) + + t.Run("has many with custom polymorphic type and id", func(t *testing.T) { + type Toy struct { + ID int + Name string + RefId int + Type string + } + + type Cat struct { + ID int + Name string + Toys []Toy `gorm:"polymorphicType:Type;polymorphicId:RefId"` + } + + s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toys": { + Name: "Toys", + Type: schema.HasMany, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"}, + References: []Reference{ + {ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) + }) +} + func TestEmbeddedBelongsTo(t *testing.T) { type Country struct { ID int `gorm:"primaryKey"` diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go index c31c4b40..b8e8ff5e 100644 --- a/tests/associations_has_many_test.go +++ b/tests/associations_has_many_test.go @@ -422,7 +422,7 @@ func TestPolymorphicHasManyAssociation(t *testing.T) { func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { users := []User{ *GetUser("slice-hasmany-1", Config{Toys: 2}), - *GetUser("slice-hasmany-2", Config{Toys: 0}), + *GetUser("slice-hasmany-2", Config{Toys: 0, Tools: 2}), *GetUser("slice-hasmany-3", Config{Toys: 4}), } @@ -430,6 +430,7 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { // Count AssertAssociationCount(t, users, "Toys", 6, "") + AssertAssociationCount(t, users, "Tools", 2, "") // Find var toys []Toy @@ -437,6 +438,14 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { t.Errorf("toys count should be %v, but got %v", 6, len(toys)) } + // Find Tools (polymorphic with custom type and id) + var tools []Tools + DB.Model(&users).Association("Tools").Find(&tools) + + if len(tools) != 2 { + t.Errorf("tools count should be %v, but got %v", 2, len(tools)) + } + // Append DB.Model(&users).Association("Toys").Append( &Toy{Name: "toy-slice-append-1"}, diff --git a/tests/helper_test.go b/tests/helper_test.go index 1a4874ee..feb67f9e 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -23,6 +23,7 @@ type Config struct { Languages int Friends int NamedPet bool + Tools int } func GetUser(name string, config Config) *User { @@ -47,6 +48,10 @@ func GetUser(name string, config Config) *User { user.Toys = append(user.Toys, Toy{Name: name + "_toy_" + strconv.Itoa(i+1)}) } + for i := 0; i < config.Tools; i++ { + user.Tools = append(user.Tools, Tools{Name: name + "_tool_" + strconv.Itoa(i+1)}) + } + if config.Company { user.Company = Company{Name: "company-" + name} } @@ -118,11 +123,13 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) { if err := db(unscoped).Where("id = ?", user.ID).First(&newUser).Error; err != nil { t.Fatalf("errors happened when query: %v", err) } else { - AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", + "CompanyID", "ManagerID", "Active") } } - AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", + "ManagerID", "Active") t.Run("Account", func(t *testing.T) { AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") @@ -133,7 +140,8 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) { } else { var account Account db(unscoped).First(&account, "user_id = ?", user.ID) - AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") + AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", + "Number") } } }) @@ -193,8 +201,10 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) { } else { var manager User db(unscoped).First(&manager, "id = ?", *user.ManagerID) - AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") - AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", + "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", + "Birthday", "CompanyID", "ManagerID", "Active") } } else if user.ManagerID != nil { t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) @@ -215,7 +225,8 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) { }) for idx, team := range user.Team { - AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", + "Birthday", "CompanyID", "ManagerID", "Active") } }) @@ -250,7 +261,8 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) { }) for idx, friend := range user.Friends { - AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", + "Birthday", "CompanyID", "ManagerID", "Active") } }) } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index cfd3e0ac..28fa315b 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -18,7 +18,7 @@ import ( ) func TestMigrate(t *testing.T) { - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Tools{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) DB.Migrator().DropTable("user_speaks", "user_friends", "ccc") @@ -34,7 +34,7 @@ func TestMigrate(t *testing.T) { if tables, err := DB.Migrator().GetTables(); err != nil { t.Fatalf("Failed to get database all tables, but got error %v", err) } else { - for _, t1 := range []string{"users", "accounts", "pets", "companies", "toys", "languages"} { + for _, t1 := range []string{"users", "accounts", "pets", "companies", "toys", "languages", "tools"} { hasTable := false for _, t2 := range tables { if t2 == t1 { @@ -93,7 +93,8 @@ func TestAutoMigrateInt8PG(t *testing.T) { Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { sql, _ := fc() if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") { - t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", sql) + t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", + sql) } }, } @@ -432,40 +433,50 @@ func TestTiDBMigrateColumns(t *testing.T) { switch columnType.Name() { case "id": if v, ok := columnType.PrimaryKey(); !ok || !v { - t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "name": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { - t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } if length, ok := columnType.Length(); !ok || length != 100 { - t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) + t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), length, 100, columnType) } case "age": if v, ok := columnType.DefaultValue(); !ok || v != "18" { - t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } if v, ok := columnType.Comment(); !ok || v != "my age" { - t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "code": if v, ok := columnType.Unique(); !ok || !v { - t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } if v, ok := columnType.DefaultValue(); !ok || v != "hello" { - t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", + columnType.Name(), columnType, v) } if v, ok := columnType.Comment(); !ok || v != "my code2" { - t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "code2": // Code2 string `gorm:"comment:my code2;default:hello"` if v, ok := columnType.DefaultValue(); !ok || v != "hello" { - t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", + columnType.Name(), columnType, v) } if v, ok := columnType.Comment(); !ok || v != "my code2" { - t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } } } @@ -497,7 +508,8 @@ func TestTiDBMigrateColumns(t *testing.T) { t.Fatalf("Failed to add column, got %v", err) } - if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil { + if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", + "new_new_name"); err != nil { t.Fatalf("Failed to add column, got %v", err) } @@ -561,36 +573,45 @@ func TestMigrateColumns(t *testing.T) { switch columnType.Name() { case "id": if v, ok := columnType.PrimaryKey(); !ok || !v { - t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "name": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { - t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } if length, ok := columnType.Length(); !sqlite && (!ok || length != 100) { - t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) + t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), length, 100, columnType) } case "age": if v, ok := columnType.DefaultValue(); !ok || v != "18" { - t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my age") { - t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "code": if v, ok := columnType.Unique(); !ok || !v { - t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") { - t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", + columnType.Name(), columnType, v) } if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") { - t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "code2": if v, ok := columnType.Unique(); !sqlserver && (!ok || !v) { - t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "code3": // TODO @@ -627,7 +648,8 @@ func TestMigrateColumns(t *testing.T) { t.Fatalf("Failed to add column, got %v", err) } - if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil { + if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", + "new_new_name"); err != nil { t.Fatalf("Failed to add column, got %v", err) } @@ -1555,7 +1577,8 @@ func TestMigrateIgnoreRelations(t *testing.T) { func TestMigrateView(t *testing.T) { DB.Save(GetUser("joins-args-db", Config{Pets: 2})) - if err := DB.Migrator().CreateView("invalid_users_pets", gorm.ViewOption{Query: nil}); err != gorm.ErrSubQueryRequired { + if err := DB.Migrator().CreateView("invalid_users_pets", + gorm.ViewOption{Query: nil}); err != gorm.ErrSubQueryRequired { t.Fatalf("no view should be created, got %v", err) } @@ -1624,17 +1647,20 @@ func TestMigrateExistingBoolColumnPG(t *testing.T) { switch columnType.Name() { case "id": if v, ok := columnType.PrimaryKey(); !ok || !v { - t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "string_bool": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { - t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } case "smallint_bool": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { - t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } } } @@ -1659,7 +1685,8 @@ func TestTableType(t *testing.T) { DB.Migrator().DropTable(&City{}) - if err := DB.Set("gorm:table_options", fmt.Sprintf("ENGINE InnoDB COMMENT '%s'", tblComment)).AutoMigrate(&City{}); err != nil { + if err := DB.Set("gorm:table_options", + fmt.Sprintf("ENGINE InnoDB COMMENT '%s'", tblComment)).AutoMigrate(&City{}); err != nil { t.Fatalf("failed to migrate cities tables, got error: %v", err) } diff --git a/tests/tests_test.go b/tests/tests_test.go index f9c6cab5..a127734e 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -107,7 +107,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) { func RunMigrations() { var err error - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}, &Tools{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) diff --git a/utils/tests/models.go b/utils/tests/models.go index a4bad2fc..f9f4f50e 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -20,7 +20,8 @@ type User struct { Account Account Pets []*Pet NamedPet *Pet - Toys []Toy `gorm:"polymorphic:Owner"` + Toys []Toy `gorm:"polymorphic:Owner"` + Tools []Tools `gorm:"polymorphicType:Type;polymorphicId:CustomID"` CompanyID *int Company Company ManagerID *uint @@ -51,6 +52,13 @@ type Toy struct { OwnerType string } +type Tools struct { + gorm.Model + Name string + CustomID string + Type string +} + type Company struct { ID int Name string