From 5ecbf25b225b824660c70dba134051888e78ee76 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 2 Jun 2020 07:28:29 +0800 Subject: [PATCH] Drop table with CASCADE option --- dialects/mysql/migrator.go | 15 +++++++++++++++ dialects/postgres/migrator.go | 13 +++++++++++++ gorm.go | 1 + migrator/migrator.go | 13 +++++-------- schema/relationship.go | 10 ++++++++++ tests/preload_suits_test.go | 13 +++++-------- 6 files changed, 49 insertions(+), 16 deletions(-) diff --git a/dialects/mysql/migrator.go b/dialects/mysql/migrator.go index 74c11277..467da9a2 100644 --- a/dialects/mysql/migrator.go +++ b/dialects/mysql/migrator.go @@ -24,6 +24,21 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { }) } +func (m Migrator) DropTable(values ...interface{}) error { + values = m.ReorderModels(values, false) + tx := m.DB.Session(&gorm.Session{}) + tx.Exec("SET FOREIGN_KEY_CHECKS = 0;") + for i := len(values) - 1; i >= 0; i-- { + if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { + return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error + }); err != nil { + return err + } + } + tx.Exec("SET FOREIGN_KEY_CHECKS = 1;") + return nil +} + func (m Migrator) DropConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { for _, chk := range stmt.Schema.ParseCheckConstraints() { diff --git a/dialects/postgres/migrator.go b/dialects/postgres/migrator.go index d93f681c..ef582f00 100644 --- a/dialects/postgres/migrator.go +++ b/dialects/postgres/migrator.go @@ -108,6 +108,19 @@ func (m Migrator) HasTable(value interface{}) bool { return count > 0 } +func (m Migrator) DropTable(values ...interface{}) error { + values = m.ReorderModels(values, false) + tx := m.DB.Session(&gorm.Session{}) + for i := len(values) - 1; i >= 0; i-- { + if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { + return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error + }); err != nil { + return err + } + } + return nil +} + func (m Migrator) HasColumn(value interface{}, field string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { diff --git a/gorm.go b/gorm.go index fd0d4b7e..07f94266 100644 --- a/gorm.go +++ b/gorm.go @@ -204,6 +204,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil { for _, ref := range relation.References { if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { + f.DataType = ref.ForeignKey.DataType ref.ForeignKey = f } else { return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) diff --git a/migrator/migrator.go b/migrator/migrator.go index 4e0f28b5..d78c6224 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -203,14 +203,11 @@ func (m Migrator) CreateTable(values ...interface{}) error { func (m Migrator) DropTable(values ...interface{}) error { values = m.ReorderModels(values, false) for i := len(values) - 1; i >= 0; i-- { - value := values[i] - if m.DB.Migrator().HasTable(value) { - tx := m.DB.Session(&gorm.Session{}) - if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { - return tx.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error - }); err != nil { - return err - } + tx := m.DB.Session(&gorm.Session{}) + if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { + return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error + }); err != nil { + return err } } return nil diff --git a/schema/relationship.go b/schema/relationship.go index 194fbeff..8b5e987c 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -150,6 +150,10 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name) } } + + // use same data type for foreign keys + relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType + relation.References = append(relation.References, &Reference{ PrimaryKey: primaryKeyField, ForeignKey: relation.Polymorphic.PolymorphicID, @@ -246,6 +250,9 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel // build references for _, f := range relation.JoinTable.Fields { + // use same data type for foreign keys + f.DataType = fieldsMap[f.Name].DataType + relation.References = append(relation.References, &Reference{ PrimaryKey: fieldsMap[f.Name], ForeignKey: f, @@ -326,6 +333,9 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH // build references for idx, foreignField := range foreignFields { + // use same data type for foreign keys + foreignField.DataType = primaryFields[idx].DataType + relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], ForeignKey: foreignField, diff --git a/tests/preload_suits_test.go b/tests/preload_suits_test.go index 2e7eeb1f..b71b7299 100644 --- a/tests/preload_suits_test.go +++ b/tests/preload_suits_test.go @@ -1167,9 +1167,8 @@ func TestNestedManyToManyPreload4(t *testing.T) { } ) + DB.Migrator().DropTable("level1_level2", "level2_level3") DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, &Level4{}) - DB.Migrator().DropTable("level1_level2") - DB.Migrator().DropTable("level2_level3") dummy := Level1{ Value: "Level1", @@ -1211,8 +1210,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) { } ) - DB.Migrator().DropTable(&Level2{}, &Level1{}) - DB.Migrator().DropTable("levels") + DB.Migrator().DropTable("levels", &Level2{}, &Level1{}) if err := DB.AutoMigrate(&Level2{}, &Level1{}); err != nil { t.Error(err) @@ -1296,7 +1294,7 @@ func TestNilPointerSlice(t *testing.T) { Level1 struct { ID uint Value string - Level2ID uint + Level2ID *uint Level2 *Level2 } ) @@ -1325,7 +1323,7 @@ func TestNilPointerSlice(t *testing.T) { Level2: nil, } if err := DB.Save(&want2).Error; err != nil { - t.Error(err) + t.Fatalf("Got error %v", err) } var got []Level1 @@ -1481,8 +1479,7 @@ func TestPreloadManyToManyCallbacks(t *testing.T) { } ) - DB.Migrator().DropTable(&Level2{}, &Level1{}) - DB.Migrator().DropTable("level1_level2s") + DB.Migrator().DropTable("level1_level2s", &Level2{}, &Level1{}) if err := DB.AutoMigrate(new(Level1), new(Level2)); err != nil { t.Error(err)