forked from mirror/gorm
Fix create join table
This commit is contained in:
parent
85f3174467
commit
70d60ef72f
|
@ -25,12 +25,14 @@ type Config struct {
|
|||
}
|
||||
|
||||
func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
|
||||
stmt := m.DB.Statement
|
||||
if stmt == nil {
|
||||
stmt = &gorm.Statement{DB: m.DB}
|
||||
stmt := &gorm.Statement{DB: m.DB}
|
||||
if m.DB.Statement != nil {
|
||||
stmt.Table = m.DB.Statement.Table
|
||||
}
|
||||
|
||||
if err := stmt.Parse(value); err != nil {
|
||||
if table, ok := value.(string); ok {
|
||||
stmt.Table = table
|
||||
} else if err := stmt.Parse(value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -105,8 +107,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
|||
// create join table
|
||||
if rel.JoinTable != nil {
|
||||
joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
if !tx.Migrator().HasTable(joinValue) {
|
||||
defer tx.Migrator().CreateTable(joinValue)
|
||||
if !tx.Migrator().HasTable(rel.JoinTable.Table) {
|
||||
defer tx.Table(rel.JoinTable.Table).Migrator().CreateTable(joinValue)
|
||||
} else {
|
||||
defer tx.Table(rel.JoinTable.Table).Migrator().AutoMigrate(joinValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -167,8 +171,8 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
|||
// create join table
|
||||
if rel.JoinTable != nil {
|
||||
joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
if !tx.Migrator().HasTable(joinValue) {
|
||||
defer tx.Migrator().CreateTable(joinValue)
|
||||
if !tx.Migrator().HasTable(rel.JoinTable.Table) {
|
||||
defer tx.Table(rel.JoinTable.Table).Migrator().CreateTable(joinValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -207,6 +211,7 @@ func (m Migrator) DropTable(values ...interface{}) error {
|
|||
|
||||
func (m Migrator) HasTable(value interface{}) bool {
|
||||
var count int64
|
||||
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
currentDatabase := m.DB.Migrator().CurrentDatabase()
|
||||
return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count)
|
||||
|
|
Loading…
Reference in New Issue