Drop table with CASCADE option

This commit is contained in:
Jinzhu 2020-06-02 07:28:29 +08:00
parent b71171dd92
commit 5ecbf25b22
6 changed files with 49 additions and 16 deletions

View File

@ -24,6 +24,21 @@ func (m Migrator) AlterColumn(value interface{}, field string) error {
}) })
} }
func (m Migrator) DropTable(values ...interface{}) error {
values = m.ReorderModels(values, false)
tx := m.DB.Session(&gorm.Session{})
tx.Exec("SET FOREIGN_KEY_CHECKS = 0;")
for i := len(values) - 1; i >= 0; i-- {
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error
}); err != nil {
return err
}
}
tx.Exec("SET FOREIGN_KEY_CHECKS = 1;")
return nil
}
func (m Migrator) DropConstraint(value interface{}, name string) error { func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
for _, chk := range stmt.Schema.ParseCheckConstraints() { for _, chk := range stmt.Schema.ParseCheckConstraints() {

View File

@ -108,6 +108,19 @@ func (m Migrator) HasTable(value interface{}) bool {
return count > 0 return count > 0
} }
func (m Migrator) DropTable(values ...interface{}) error {
values = m.ReorderModels(values, false)
tx := m.DB.Session(&gorm.Session{})
for i := len(values) - 1; i >= 0; i-- {
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error
}); err != nil {
return err
}
}
return nil
}
func (m Migrator) HasColumn(value interface{}, field string) bool { func (m Migrator) HasColumn(value interface{}, field string) bool {
var count int64 var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {

View File

@ -204,6 +204,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil { if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
for _, ref := range relation.References { for _, ref := range relation.References {
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
f.DataType = ref.ForeignKey.DataType
ref.ForeignKey = f ref.ForeignKey = f
} else { } else {
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)

View File

@ -203,14 +203,11 @@ func (m Migrator) CreateTable(values ...interface{}) error {
func (m Migrator) DropTable(values ...interface{}) error { func (m Migrator) DropTable(values ...interface{}) error {
values = m.ReorderModels(values, false) values = m.ReorderModels(values, false)
for i := len(values) - 1; i >= 0; i-- { for i := len(values) - 1; i >= 0; i-- {
value := values[i] tx := m.DB.Session(&gorm.Session{})
if m.DB.Migrator().HasTable(value) { if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
tx := m.DB.Session(&gorm.Session{}) return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { }); err != nil {
return tx.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error return err
}); err != nil {
return err
}
} }
} }
return nil return nil

View File

@ -150,6 +150,10 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name) schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name)
} }
} }
// use same data type for foreign keys
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
relation.References = append(relation.References, &Reference{ relation.References = append(relation.References, &Reference{
PrimaryKey: primaryKeyField, PrimaryKey: primaryKeyField,
ForeignKey: relation.Polymorphic.PolymorphicID, ForeignKey: relation.Polymorphic.PolymorphicID,
@ -246,6 +250,9 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
// build references // build references
for _, f := range relation.JoinTable.Fields { for _, f := range relation.JoinTable.Fields {
// use same data type for foreign keys
f.DataType = fieldsMap[f.Name].DataType
relation.References = append(relation.References, &Reference{ relation.References = append(relation.References, &Reference{
PrimaryKey: fieldsMap[f.Name], PrimaryKey: fieldsMap[f.Name],
ForeignKey: f, ForeignKey: f,
@ -326,6 +333,9 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
// build references // build references
for idx, foreignField := range foreignFields { for idx, foreignField := range foreignFields {
// use same data type for foreign keys
foreignField.DataType = primaryFields[idx].DataType
relation.References = append(relation.References, &Reference{ relation.References = append(relation.References, &Reference{
PrimaryKey: primaryFields[idx], PrimaryKey: primaryFields[idx],
ForeignKey: foreignField, ForeignKey: foreignField,

View File

@ -1167,9 +1167,8 @@ func TestNestedManyToManyPreload4(t *testing.T) {
} }
) )
DB.Migrator().DropTable("level1_level2", "level2_level3")
DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, &Level4{}) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, &Level4{})
DB.Migrator().DropTable("level1_level2")
DB.Migrator().DropTable("level2_level3")
dummy := Level1{ dummy := Level1{
Value: "Level1", Value: "Level1",
@ -1211,8 +1210,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) {
} }
) )
DB.Migrator().DropTable(&Level2{}, &Level1{}) DB.Migrator().DropTable("levels", &Level2{}, &Level1{})
DB.Migrator().DropTable("levels")
if err := DB.AutoMigrate(&Level2{}, &Level1{}); err != nil { if err := DB.AutoMigrate(&Level2{}, &Level1{}); err != nil {
t.Error(err) t.Error(err)
@ -1296,7 +1294,7 @@ func TestNilPointerSlice(t *testing.T) {
Level1 struct { Level1 struct {
ID uint ID uint
Value string Value string
Level2ID uint Level2ID *uint
Level2 *Level2 Level2 *Level2
} }
) )
@ -1325,7 +1323,7 @@ func TestNilPointerSlice(t *testing.T) {
Level2: nil, Level2: nil,
} }
if err := DB.Save(&want2).Error; err != nil { if err := DB.Save(&want2).Error; err != nil {
t.Error(err) t.Fatalf("Got error %v", err)
} }
var got []Level1 var got []Level1
@ -1481,8 +1479,7 @@ func TestPreloadManyToManyCallbacks(t *testing.T) {
} }
) )
DB.Migrator().DropTable(&Level2{}, &Level1{}) DB.Migrator().DropTable("level1_level2s", &Level2{}, &Level1{})
DB.Migrator().DropTable("level1_level2s")
if err := DB.AutoMigrate(new(Level1), new(Level2)); err != nil { if err := DB.AutoMigrate(new(Level1), new(Level2)); err != nil {
t.Error(err) t.Error(err)