gorm/migrator/migrator.go

296 lines
8.9 KiB
Go
Raw Normal View History

2020-01-28 18:01:35 +03:00
package migrator
2020-02-20 18:04:03 +03:00
import (
"database/sql"
"fmt"
"github.com/jinzhu/gorm"
2020-02-22 06:15:51 +03:00
"github.com/jinzhu/gorm/clause"
2020-02-20 18:04:03 +03:00
)
2020-01-28 18:01:35 +03:00
// Migrator migrator struct
type Migrator struct {
*Config
}
// Config schema config
type Config struct {
CheckExistsBeforeDropping bool
DB *gorm.DB
}
2020-02-20 18:04:03 +03:00
func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
stmt := migrator.DB.Statement
if stmt == nil {
stmt = &gorm.Statement{DB: migrator.DB}
}
if err := stmt.Parse(value); err != nil {
return err
}
return fc(stmt)
}
// AutoMigrate
func (migrator Migrator) AutoMigrate(values ...interface{}) error {
2020-02-22 06:15:51 +03:00
// if has table
// not -> create table
// check columns -> add column, change column type
// check foreign keys -> create indexes
// check indexes -> create indexes
2020-02-20 18:04:03 +03:00
return gorm.ErrNotImplemented
}
func (migrator Migrator) CreateTable(values ...interface{}) error {
2020-02-22 06:15:51 +03:00
// migrate
// create join table
2020-02-20 18:04:03 +03:00
return gorm.ErrNotImplemented
}
func (migrator Migrator) DropTable(values ...interface{}) error {
for _, value := range values {
if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
2020-02-22 06:15:51 +03:00
return migrator.DB.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error
2020-02-20 18:04:03 +03:00
}); err != nil {
return err
}
}
return nil
}
func (migrator Migrator) HasTable(values ...interface{}) bool {
var count int64
for _, value := range values {
err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := migrator.DB.Migrator().CurrentDatabase()
return migrator.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Scan(&count).Error
})
if err != nil || count == 0 {
return false
}
}
return true
}
func (migrator Migrator) RenameTable(oldName, newName string) error {
return migrator.DB.Exec("RENAME TABLE ? TO ?", oldName, newName).Error
}
func (migrator Migrator) AddColumn(value interface{}, field string) error {
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil {
2020-02-22 06:15:51 +03:00
return migrator.DB.Exec(
"ALTER TABLE ? ADD ? ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType},
).Error
2020-02-20 18:04:03 +03:00
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
}
func (migrator Migrator) DropColumn(value interface{}, field string) error {
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil {
2020-02-22 06:15:51 +03:00
return migrator.DB.Exec(
"ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName},
).Error
2020-02-20 18:04:03 +03:00
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
}
func (migrator Migrator) AlterColumn(value interface{}, field string) error {
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil {
2020-02-22 06:15:51 +03:00
return migrator.DB.Exec(
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType},
).Error
2020-02-20 18:04:03 +03:00
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
}
func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) error {
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil {
oldName = migrator.DB.NamingStrategy.ColumnName(stmt.Table, oldName)
2020-02-22 06:15:51 +03:00
return migrator.DB.Exec(
"ALTER TABLE ? RENAME COLUMN ? TO ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: field.DBName},
).Error
2020-02-20 18:04:03 +03:00
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
}
func (migrator Migrator) ColumnTypes(value interface{}) ([]*sql.ColumnType, error) {
return nil, gorm.ErrNotImplemented
}
func (migrator Migrator) CreateView(name string, option gorm.ViewOption) error {
return gorm.ErrNotImplemented
}
func (migrator Migrator) DropView(name string) error {
return gorm.ErrNotImplemented
}
func (migrator Migrator) CreateConstraint(value interface{}, name string) error {
2020-02-22 06:15:51 +03:00
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
checkConstraints := stmt.Schema.ParseCheckConstraints()
if chk, ok := checkConstraints[name]; ok {
return migrator.DB.Exec(
"ALTER TABLE ? ADD CONSTRAINT ? CHECK ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint},
).Error
}
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
sql := "ALTER TABLE ? ADD CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
var foreignKeys, references []interface{}
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
return migrator.DB.Exec(
sql, clause.Table{Name: stmt.Table}, clause.Column{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references,
).Error
}
}
err := fmt.Errorf("failed to create constraint with name %v", name)
if field := stmt.Schema.LookUpField(name); field != nil {
for _, cc := range checkConstraints {
if err = migrator.CreateIndex(value, cc.Name); err != nil {
return err
}
}
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field {
if err = migrator.CreateIndex(value, constraint.Name); err != nil {
return err
}
}
}
}
return err
})
2020-02-20 18:04:03 +03:00
}
func (migrator Migrator) DropConstraint(value interface{}, name string) error {
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
2020-02-22 06:15:51 +03:00
return migrator.DB.Exec(
"ALTER TABLE ? DROP CONSTRAINT ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: name},
).Error
2020-02-20 18:04:03 +03:00
})
}
func (migrator Migrator) CreateIndex(value interface{}, name string) error {
2020-02-22 06:15:51 +03:00
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
err := fmt.Errorf("failed to create index with name %v", name)
indexes := stmt.Schema.ParseIndexes()
if idx, ok := indexes[name]; ok {
fields := []interface{}{}
for _, field := range idx.Fields {
str := stmt.Quote(field.DBName)
if field.Expression != "" {
str = field.Expression
} else if field.Length > 0 {
str += fmt.Sprintf("(%d)", field.Length)
}
if field.Sort != "" {
str += " " + field.Sort
}
fields = append(fields, clause.Expr{SQL: str})
}
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, fields}
createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX ? ON ??"
if idx.Comment != "" {
values = append(values, idx.Comment)
createIndexSQL += " COMMENT ?"
}
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type
}
return migrator.DB.Raw(createIndexSQL, values...).Error
} else if field := stmt.Schema.LookUpField(name); field != nil {
for _, idx := range indexes {
for _, idxOpt := range idx.Fields {
if idxOpt.Field == field {
if err = migrator.CreateIndex(value, idx.Name); err != nil {
return err
}
}
}
}
}
return err
})
2020-02-20 18:04:03 +03:00
}
func (migrator Migrator) DropIndex(value interface{}, name string) error {
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
2020-02-22 06:15:51 +03:00
return migrator.DB.Raw("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error
2020-02-20 18:04:03 +03:00
})
}
func (migrator Migrator) HasIndex(value interface{}, name string) bool {
var count int64
migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := migrator.DB.Migrator().CurrentDatabase()
2020-02-22 06:15:51 +03:00
return migrator.DB.Raw(
"SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?",
currentDatabase, stmt.Table, name,
).Scan(&count).Error
2020-02-20 18:04:03 +03:00
})
if count != 0 {
return true
}
return false
}
func (migrator Migrator) RenameIndex(value interface{}, oldName, newName string) error {
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
2020-02-22 06:15:51 +03:00
return migrator.DB.Exec(
"ALTER TABLE ? RENAME INDEX ? TO ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName},
).Error
2020-02-20 18:04:03 +03:00
})
}
func (migrator Migrator) CurrentDatabase() (name string) {
migrator.DB.Raw("SELECT DATABASE()").Scan(&name)
return
}