diff --git a/migrator/migrator.go b/migrator/migrator.go index 9493a00c..016ebfc7 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -158,14 +158,10 @@ func (m Migrator) CreateTable(values ...interface{}) error { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { var ( createTableSQL = "CREATE TABLE ? (" - values = []interface{}{clause.Table{Name: stmt.Table}} + values = []interface{}{m.CurrentTable(stmt)} hasPrimaryKeyInDataType bool ) - if stmt.TableExpr != nil { - values[0] = *stmt.TableExpr - } - for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.FieldsByDBName[dbName] createTableSQL += "? ?" @@ -243,7 +239,7 @@ func (m Migrator) DropTable(values ...interface{}) error { for i := len(values) - 1; i >= 0; i-- { tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { - return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error + return tx.Exec("DROP TABLE IF EXISTS ?", m.CurrentTable(stmt)).Error }); err != nil { return err } @@ -263,30 +259,30 @@ func (m Migrator) HasTable(value interface{}) bool { } func (m Migrator) RenameTable(oldName, newName interface{}) error { - var oldTable, newTable string + var oldTable, newTable interface{} if v, ok := oldName.(string); ok { - oldTable = v + oldTable = clause.Table{Name: v} } else { stmt := &gorm.Statement{DB: m.DB} if err := stmt.Parse(oldName); err == nil { - oldTable = stmt.Table + oldTable = m.CurrentTable(stmt) } else { return err } } if v, ok := newName.(string); ok { - newTable = v + newTable = clause.Table{Name: v} } else { stmt := &gorm.Statement{DB: m.DB} if err := stmt.Parse(newName); err == nil { - newTable = stmt.Table + newTable = m.CurrentTable(stmt) } else { return err } } - return m.DB.Exec("ALTER TABLE ? RENAME TO ?", clause.Table{Name: oldTable}, clause.Table{Name: newTable}).Error + return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error } func (m Migrator) AddColumn(value interface{}, field string) error { @@ -294,7 +290,7 @@ func (m Migrator) AddColumn(value interface{}, field string) error { if field := stmt.Schema.LookUpField(field); field != nil { return m.DB.Exec( "ALTER TABLE ? ADD ? ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), + m.CurrentTable(stmt), clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), ).Error } return fmt.Errorf("failed to look up field with name: %s", field) @@ -308,7 +304,7 @@ func (m Migrator) DropColumn(value interface{}, name string) error { } return m.DB.Exec( - "ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: name}, + "ALTER TABLE ? DROP COLUMN ?", m.CurrentTable(stmt), clause.Column{Name: name}, ).Error }) } @@ -319,7 +315,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { fileType := clause.Expr{SQL: m.DataTypeOf(field)} return m.DB.Exec( "ALTER TABLE ? ALTER COLUMN ? TYPE ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, fileType, + m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, ).Error } @@ -357,7 +353,7 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error return m.DB.Exec( "ALTER TABLE ? RENAME COLUMN ? TO ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, + m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName}, ).Error }) } @@ -459,14 +455,14 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { if chk, ok := checkConstraints[name]; ok { return m.DB.Exec( "ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)", - clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, + m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, ).Error } for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { sql, values := buildConstraint(constraint) - return m.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error + return m.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{m.CurrentTable(stmt)}, values...)...).Error } } @@ -495,7 +491,7 @@ func (m Migrator) DropConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.DB.Exec( "ALTER TABLE ? DROP CONSTRAINT ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: name}, + m.CurrentTable(stmt), clause.Column{Name: name}, ).Error }) } @@ -542,7 +538,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if idx := stmt.Schema.LookIndex(name); idx != nil { opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) - values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} + values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} createIndexSQL := "CREATE " if idx.Class != "" { @@ -571,7 +567,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error { name = idx.Name } - return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error + return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error }) } @@ -596,7 +592,7 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.DB.Exec( "ALTER TABLE ? RENAME INDEX ? TO ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, + m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName}, ).Error }) } @@ -701,3 +697,10 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i } return } + +func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { + if stmt.TableExpr != nil { + return *stmt.TableExpr + } + return clause.Table{Name: stmt.Table} +}