gorm/migrator/migrator.go

950 lines
27 KiB
Go
Raw Normal View History

2020-01-28 18:01:35 +03:00
package migrator
2020-02-20 18:04:03 +03:00
import (
2020-07-16 06:27:04 +03:00
"context"
"database/sql"
"errors"
2020-02-20 18:04:03 +03:00
"fmt"
2020-02-22 08:09:57 +03:00
"reflect"
2020-08-23 10:40:19 +03:00
"regexp"
2020-02-22 08:09:57 +03:00
"strings"
"time"
2020-02-20 18:04:03 +03:00
2020-06-02 04:16:07 +03:00
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
2020-06-02 04:16:07 +03:00
"gorm.io/gorm/schema"
2020-02-20 18:04:03 +03:00
)
2020-01-28 18:01:35 +03:00
var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
2020-02-22 12:53:57 +03:00
// Migrator m struct
2020-01-28 18:01:35 +03:00
type Migrator struct {
2020-02-22 12:53:57 +03:00
Config
2020-01-28 18:01:35 +03:00
}
// Config schema config
type Config struct {
CreateIndexAfterCreateTable bool
DB *gorm.DB
2020-02-22 12:53:57 +03:00
gorm.Dialector
2020-01-28 18:01:35 +03:00
}
2020-02-20 18:04:03 +03:00
type printSQLLogger struct {
logger.Interface
}
func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, _ := fc()
fmt.Println(sql + ";")
l.Interface.Trace(ctx, begin, fc, err)
}
// GormDataTypeInterface gorm data type interface
2020-06-06 05:47:32 +03:00
type GormDataTypeInterface interface {
GormDBDataType(*gorm.DB, *schema.Field) string
}
// RunWithValue run migration with statement value
2020-02-22 12:53:57 +03:00
func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
2020-04-28 03:05:22 +03:00
stmt := &gorm.Statement{DB: m.DB}
if m.DB.Statement != nil {
stmt.Table = m.DB.Statement.Table
stmt.TableExpr = m.DB.Statement.TableExpr
2020-02-20 18:04:03 +03:00
}
2020-04-28 03:05:22 +03:00
if table, ok := value.(string); ok {
stmt.Table = table
} else if err := stmt.ParseWithSpecialTableName(value, stmt.Table); err != nil {
2020-02-20 18:04:03 +03:00
return err
}
return fc(stmt)
}
// DataTypeOf return field's db data type
2020-02-22 12:53:57 +03:00
func (m Migrator) DataTypeOf(field *schema.Field) string {
2020-06-06 05:47:32 +03:00
fieldValue := reflect.New(field.IndirectFieldType)
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
if dataType := dataTyper.GormDBDataType(m.DB, field); dataType != "" {
return dataType
}
}
2020-02-22 12:53:57 +03:00
return m.Dialector.DataTypeOf(field)
}
// FullDataTypeOf returns field's db full data type
2020-05-30 19:42:52 +03:00
func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
expr.SQL = m.DataTypeOf(field)
if field.NotNull {
2020-05-30 19:42:52 +03:00
expr.SQL += " NOT NULL"
}
if field.Unique {
2020-05-30 19:42:52 +03:00
expr.SQL += " UNIQUE"
}
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
2020-06-25 03:00:10 +03:00
if field.DefaultValueInterface != nil {
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)
} else if field.DefaultValue != "(-)" {
2020-05-30 19:42:52 +03:00
expr.SQL += " DEFAULT " + field.DefaultValue
}
}
2020-05-30 19:42:52 +03:00
return
}
// AutoMigrate auto migrate values
2020-02-22 12:53:57 +03:00
func (m Migrator) AutoMigrate(values ...interface{}) error {
2020-02-22 19:18:12 +03:00
for _, value := range m.ReorderModels(values, true) {
queryTx := m.DB.Session(&gorm.Session{})
execTx := queryTx
if m.DB.DryRun {
queryTx.DryRun = false
execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
}
if !queryTx.Migrator().HasTable(value) {
if err := execTx.Migrator().CreateTable(value); err != nil {
2020-02-22 08:09:57 +03:00
return err
}
} else {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
columnTypes, err := queryTx.Migrator().ColumnTypes(value)
if err != nil {
return err
}
var (
parseIndexes = stmt.Schema.ParseIndexes()
parseCheckConstraints = stmt.Schema.ParseCheckConstraints()
)
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[dbName]
var foundColumn gorm.ColumnType
2020-08-23 10:40:19 +03:00
for _, columnType := range columnTypes {
if columnType.Name() == dbName {
2020-08-23 10:40:19 +03:00
foundColumn = columnType
break
}
}
if foundColumn == nil {
// not found, add column
if err := execTx.Migrator().AddColumn(value, dbName); err != nil {
2020-02-22 08:09:57 +03:00
return err
}
} else if err := execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
2020-08-23 10:40:19 +03:00
// found, smart migrate
return err
2020-02-22 08:09:57 +03:00
}
}
if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
for _, rel := range stmt.Schema.Relationships.Relations {
if rel.Field.IgnoreMigration {
continue
}
if constraint := rel.ParseConstraint(); constraint != nil &&
constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) {
if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
return err
2020-02-22 08:09:57 +03:00
}
}
}
}
2020-02-22 08:09:57 +03:00
for _, chk := range parseCheckConstraints {
if !queryTx.Migrator().HasConstraint(value, chk.Name) {
if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil {
return err
2020-02-22 08:09:57 +03:00
}
}
}
for _, idx := range parseIndexes {
if !queryTx.Migrator().HasIndex(value, idx.Name) {
if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil {
return err
}
}
}
2020-02-22 08:09:57 +03:00
return nil
}); err != nil {
return err
}
}
}
return nil
2020-02-20 18:04:03 +03:00
}
// GetTables returns tables
func (m Migrator) GetTables() (tableList []string, err error) {
err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).
Scan(&tableList).Error
return
}
// CreateTable create table in database for values
2020-02-22 12:53:57 +03:00
func (m Migrator) CreateTable(values ...interface{}) error {
2020-02-22 19:18:12 +03:00
for _, value := range m.ReorderModels(values, false) {
tx := m.DB.Session(&gorm.Session{})
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
2020-02-22 08:09:57 +03:00
var (
createTableSQL = "CREATE TABLE ? ("
values = []interface{}{m.CurrentTable(stmt)}
2020-02-22 08:09:57 +03:00
hasPrimaryKeyInDataType bool
)
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[dbName]
if !field.IgnoreMigration {
createTableSQL += "? ?"
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY")
values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
createTableSQL += ","
}
2020-02-22 08:09:57 +03:00
}
2020-05-31 07:52:49 +03:00
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
2020-02-22 08:09:57 +03:00
createTableSQL += "PRIMARY KEY ?,"
primaryKeys := []interface{}{}
for _, field := range stmt.Schema.PrimaryFields {
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
}
values = append(values, primaryKeys)
}
for _, idx := range stmt.Schema.ParseIndexes() {
2020-02-22 15:57:29 +03:00
if m.CreateIndexAfterCreateTable {
defer func(value interface{}, name string) {
if errr == nil {
errr = tx.Migrator().CreateIndex(value, name)
}
}(value, idx.Name)
2020-02-22 15:57:29 +03:00
} else {
if idx.Class != "" {
createTableSQL += idx.Class + " "
}
2020-10-21 15:15:49 +03:00
createTableSQL += "INDEX ? ?"
if idx.Comment != "" {
createTableSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment)
}
2020-10-21 15:15:49 +03:00
if idx.Option != "" {
createTableSQL += " " + idx.Option
}
createTableSQL += ","
values = append(values, clause.Column{Name: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
2020-02-22 15:57:29 +03:00
}
2020-02-22 08:09:57 +03:00
}
if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
for _, rel := range stmt.Schema.Relationships.Relations {
if rel.Field.IgnoreMigration {
continue
}
if constraint := rel.ParseConstraint(); constraint != nil {
if constraint.Schema == stmt.Schema {
sql, vars := buildConstraint(constraint)
createTableSQL += sql + ","
values = append(values, vars...)
}
2020-06-19 19:48:15 +03:00
}
2020-02-22 08:09:57 +03:00
}
}
for _, chk := range stmt.Schema.ParseCheckConstraints() {
2020-07-08 13:15:45 +03:00
createTableSQL += "CONSTRAINT ? CHECK (?),"
2020-02-22 08:09:57 +03:00
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
}
createTableSQL = strings.TrimSuffix(createTableSQL, ",")
createTableSQL += ")"
2020-06-14 14:13:16 +03:00
if tableOption, ok := m.DB.Get("gorm:table_options"); ok {
createTableSQL += fmt.Sprint(tableOption)
}
errr = tx.Exec(createTableSQL, values...).Error
return errr
2020-02-22 08:09:57 +03:00
}); err != nil {
return err
}
}
return nil
2020-02-20 18:04:03 +03:00
}
// DropTable drop table for values
2020-02-22 12:53:57 +03:00
func (m Migrator) DropTable(values ...interface{}) error {
2020-02-22 19:18:12 +03:00
values = m.ReorderModels(values, false)
for i := len(values) - 1; i >= 0; i-- {
tx := m.DB.Session(&gorm.Session{})
2020-06-02 02:28:29 +03:00
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
return tx.Exec("DROP TABLE IF EXISTS ?", m.CurrentTable(stmt)).Error
2020-06-02 02:28:29 +03:00
}); err != nil {
return err
2020-02-20 18:04:03 +03:00
}
}
return nil
}
// HasTable returns table exists or not for value, value could be a struct or string
2020-02-22 12:53:57 +03:00
func (m Migrator) HasTable(value interface{}) bool {
2020-02-20 18:04:03 +03:00
var count int64
2020-04-28 03:05:22 +03:00
2020-02-22 12:53:57 +03:00
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count)
})
2020-02-20 18:04:03 +03:00
2020-02-22 12:53:57 +03:00
return count > 0
2020-02-20 18:04:03 +03:00
}
// RenameTable rename table from oldName to newName
2020-05-31 05:24:49 +03:00
func (m Migrator) RenameTable(oldName, newName interface{}) error {
var oldTable, newTable interface{}
2020-05-31 05:24:49 +03:00
if v, ok := oldName.(string); ok {
oldTable = clause.Table{Name: v}
2020-05-31 05:24:49 +03:00
} else {
stmt := &gorm.Statement{DB: m.DB}
if err := stmt.Parse(oldName); err == nil {
oldTable = m.CurrentTable(stmt)
2020-05-31 05:24:49 +03:00
} else {
return err
}
}
if v, ok := newName.(string); ok {
newTable = clause.Table{Name: v}
2020-05-31 05:24:49 +03:00
} else {
stmt := &gorm.Statement{DB: m.DB}
if err := stmt.Parse(newName); err == nil {
newTable = m.CurrentTable(stmt)
2020-05-31 05:24:49 +03:00
} else {
return err
}
}
return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error
2020-02-20 18:04:03 +03:00
}
// AddColumn create `name` column for value
func (m Migrator) AddColumn(value interface{}, name string) error {
2020-02-22 12:53:57 +03:00
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// avoid using the same name field
f := stmt.Schema.LookUpField(name)
if f == nil {
return fmt.Errorf("failed to look up field with name: %s", name)
2020-02-20 18:04:03 +03:00
}
if !f.IgnoreMigration {
return m.DB.Exec(
"ALTER TABLE ? ADD ? ?",
m.CurrentTable(stmt), clause.Column{Name: f.DBName}, m.DB.Migrator().FullDataTypeOf(f),
).Error
}
return nil
2020-02-20 18:04:03 +03:00
})
}
// DropColumn drop value's `name` column
func (m Migrator) DropColumn(value interface{}, name string) error {
2020-02-22 12:53:57 +03:00
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
2020-02-20 18:04:03 +03:00
}
return m.DB.Exec(
"ALTER TABLE ? DROP COLUMN ?", m.CurrentTable(stmt), clause.Column{Name: name},
).Error
2020-02-20 18:04:03 +03:00
})
}
// AlterColumn alter value's `field` column' type based on schema definition
2020-02-22 12:53:57 +03:00
func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
2020-02-20 18:04:03 +03:00
if field := stmt.Schema.LookUpField(field); field != nil {
fileType := m.FullDataTypeOf(field)
2020-02-22 12:53:57 +03:00
return m.DB.Exec(
2020-02-22 06:15:51 +03:00
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
2020-02-22 06:15:51 +03:00
).Error
2020-02-20 18:04:03 +03:00
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
}
// HasColumn check has column `field` for value or not
2020-02-22 12:53:57 +03:00
func (m Migrator) HasColumn(value interface{}, field string) bool {
2020-02-22 08:09:57 +03:00
var count int64
2020-02-22 12:53:57 +03:00
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
2020-02-22 08:09:57 +03:00
name := field
if field := stmt.Schema.LookUpField(field); field != nil {
name = field.DBName
}
2020-02-22 12:53:57 +03:00
return m.DB.Raw(
2020-02-22 08:09:57 +03:00
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
currentDatabase, stmt.Table, name,
2020-02-22 12:53:57 +03:00
).Row().Scan(&count)
2020-02-22 08:09:57 +03:00
})
2020-02-22 12:53:57 +03:00
return count > 0
2020-02-22 08:09:57 +03:00
}
// RenameColumn rename value's field name from oldName to newName
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
2020-02-22 12:53:57 +03:00
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(oldName); field != nil {
oldName = field.DBName
2020-02-20 18:04:03 +03:00
}
if field := stmt.Schema.LookUpField(newName); field != nil {
newName = field.DBName
}
return m.DB.Exec(
"ALTER TABLE ? RENAME COLUMN ? TO ?",
m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName},
).Error
2020-02-20 18:04:03 +03:00
})
}
// MigrateColumn migrate column
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
2020-08-23 10:40:19 +03:00
// found, smart migrate
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
2020-08-23 10:40:19 +03:00
realDataType := strings.ToLower(columnType.DatabaseTypeName())
var (
alterColumn bool
isSameType = fullDataType == realDataType
)
2020-08-23 10:40:19 +03:00
if !field.PrimaryKey {
// check type
if !strings.HasPrefix(fullDataType, realDataType) {
// check type aliases
aliases := m.DB.Migrator().GetTypeAliases(realDataType)
for _, alias := range aliases {
if strings.HasPrefix(fullDataType, alias) {
isSameType = true
break
}
}
if !isSameType {
alterColumn = true
}
}
}
if !isSameType {
// check size
if length, ok := columnType.Length(); length != int64(field.Size) {
if length > 0 && field.Size > 0 {
2020-08-23 10:40:19 +03:00
alterColumn = true
} else {
// has size in data type and not equal
// Since the following code is frequently called in the for loop, reg optimization is needed here
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
if !field.PrimaryKey &&
(len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) {
alterColumn = true
}
2020-08-23 10:40:19 +03:00
}
}
// check precision
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
alterColumn = true
}
2020-08-23 10:40:19 +03:00
}
}
// check nullable
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
// not primary key & database is nullable
if !field.PrimaryKey && nullable {
alterColumn = true
}
}
2022-02-19 18:42:20 +03:00
// check unique
if unique, ok := columnType.Unique(); ok && unique != field.Unique {
// not primary key
if !field.PrimaryKey {
alterColumn = true
}
}
// check default value
if !field.PrimaryKey {
currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL"))
dv, dvNotNull := columnType.DefaultValue()
if dvNotNull && !currentDefaultNotNull {
// defalut value -> null
alterColumn = true
} else if !dvNotNull && currentDefaultNotNull {
// null -> default value
2022-02-19 18:42:20 +03:00
alterColumn = true
} else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) ||
(field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) {
// default value not equal
// not both null
if currentDefaultNotNull || dvNotNull {
alterColumn = true
}
2022-02-19 18:42:20 +03:00
}
}
// check comment
if comment, ok := columnType.Comment(); ok && comment != field.Comment {
// not primary key
if !field.PrimaryKey {
alterColumn = true
}
}
if alterColumn && !field.IgnoreMigration {
2022-09-22 06:25:03 +03:00
return m.DB.Migrator().AlterColumn(value, field.DBName)
2020-08-23 10:40:19 +03:00
}
return nil
}
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
2021-12-02 05:39:24 +03:00
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
if err != nil {
return err
2020-02-22 14:41:01 +03:00
}
2021-12-02 05:39:24 +03:00
defer func() {
err = rows.Close()
}()
var rawColumnTypes []*sql.ColumnType
rawColumnTypes, err = rows.ColumnTypes()
if err != nil {
return err
}
for _, c := range rawColumnTypes {
columnTypes = append(columnTypes, ColumnType{SQLColumnType: c})
}
2021-12-02 05:39:24 +03:00
return
2020-02-22 14:41:01 +03:00
})
return columnTypes, execErr
2020-02-20 18:04:03 +03:00
}
// CreateView create view from Query in gorm.ViewOption.
// Query in gorm.ViewOption is a [subquery]
//
// // CREATE VIEW `user_view` AS SELECT * FROM `users` WHERE age > 20
// q := DB.Model(&User{}).Where("age > ?", 20)
// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q})
//
// // CREATE OR REPLACE VIEW `users_view` AS SELECT * FROM `users` WITH CHECK OPTION
// q := DB.Model(&User{})
// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q, Replace: true, CheckOption: "WITH CHECK OPTION"})
//
// [subquery]: https://gorm.io/docs/advanced_query.html#SubQuery
2020-02-22 12:53:57 +03:00
func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
if option.Query == nil {
return gorm.ErrSubQueryRequired
}
sql := new(strings.Builder)
sql.WriteString("CREATE ")
if option.Replace {
sql.WriteString("OR REPLACE ")
}
sql.WriteString("VIEW ")
m.QuoteTo(sql, name)
sql.WriteString(" AS ")
m.DB.Statement.AddVar(sql, option.Query)
if option.CheckOption != "" {
sql.WriteString(" ")
sql.WriteString(option.CheckOption)
}
return m.DB.Exec(m.Explain(sql.String(), m.DB.Statement.Vars...)).Error
2020-02-20 18:04:03 +03:00
}
// DropView drop view
2020-02-22 12:53:57 +03:00
func (m Migrator) DropView(name string) error {
return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error
2020-02-20 18:04:03 +03:00
}
2020-02-22 08:09:57 +03:00
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
2020-06-19 19:48:15 +03:00
sql += " ON UPDATE " + constraint.OnUpdate
2020-02-22 08:09:57 +03:00
}
var foreignKeys, references []interface{}
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
2020-02-22 15:57:29 +03:00
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
2020-02-22 08:09:57 +03:00
return
}
// GuessConstraintAndTable guess statement's constraint and it's table based on name
2021-01-26 08:39:34 +03:00
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) {
if stmt.Schema == nil {
return nil, nil, stmt.Table
}
checkConstraints := stmt.Schema.ParseCheckConstraints()
if chk, ok := checkConstraints[name]; ok {
return nil, &chk, stmt.Table
}
getTable := func(rel *schema.Relationship) string {
switch rel.Type {
case schema.HasOne, schema.HasMany:
return rel.FieldSchema.Table
case schema.Many2Many:
return rel.JoinTable.Table
}
return stmt.Table
}
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
return constraint, nil, getTable(rel)
}
}
if field := stmt.Schema.LookUpField(name); field != nil {
for k := range checkConstraints {
if checkConstraints[k].Field == field {
v := checkConstraints[k]
return nil, &v, stmt.Table
2021-01-26 08:39:34 +03:00
}
}
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field {
return constraint, nil, getTable(rel)
}
}
}
2021-02-26 12:30:00 +03:00
return nil, nil, stmt.Schema.Table
2021-01-26 08:39:34 +03:00
}
// CreateConstraint create constraint
2020-02-22 12:53:57 +03:00
func (m Migrator) CreateConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
2021-01-26 08:39:34 +03:00
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if chk != nil {
2020-02-22 12:53:57 +03:00
return m.DB.Exec(
2020-07-08 13:15:45 +03:00
"ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)",
m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint},
2020-02-22 06:15:51 +03:00
).Error
}
2021-01-26 08:39:34 +03:00
if constraint != nil {
vars := []interface{}{clause.Table{Name: table}}
2021-01-26 08:39:34 +03:00
if stmt.TableExpr != nil {
vars[0] = stmt.TableExpr
2020-02-22 06:15:51 +03:00
}
2021-01-26 08:39:34 +03:00
sql, values := buildConstraint(constraint)
return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error
2020-02-22 06:15:51 +03:00
}
2021-01-26 08:39:34 +03:00
return nil
2020-02-22 06:15:51 +03:00
})
2020-02-20 18:04:03 +03:00
}
// DropConstraint drop constraint
2020-02-22 12:53:57 +03:00
func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
2021-01-26 08:39:34 +03:00
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if constraint != nil {
name = constraint.Name
} else if chk != nil {
name = chk.Name
}
return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
2020-02-20 18:04:03 +03:00
})
}
// HasConstraint check has constraint or not
2020-02-22 12:53:57 +03:00
func (m Migrator) HasConstraint(value interface{}, name string) bool {
2020-02-22 08:09:57 +03:00
var count int64
2020-02-22 12:53:57 +03:00
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
2021-01-26 08:39:34 +03:00
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if constraint != nil {
name = constraint.Name
} else if chk != nil {
name = chk.Name
}
2020-02-22 12:53:57 +03:00
return m.DB.Raw(
"SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?",
2021-01-26 08:39:34 +03:00
currentDatabase, table, name,
2020-02-22 12:53:57 +03:00
).Row().Scan(&count)
2020-02-22 08:09:57 +03:00
})
2020-02-22 12:53:57 +03:00
return count > 0
2020-02-22 08:09:57 +03:00
}
// BuildIndexOptions build index options
2020-02-22 12:53:57 +03:00
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
2020-02-22 08:09:57 +03:00
for _, opt := range opts {
str := stmt.Quote(opt.DBName)
if opt.Expression != "" {
str = opt.Expression
} else if opt.Length > 0 {
str += fmt.Sprintf("(%d)", opt.Length)
}
2020-02-22 12:53:57 +03:00
if opt.Collate != "" {
str += " COLLATE " + opt.Collate
}
2020-02-22 08:09:57 +03:00
if opt.Sort != "" {
str += " " + opt.Sort
}
results = append(results, clause.Expr{SQL: str})
}
return
}
// BuildIndexOptionsInterface build index options interface
2020-02-22 12:53:57 +03:00
type BuildIndexOptionsInterface interface {
BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{}
}
// CreateIndex create index `name`
2020-02-22 12:53:57 +03:00
func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
2020-05-30 17:27:20 +03:00
if idx := stmt.Schema.LookIndex(name); idx != nil {
2020-02-22 12:53:57 +03:00
opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
2020-02-22 06:15:51 +03:00
createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX ? ON ??"
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type
}
if idx.Comment != "" {
createIndexSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment)
}
2020-10-21 15:15:49 +03:00
if idx.Option != "" {
createIndexSQL += " " + idx.Option
}
2020-02-22 12:53:57 +03:00
return m.DB.Exec(createIndexSQL, values...).Error
2020-02-22 06:15:51 +03:00
}
2020-05-30 17:27:20 +03:00
return fmt.Errorf("failed to create index with name %s", name)
2020-02-22 06:15:51 +03:00
})
2020-02-20 18:04:03 +03:00
}
// DropIndex drop index `name`
2020-02-22 12:53:57 +03:00
func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
2020-05-30 17:27:20 +03:00
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error
2020-02-20 18:04:03 +03:00
})
}
// HasIndex check has index `name` or not
2020-02-22 12:53:57 +03:00
func (m Migrator) HasIndex(value interface{}, name string) bool {
2020-02-20 18:04:03 +03:00
var count int64
2020-02-22 12:53:57 +03:00
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
2020-05-30 17:27:20 +03:00
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
2020-02-22 12:53:57 +03:00
return m.DB.Raw(
2020-02-22 06:15:51 +03:00
"SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?",
currentDatabase, stmt.Table, name,
2020-02-22 12:53:57 +03:00
).Row().Scan(&count)
2020-02-20 18:04:03 +03:00
})
2020-02-22 12:53:57 +03:00
return count > 0
2020-02-20 18:04:03 +03:00
}
// RenameIndex rename index from oldName to newName
2020-02-22 12:53:57 +03:00
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec(
2020-02-22 06:15:51 +03:00
"ALTER TABLE ? RENAME INDEX ? TO ?",
m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName},
2020-02-22 06:15:51 +03:00
).Error
2020-02-20 18:04:03 +03:00
})
}
// CurrentDatabase returns current database name
2020-02-22 12:53:57 +03:00
func (m Migrator) CurrentDatabase() (name string) {
m.DB.Raw("SELECT DATABASE()").Row().Scan(&name)
2020-02-20 18:04:03 +03:00
return
}
2020-02-22 19:18:12 +03:00
// ReorderModels reorder models according to constraint dependencies
func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []interface{}) {
type Dependency struct {
2020-02-23 03:29:59 +03:00
*gorm.Statement
2020-02-22 19:18:12 +03:00
Depends []*schema.Schema
}
var (
modelNames, orderedModelNames []string
orderedModelNamesMap = map[string]bool{}
parsedSchemas = map[*schema.Schema]bool{}
2020-02-23 03:29:59 +03:00
valuesMap = map[string]Dependency{}
insertIntoOrderedList func(name string)
2020-06-24 11:43:53 +03:00
parseDependence func(value interface{}, addToList bool)
2020-02-22 19:18:12 +03:00
)
2020-06-24 11:43:53 +03:00
parseDependence = func(value interface{}, addToList bool) {
2020-02-23 03:29:59 +03:00
dep := Dependency{
2020-03-09 15:37:01 +03:00
Statement: &gorm.Statement{DB: m.DB, Dest: value},
2020-02-23 03:29:59 +03:00
}
beDependedOn := map[*schema.Schema]bool{}
// support for special table name
if err := dep.ParseWithSpecialTableName(value, m.DB.Statement.Table); err != nil {
2020-07-16 06:27:04 +03:00
m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err)
}
if _, ok := parsedSchemas[dep.Statement.Schema]; ok {
return
}
parsedSchemas[dep.Statement.Schema] = true
2020-02-22 19:18:12 +03:00
if !m.DB.IgnoreRelationshipsWhenMigrating {
for _, rel := range dep.Schema.Relationships.Relations {
if rel.Field.IgnoreMigration {
continue
}
if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
dep.Depends = append(dep.Depends, c.ReferenceSchema)
}
if rel.Type == schema.HasOne || rel.Type == schema.HasMany {
beDependedOn[rel.FieldSchema] = true
}
if rel.JoinTable != nil {
// append join value
defer func(rel *schema.Relationship, joinValue interface{}) {
if !beDependedOn[rel.FieldSchema] {
dep.Depends = append(dep.Depends, rel.FieldSchema)
} else {
fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface()
parseDependence(fieldValue, autoAdd)
}
parseDependence(joinValue, autoAdd)
}(rel, reflect.New(rel.JoinTable.ModelType).Interface())
}
}
2020-02-22 19:18:12 +03:00
}
2020-02-23 03:29:59 +03:00
valuesMap[dep.Schema.Table] = dep
2020-02-22 19:18:12 +03:00
2020-02-23 03:29:59 +03:00
if addToList {
modelNames = append(modelNames, dep.Schema.Table)
}
2020-02-22 19:18:12 +03:00
}
2020-02-23 03:29:59 +03:00
insertIntoOrderedList = func(name string) {
2020-02-22 19:18:12 +03:00
if _, ok := orderedModelNamesMap[name]; ok {
2020-02-23 03:29:59 +03:00
return // avoid loop
2020-02-22 19:18:12 +03:00
}
orderedModelNamesMap[name] = true
2020-02-22 19:18:12 +03:00
if autoAdd {
dep := valuesMap[name]
for _, d := range dep.Depends {
if _, ok := valuesMap[d.Table]; ok {
insertIntoOrderedList(d.Table)
} else {
parseDependence(reflect.New(d.ModelType).Interface(), autoAdd)
insertIntoOrderedList(d.Table)
}
2020-02-22 19:18:12 +03:00
}
}
orderedModelNames = append(orderedModelNames, name)
}
2020-02-23 03:29:59 +03:00
for _, value := range values {
2020-05-23 11:38:13 +03:00
if v, ok := value.(string); ok {
results = append(results, v)
} else {
parseDependence(value, true)
}
2020-02-23 03:29:59 +03:00
}
2020-02-22 19:18:12 +03:00
for _, name := range modelNames {
2020-02-23 03:29:59 +03:00
insertIntoOrderedList(name)
2020-02-22 19:18:12 +03:00
}
for _, name := range orderedModelNames {
2020-02-23 03:29:59 +03:00
results = append(results, valuesMap[name].Statement.Dest)
2020-02-22 19:18:12 +03:00
}
return
}
// CurrentTable returns current statement's table expression
func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
if stmt.TableExpr != nil {
return *stmt.TableExpr
}
return clause.Table{Name: stmt.Table}
}
// GetIndexes return Indexes []gorm.Index and execErr error
func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) {
return nil, errors.New("not support")
}
// GetTypeAliases return database type aliases
func (m Migrator) GetTypeAliases(databaseTypeName string) []string {
return nil
}