gorm/migrator/migrator.go

475 lines
14 KiB
Go
Raw Normal View History

2020-01-28 18:01:35 +03:00
package migrator
2020-02-20 18:04:03 +03:00
import (
"database/sql"
"fmt"
2020-02-22 08:09:57 +03:00
"reflect"
"strings"
2020-02-20 18:04:03 +03:00
"github.com/jinzhu/gorm"
2020-02-22 06:15:51 +03:00
"github.com/jinzhu/gorm/clause"
2020-02-22 08:09:57 +03:00
"github.com/jinzhu/gorm/schema"
2020-02-20 18:04:03 +03:00
)
2020-01-28 18:01:35 +03:00
2020-02-22 12:53:57 +03:00
// Migrator m struct
2020-01-28 18:01:35 +03:00
type Migrator struct {
2020-02-22 12:53:57 +03:00
Config
2020-01-28 18:01:35 +03:00
}
// Config schema config
type Config struct {
2020-02-22 15:57:29 +03:00
CreateIndexAfterCreateTable bool
DB *gorm.DB
2020-02-22 12:53:57 +03:00
gorm.Dialector
2020-01-28 18:01:35 +03:00
}
2020-02-20 18:04:03 +03:00
2020-02-22 12:53:57 +03:00
func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
stmt := m.DB.Statement
2020-02-20 18:04:03 +03:00
if stmt == nil {
2020-02-22 12:53:57 +03:00
stmt = &gorm.Statement{DB: m.DB}
2020-02-20 18:04:03 +03:00
}
if err := stmt.Parse(value); err != nil {
return err
}
return fc(stmt)
}
2020-02-22 12:53:57 +03:00
func (m Migrator) DataTypeOf(field *schema.Field) string {
if field.DBDataType != "" {
return field.DBDataType
}
return m.Dialector.DataTypeOf(field)
}
2020-02-20 18:04:03 +03:00
// AutoMigrate
2020-02-22 12:53:57 +03:00
func (m Migrator) AutoMigrate(values ...interface{}) error {
2020-02-22 08:09:57 +03:00
// TODO smart migrate data type
2020-02-22 06:15:51 +03:00
2020-02-22 08:09:57 +03:00
for _, value := range values {
2020-02-22 12:53:57 +03:00
if !m.DB.Migrator().HasTable(value) {
if err := m.DB.Migrator().CreateTable(value); err != nil {
2020-02-22 08:09:57 +03:00
return err
}
} else {
2020-02-22 12:53:57 +03:00
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
2020-02-22 08:09:57 +03:00
for _, field := range stmt.Schema.FieldsByDBName {
2020-02-22 12:53:57 +03:00
if !m.DB.Migrator().HasColumn(value, field.DBName) {
if err := m.DB.Migrator().AddColumn(value, field.DBName); err != nil {
2020-02-22 08:09:57 +03:00
return err
}
}
}
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil {
2020-02-22 12:53:57 +03:00
if !m.DB.Migrator().HasConstraint(value, constraint.Name) {
if err := m.DB.Migrator().CreateConstraint(value, constraint.Name); err != nil {
2020-02-22 08:09:57 +03:00
return err
}
}
}
for _, chk := range stmt.Schema.ParseCheckConstraints() {
2020-02-22 12:53:57 +03:00
if !m.DB.Migrator().HasConstraint(value, chk.Name) {
if err := m.DB.Migrator().CreateConstraint(value, chk.Name); err != nil {
2020-02-22 08:09:57 +03:00
return err
}
}
}
// create join table
2020-02-22 15:57:29 +03:00
if rel.JoinTable != nil {
joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
if !m.DB.Migrator().HasTable(joinValue) {
defer m.DB.Migrator().CreateTable(joinValue)
}
2020-02-22 08:09:57 +03:00
}
}
return nil
}); err != nil {
return err
}
}
}
return nil
2020-02-20 18:04:03 +03:00
}
2020-02-22 12:53:57 +03:00
func (m Migrator) CreateTable(values ...interface{}) error {
2020-02-22 08:09:57 +03:00
for _, value := range values {
2020-02-22 12:53:57 +03:00
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
2020-02-22 08:09:57 +03:00
var (
createTableSQL = "CREATE TABLE ? ("
values = []interface{}{clause.Table{Name: stmt.Table}}
hasPrimaryKeyInDataType bool
)
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[dbName]
createTableSQL += fmt.Sprintf("? ?")
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(field.DBDataType), "PRIMARY KEY")
2020-02-22 12:53:57 +03:00
values = append(values, clause.Column{Name: dbName}, clause.Expr{SQL: m.DataTypeOf(field)})
2020-02-22 08:09:57 +03:00
if field.AutoIncrement {
createTableSQL += " AUTO_INCREMENT"
}
if field.NotNull {
createTableSQL += " NOT NULL"
}
if field.Unique {
createTableSQL += " UNIQUE"
}
if field.DefaultValue != "" {
createTableSQL += " DEFAULT ?"
values = append(values, clause.Expr{SQL: field.DefaultValue})
}
createTableSQL += ","
}
if !hasPrimaryKeyInDataType {
createTableSQL += "PRIMARY KEY ?,"
primaryKeys := []interface{}{}
for _, field := range stmt.Schema.PrimaryFields {
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
}
values = append(values, primaryKeys)
}
for _, idx := range stmt.Schema.ParseIndexes() {
2020-02-22 15:57:29 +03:00
if m.CreateIndexAfterCreateTable {
m.DB.Migrator().CreateIndex(value, idx.Name)
} else {
createTableSQL += "INDEX ? ?,"
values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
}
2020-02-22 08:09:57 +03:00
}
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil {
sql, vars := buildConstraint(constraint)
createTableSQL += sql + ","
values = append(values, vars...)
}
// create join table
2020-02-22 15:57:29 +03:00
if rel.JoinTable != nil {
joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
if !m.DB.Migrator().HasTable(joinValue) {
defer m.DB.Migrator().CreateTable(joinValue)
}
2020-02-22 08:09:57 +03:00
}
}
for _, chk := range stmt.Schema.ParseCheckConstraints() {
createTableSQL += "CONSTRAINT ? CHECK ?,"
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
}
createTableSQL = strings.TrimSuffix(createTableSQL, ",")
createTableSQL += ")"
2020-02-22 12:53:57 +03:00
return m.DB.Exec(createTableSQL, values...).Error
2020-02-22 08:09:57 +03:00
}); err != nil {
return err
}
}
return nil
2020-02-20 18:04:03 +03:00
}
2020-02-22 12:53:57 +03:00
func (m Migrator) DropTable(values ...interface{}) error {
2020-02-20 18:04:03 +03:00
for _, value := range values {
2020-02-22 12:53:57 +03:00
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error
2020-02-20 18:04:03 +03:00
}); err != nil {
return err
}
}
return nil
}
2020-02-22 12:53:57 +03:00
func (m Migrator) HasTable(value interface{}) bool {
2020-02-20 18:04:03 +03:00
var count int64
2020-02-22 12:53:57 +03:00
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count)
})
2020-02-20 18:04:03 +03:00
2020-02-22 12:53:57 +03:00
return count > 0
2020-02-20 18:04:03 +03:00
}
2020-02-22 12:53:57 +03:00
func (m Migrator) RenameTable(oldName, newName string) error {
return m.DB.Exec("RENAME TABLE ? TO ?", oldName, newName).Error
2020-02-20 18:04:03 +03:00
}
2020-02-22 12:53:57 +03:00
func (m Migrator) AddColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
2020-02-20 18:04:03 +03:00
if field := stmt.Schema.LookUpField(field); field != nil {
2020-02-22 12:53:57 +03:00
return m.DB.Exec(
2020-02-22 06:15:51 +03:00
"ALTER TABLE ? ADD ? ?",
2020-02-22 12:53:57 +03:00
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.DataTypeOf(field)},
2020-02-22 06:15:51 +03:00
).Error
2020-02-20 18:04:03 +03:00
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
}
2020-02-22 12:53:57 +03:00
func (m Migrator) DropColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
2020-02-20 18:04:03 +03:00
if field := stmt.Schema.LookUpField(field); field != nil {
2020-02-22 12:53:57 +03:00
return m.DB.Exec(
2020-02-22 06:15:51 +03:00
"ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName},
).Error
2020-02-20 18:04:03 +03:00
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
}
2020-02-22 12:53:57 +03:00
func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
2020-02-20 18:04:03 +03:00
if field := stmt.Schema.LookUpField(field); field != nil {
2020-02-22 12:53:57 +03:00
return m.DB.Exec(
2020-02-22 06:15:51 +03:00
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
2020-02-22 12:53:57 +03:00
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.DataTypeOf(field)},
2020-02-22 06:15:51 +03:00
).Error
2020-02-20 18:04:03 +03:00
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
}
2020-02-22 12:53:57 +03:00
func (m Migrator) HasColumn(value interface{}, field string) bool {
2020-02-22 08:09:57 +03:00
var count int64
2020-02-22 12:53:57 +03:00
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
2020-02-22 08:09:57 +03:00
name := field
if field := stmt.Schema.LookUpField(field); field != nil {
name = field.DBName
}
2020-02-22 12:53:57 +03:00
return m.DB.Raw(
2020-02-22 08:09:57 +03:00
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
currentDatabase, stmt.Table, name,
2020-02-22 12:53:57 +03:00
).Row().Scan(&count)
2020-02-22 08:09:57 +03:00
})
2020-02-22 12:53:57 +03:00
return count > 0
2020-02-22 08:09:57 +03:00
}
2020-02-22 12:53:57 +03:00
func (m Migrator) RenameColumn(value interface{}, oldName, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
2020-02-20 18:04:03 +03:00
if field := stmt.Schema.LookUpField(field); field != nil {
2020-02-22 12:53:57 +03:00
oldName = m.DB.NamingStrategy.ColumnName(stmt.Table, oldName)
return m.DB.Exec(
2020-02-22 06:15:51 +03:00
"ALTER TABLE ? RENAME COLUMN ? TO ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: field.DBName},
).Error
2020-02-20 18:04:03 +03:00
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
}
2020-02-22 14:41:01 +03:00
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) {
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows()
if err == nil {
columnTypes, err = rows.ColumnTypes()
}
return err
})
return
2020-02-20 18:04:03 +03:00
}
2020-02-22 12:53:57 +03:00
func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
2020-02-20 18:04:03 +03:00
return gorm.ErrNotImplemented
}
2020-02-22 12:53:57 +03:00
func (m Migrator) DropView(name string) error {
2020-02-20 18:04:03 +03:00
return gorm.ErrNotImplemented
}
2020-02-22 08:09:57 +03:00
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
var foreignKeys, references []interface{}
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
2020-02-22 15:57:29 +03:00
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
2020-02-22 08:09:57 +03:00
return
}
2020-02-22 12:53:57 +03:00
func (m Migrator) CreateConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
2020-02-22 06:15:51 +03:00
checkConstraints := stmt.Schema.ParseCheckConstraints()
if chk, ok := checkConstraints[name]; ok {
2020-02-22 12:53:57 +03:00
return m.DB.Exec(
2020-02-22 06:15:51 +03:00
"ALTER TABLE ? ADD CONSTRAINT ? CHECK ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint},
).Error
}
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
2020-02-22 08:09:57 +03:00
sql, values := buildConstraint(constraint)
2020-02-22 12:53:57 +03:00
return m.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error
2020-02-22 06:15:51 +03:00
}
}
err := fmt.Errorf("failed to create constraint with name %v", name)
if field := stmt.Schema.LookUpField(name); field != nil {
for _, cc := range checkConstraints {
2020-02-22 15:57:29 +03:00
if err = m.DB.Migrator().CreateIndex(value, cc.Name); err != nil {
2020-02-22 06:15:51 +03:00
return err
}
}
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field {
2020-02-22 15:57:29 +03:00
if err = m.DB.Migrator().CreateIndex(value, constraint.Name); err != nil {
2020-02-22 06:15:51 +03:00
return err
}
}
}
}
return err
})
2020-02-20 18:04:03 +03:00
}
2020-02-22 12:53:57 +03:00
func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec(
2020-02-22 06:15:51 +03:00
"ALTER TABLE ? DROP CONSTRAINT ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: name},
).Error
2020-02-20 18:04:03 +03:00
})
}
2020-02-22 12:53:57 +03:00
func (m Migrator) HasConstraint(value interface{}, name string) bool {
2020-02-22 08:09:57 +03:00
var count int64
2020-02-22 12:53:57 +03:00
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
return m.DB.Raw(
2020-02-22 08:09:57 +03:00
"SELECT count(*) FROM INFORMATION_SCHEMA.referential_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?",
currentDatabase, stmt.Table, name,
2020-02-22 12:53:57 +03:00
).Row().Scan(&count)
2020-02-22 08:09:57 +03:00
})
2020-02-22 12:53:57 +03:00
return count > 0
2020-02-22 08:09:57 +03:00
}
2020-02-22 12:53:57 +03:00
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
2020-02-22 08:09:57 +03:00
for _, opt := range opts {
str := stmt.Quote(opt.DBName)
if opt.Expression != "" {
str = opt.Expression
} else if opt.Length > 0 {
str += fmt.Sprintf("(%d)", opt.Length)
}
2020-02-22 12:53:57 +03:00
if opt.Collate != "" {
str += " COLLATE " + opt.Collate
}
2020-02-22 08:09:57 +03:00
if opt.Sort != "" {
str += " " + opt.Sort
}
results = append(results, clause.Expr{SQL: str})
}
return
}
2020-02-22 12:53:57 +03:00
type BuildIndexOptionsInterface interface {
BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{}
}
func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
2020-02-22 06:15:51 +03:00
err := fmt.Errorf("failed to create index with name %v", name)
indexes := stmt.Schema.ParseIndexes()
if idx, ok := indexes[name]; ok {
2020-02-22 12:53:57 +03:00
opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
2020-02-22 08:09:57 +03:00
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
2020-02-22 06:15:51 +03:00
createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX ? ON ??"
if idx.Comment != "" {
values = append(values, idx.Comment)
createIndexSQL += " COMMENT ?"
}
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type
}
2020-02-22 12:53:57 +03:00
return m.DB.Exec(createIndexSQL, values...).Error
2020-02-22 06:15:51 +03:00
} else if field := stmt.Schema.LookUpField(name); field != nil {
for _, idx := range indexes {
for _, idxOpt := range idx.Fields {
if idxOpt.Field == field {
2020-02-22 12:53:57 +03:00
if err = m.CreateIndex(value, idx.Name); err != nil {
2020-02-22 06:15:51 +03:00
return err
}
}
}
}
}
return err
})
2020-02-20 18:04:03 +03:00
}
2020-02-22 12:53:57 +03:00
func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error
2020-02-20 18:04:03 +03:00
})
}
2020-02-22 12:53:57 +03:00
func (m Migrator) HasIndex(value interface{}, name string) bool {
2020-02-20 18:04:03 +03:00
var count int64
2020-02-22 12:53:57 +03:00
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
return m.DB.Raw(
2020-02-22 06:15:51 +03:00
"SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?",
currentDatabase, stmt.Table, name,
2020-02-22 12:53:57 +03:00
).Row().Scan(&count)
2020-02-20 18:04:03 +03:00
})
2020-02-22 12:53:57 +03:00
return count > 0
2020-02-20 18:04:03 +03:00
}
2020-02-22 12:53:57 +03:00
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec(
2020-02-22 06:15:51 +03:00
"ALTER TABLE ? RENAME INDEX ? TO ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName},
).Error
2020-02-20 18:04:03 +03:00
})
}
2020-02-22 12:53:57 +03:00
func (m Migrator) CurrentDatabase() (name string) {
m.DB.Raw("SELECT DATABASE()").Row().Scan(&name)
2020-02-20 18:04:03 +03:00
return
}