From 70d60ef72fccd8822c9dc54e0be492294e78c58d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 28 Apr 2020 08:05:22 +0800 Subject: [PATCH] Fix create join table --- migrator/migrator.go | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 763b4ec3..f581f714 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -25,12 +25,14 @@ type Config struct { } func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { - stmt := m.DB.Statement - if stmt == nil { - stmt = &gorm.Statement{DB: m.DB} + stmt := &gorm.Statement{DB: m.DB} + if m.DB.Statement != nil { + stmt.Table = m.DB.Statement.Table } - if err := stmt.Parse(value); err != nil { + if table, ok := value.(string); ok { + stmt.Table = table + } else if err := stmt.Parse(value); err != nil { return err } @@ -105,8 +107,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { // create join table if rel.JoinTable != nil { joinValue := reflect.New(rel.JoinTable.ModelType).Interface() - if !tx.Migrator().HasTable(joinValue) { - defer tx.Migrator().CreateTable(joinValue) + if !tx.Migrator().HasTable(rel.JoinTable.Table) { + defer tx.Table(rel.JoinTable.Table).Migrator().CreateTable(joinValue) + } else { + defer tx.Table(rel.JoinTable.Table).Migrator().AutoMigrate(joinValue) } } } @@ -167,8 +171,8 @@ func (m Migrator) CreateTable(values ...interface{}) error { // create join table if rel.JoinTable != nil { joinValue := reflect.New(rel.JoinTable.ModelType).Interface() - if !tx.Migrator().HasTable(joinValue) { - defer tx.Migrator().CreateTable(joinValue) + if !tx.Migrator().HasTable(rel.JoinTable.Table) { + defer tx.Table(rel.JoinTable.Table).Migrator().CreateTable(joinValue) } } } @@ -207,6 +211,7 @@ func (m Migrator) DropTable(values ...interface{}) error { func (m Migrator) HasTable(value interface{}) bool { var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count)