Fix create join table

This commit is contained in:
Jinzhu 2020-04-28 08:05:22 +08:00
parent 85f3174467
commit 70d60ef72f
1 changed files with 13 additions and 8 deletions

View File

@ -25,12 +25,14 @@ type Config struct {
} }
func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
stmt := m.DB.Statement stmt := &gorm.Statement{DB: m.DB}
if stmt == nil { if m.DB.Statement != nil {
stmt = &gorm.Statement{DB: m.DB} stmt.Table = m.DB.Statement.Table
} }
if err := stmt.Parse(value); err != nil { if table, ok := value.(string); ok {
stmt.Table = table
} else if err := stmt.Parse(value); err != nil {
return err return err
} }
@ -105,8 +107,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
// create join table // create join table
if rel.JoinTable != nil { if rel.JoinTable != nil {
joinValue := reflect.New(rel.JoinTable.ModelType).Interface() joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
if !tx.Migrator().HasTable(joinValue) { if !tx.Migrator().HasTable(rel.JoinTable.Table) {
defer tx.Migrator().CreateTable(joinValue) defer tx.Table(rel.JoinTable.Table).Migrator().CreateTable(joinValue)
} else {
defer tx.Table(rel.JoinTable.Table).Migrator().AutoMigrate(joinValue)
} }
} }
} }
@ -167,8 +171,8 @@ func (m Migrator) CreateTable(values ...interface{}) error {
// create join table // create join table
if rel.JoinTable != nil { if rel.JoinTable != nil {
joinValue := reflect.New(rel.JoinTable.ModelType).Interface() joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
if !tx.Migrator().HasTable(joinValue) { if !tx.Migrator().HasTable(rel.JoinTable.Table) {
defer tx.Migrator().CreateTable(joinValue) defer tx.Table(rel.JoinTable.Table).Migrator().CreateTable(joinValue)
} }
} }
} }
@ -207,6 +211,7 @@ func (m Migrator) DropTable(values ...interface{}) error {
func (m Migrator) HasTable(value interface{}) bool { func (m Migrator) HasTable(value interface{}) bool {
var count int64 var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase() currentDatabase := m.DB.Migrator().CurrentDatabase()
return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count) return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count)