Fix create index for other database/schema, close #3698

This commit is contained in:
Jinzhu 2020-11-05 11:43:21 +08:00
parent fcf2ab6c0e
commit 85e9f66d26
1 changed files with 25 additions and 22 deletions

View File

@ -158,14 +158,10 @@ func (m Migrator) CreateTable(values ...interface{}) error {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
var ( var (
createTableSQL = "CREATE TABLE ? (" createTableSQL = "CREATE TABLE ? ("
values = []interface{}{clause.Table{Name: stmt.Table}} values = []interface{}{m.CurrentTable(stmt)}
hasPrimaryKeyInDataType bool hasPrimaryKeyInDataType bool
) )
if stmt.TableExpr != nil {
values[0] = *stmt.TableExpr
}
for _, dbName := range stmt.Schema.DBNames { for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[dbName] field := stmt.Schema.FieldsByDBName[dbName]
createTableSQL += "? ?" createTableSQL += "? ?"
@ -243,7 +239,7 @@ func (m Migrator) DropTable(values ...interface{}) error {
for i := len(values) - 1; i >= 0; i-- { for i := len(values) - 1; i >= 0; i-- {
tx := m.DB.Session(&gorm.Session{}) tx := m.DB.Session(&gorm.Session{})
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error return tx.Exec("DROP TABLE IF EXISTS ?", m.CurrentTable(stmt)).Error
}); err != nil { }); err != nil {
return err return err
} }
@ -263,30 +259,30 @@ func (m Migrator) HasTable(value interface{}) bool {
} }
func (m Migrator) RenameTable(oldName, newName interface{}) error { func (m Migrator) RenameTable(oldName, newName interface{}) error {
var oldTable, newTable string var oldTable, newTable interface{}
if v, ok := oldName.(string); ok { if v, ok := oldName.(string); ok {
oldTable = v oldTable = clause.Table{Name: v}
} else { } else {
stmt := &gorm.Statement{DB: m.DB} stmt := &gorm.Statement{DB: m.DB}
if err := stmt.Parse(oldName); err == nil { if err := stmt.Parse(oldName); err == nil {
oldTable = stmt.Table oldTable = m.CurrentTable(stmt)
} else { } else {
return err return err
} }
} }
if v, ok := newName.(string); ok { if v, ok := newName.(string); ok {
newTable = v newTable = clause.Table{Name: v}
} else { } else {
stmt := &gorm.Statement{DB: m.DB} stmt := &gorm.Statement{DB: m.DB}
if err := stmt.Parse(newName); err == nil { if err := stmt.Parse(newName); err == nil {
newTable = stmt.Table newTable = m.CurrentTable(stmt)
} else { } else {
return err return err
} }
} }
return m.DB.Exec("ALTER TABLE ? RENAME TO ?", clause.Table{Name: oldTable}, clause.Table{Name: newTable}).Error return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error
} }
func (m Migrator) AddColumn(value interface{}, field string) error { func (m Migrator) AddColumn(value interface{}, field string) error {
@ -294,7 +290,7 @@ func (m Migrator) AddColumn(value interface{}, field string) error {
if field := stmt.Schema.LookUpField(field); field != nil { if field := stmt.Schema.LookUpField(field); field != nil {
return m.DB.Exec( return m.DB.Exec(
"ALTER TABLE ? ADD ? ?", "ALTER TABLE ? ADD ? ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), m.CurrentTable(stmt), clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field),
).Error ).Error
} }
return fmt.Errorf("failed to look up field with name: %s", field) return fmt.Errorf("failed to look up field with name: %s", field)
@ -308,7 +304,7 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
} }
return m.DB.Exec( return m.DB.Exec(
"ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: name}, "ALTER TABLE ? DROP COLUMN ?", m.CurrentTable(stmt), clause.Column{Name: name},
).Error ).Error
}) })
} }
@ -319,7 +315,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error {
fileType := clause.Expr{SQL: m.DataTypeOf(field)} fileType := clause.Expr{SQL: m.DataTypeOf(field)}
return m.DB.Exec( return m.DB.Exec(
"ALTER TABLE ? ALTER COLUMN ? TYPE ?", "ALTER TABLE ? ALTER COLUMN ? TYPE ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, fileType, m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
).Error ).Error
} }
@ -357,7 +353,7 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
return m.DB.Exec( return m.DB.Exec(
"ALTER TABLE ? RENAME COLUMN ? TO ?", "ALTER TABLE ? RENAME COLUMN ? TO ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName},
).Error ).Error
}) })
} }
@ -459,14 +455,14 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
if chk, ok := checkConstraints[name]; ok { if chk, ok := checkConstraints[name]; ok {
return m.DB.Exec( return m.DB.Exec(
"ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)", "ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)",
clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint},
).Error ).Error
} }
for _, rel := range stmt.Schema.Relationships.Relations { for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
sql, values := buildConstraint(constraint) sql, values := buildConstraint(constraint)
return m.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error return m.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{m.CurrentTable(stmt)}, values...)...).Error
} }
} }
@ -495,7 +491,7 @@ func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec( return m.DB.Exec(
"ALTER TABLE ? DROP CONSTRAINT ?", "ALTER TABLE ? DROP CONSTRAINT ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: name}, m.CurrentTable(stmt), clause.Column{Name: name},
).Error ).Error
}) })
} }
@ -542,7 +538,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil { if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
createIndexSQL := "CREATE " createIndexSQL := "CREATE "
if idx.Class != "" { if idx.Class != "" {
@ -571,7 +567,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error {
name = idx.Name name = idx.Name
} }
return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error
}) })
} }
@ -596,7 +592,7 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec( return m.DB.Exec(
"ALTER TABLE ? RENAME INDEX ? TO ?", "ALTER TABLE ? RENAME INDEX ? TO ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName},
).Error ).Error
}) })
} }
@ -701,3 +697,10 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
} }
return return
} }
func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
if stmt.TableExpr != nil {
return *stmt.TableExpr
}
return clause.Table{Name: stmt.Table}
}