2020-01-28 18:01:35 +03:00
package migrator
2020-02-20 18:04:03 +03:00
import (
"database/sql"
"fmt"
2020-02-22 08:09:57 +03:00
"reflect"
"strings"
2020-02-20 18:04:03 +03:00
"github.com/jinzhu/gorm"
2020-02-22 06:15:51 +03:00
"github.com/jinzhu/gorm/clause"
2020-02-22 08:09:57 +03:00
"github.com/jinzhu/gorm/schema"
2020-02-20 18:04:03 +03:00
)
2020-01-28 18:01:35 +03:00
2020-02-22 12:53:57 +03:00
// Migrator m struct
2020-01-28 18:01:35 +03:00
type Migrator struct {
2020-02-22 12:53:57 +03:00
Config
2020-01-28 18:01:35 +03:00
}
// Config schema config
type Config struct {
2020-02-22 15:57:29 +03:00
CreateIndexAfterCreateTable bool
DB * gorm . DB
2020-02-22 12:53:57 +03:00
gorm . Dialector
2020-01-28 18:01:35 +03:00
}
2020-02-20 18:04:03 +03:00
2020-02-22 12:53:57 +03:00
func ( m Migrator ) RunWithValue ( value interface { } , fc func ( * gorm . Statement ) error ) error {
stmt := m . DB . Statement
2020-02-20 18:04:03 +03:00
if stmt == nil {
2020-02-22 12:53:57 +03:00
stmt = & gorm . Statement { DB : m . DB }
2020-02-20 18:04:03 +03:00
}
if err := stmt . Parse ( value ) ; err != nil {
return err
}
return fc ( stmt )
}
2020-02-22 12:53:57 +03:00
func ( m Migrator ) DataTypeOf ( field * schema . Field ) string {
if field . DBDataType != "" {
return field . DBDataType
}
return m . Dialector . DataTypeOf ( field )
}
2020-02-20 18:04:03 +03:00
// AutoMigrate
2020-02-22 12:53:57 +03:00
func ( m Migrator ) AutoMigrate ( values ... interface { } ) error {
2020-02-22 08:09:57 +03:00
// TODO smart migrate data type
2020-02-22 06:15:51 +03:00
2020-02-22 08:09:57 +03:00
for _ , value := range values {
2020-02-22 12:53:57 +03:00
if ! m . DB . Migrator ( ) . HasTable ( value ) {
if err := m . DB . Migrator ( ) . CreateTable ( value ) ; err != nil {
2020-02-22 08:09:57 +03:00
return err
}
} else {
2020-02-22 12:53:57 +03:00
if err := m . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
2020-02-22 08:09:57 +03:00
for _ , field := range stmt . Schema . FieldsByDBName {
2020-02-22 12:53:57 +03:00
if ! m . DB . Migrator ( ) . HasColumn ( value , field . DBName ) {
if err := m . DB . Migrator ( ) . AddColumn ( value , field . DBName ) ; err != nil {
2020-02-22 08:09:57 +03:00
return err
}
}
}
for _ , rel := range stmt . Schema . Relationships . Relations {
if constraint := rel . ParseConstraint ( ) ; constraint != nil {
2020-02-22 12:53:57 +03:00
if ! m . DB . Migrator ( ) . HasConstraint ( value , constraint . Name ) {
if err := m . DB . Migrator ( ) . CreateConstraint ( value , constraint . Name ) ; err != nil {
2020-02-22 08:09:57 +03:00
return err
}
}
}
for _ , chk := range stmt . Schema . ParseCheckConstraints ( ) {
2020-02-22 12:53:57 +03:00
if ! m . DB . Migrator ( ) . HasConstraint ( value , chk . Name ) {
if err := m . DB . Migrator ( ) . CreateConstraint ( value , chk . Name ) ; err != nil {
2020-02-22 08:09:57 +03:00
return err
}
}
}
// create join table
2020-02-22 15:57:29 +03:00
if rel . JoinTable != nil {
joinValue := reflect . New ( rel . JoinTable . ModelType ) . Interface ( )
if ! m . DB . Migrator ( ) . HasTable ( joinValue ) {
defer m . DB . Migrator ( ) . CreateTable ( joinValue )
}
2020-02-22 08:09:57 +03:00
}
}
return nil
} ) ; err != nil {
return err
}
}
}
return nil
2020-02-20 18:04:03 +03:00
}
2020-02-22 12:53:57 +03:00
func ( m Migrator ) CreateTable ( values ... interface { } ) error {
2020-02-22 08:09:57 +03:00
for _ , value := range values {
2020-02-22 12:53:57 +03:00
if err := m . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
2020-02-22 08:09:57 +03:00
var (
createTableSQL = "CREATE TABLE ? ("
values = [ ] interface { } { clause . Table { Name : stmt . Table } }
hasPrimaryKeyInDataType bool
)
for _ , dbName := range stmt . Schema . DBNames {
field := stmt . Schema . FieldsByDBName [ dbName ]
createTableSQL += fmt . Sprintf ( "? ?" )
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings . Contains ( strings . ToUpper ( field . DBDataType ) , "PRIMARY KEY" )
2020-02-22 12:53:57 +03:00
values = append ( values , clause . Column { Name : dbName } , clause . Expr { SQL : m . DataTypeOf ( field ) } )
2020-02-22 08:09:57 +03:00
if field . AutoIncrement {
createTableSQL += " AUTO_INCREMENT"
}
if field . NotNull {
createTableSQL += " NOT NULL"
}
if field . Unique {
createTableSQL += " UNIQUE"
}
if field . DefaultValue != "" {
createTableSQL += " DEFAULT ?"
values = append ( values , clause . Expr { SQL : field . DefaultValue } )
}
createTableSQL += ","
}
if ! hasPrimaryKeyInDataType {
createTableSQL += "PRIMARY KEY ?,"
primaryKeys := [ ] interface { } { }
for _ , field := range stmt . Schema . PrimaryFields {
primaryKeys = append ( primaryKeys , clause . Column { Name : field . DBName } )
}
values = append ( values , primaryKeys )
}
for _ , idx := range stmt . Schema . ParseIndexes ( ) {
2020-02-22 15:57:29 +03:00
if m . CreateIndexAfterCreateTable {
m . DB . Migrator ( ) . CreateIndex ( value , idx . Name )
} else {
createTableSQL += "INDEX ? ?,"
values = append ( values , clause . Expr { SQL : idx . Name } , m . DB . Migrator ( ) . ( BuildIndexOptionsInterface ) . BuildIndexOptions ( idx . Fields , stmt ) )
}
2020-02-22 08:09:57 +03:00
}
for _ , rel := range stmt . Schema . Relationships . Relations {
if constraint := rel . ParseConstraint ( ) ; constraint != nil {
sql , vars := buildConstraint ( constraint )
createTableSQL += sql + ","
values = append ( values , vars ... )
}
// create join table
2020-02-22 15:57:29 +03:00
if rel . JoinTable != nil {
joinValue := reflect . New ( rel . JoinTable . ModelType ) . Interface ( )
if ! m . DB . Migrator ( ) . HasTable ( joinValue ) {
defer m . DB . Migrator ( ) . CreateTable ( joinValue )
}
2020-02-22 08:09:57 +03:00
}
}
for _ , chk := range stmt . Schema . ParseCheckConstraints ( ) {
createTableSQL += "CONSTRAINT ? CHECK ?,"
values = append ( values , clause . Column { Name : chk . Name } , clause . Expr { SQL : chk . Constraint } )
}
createTableSQL = strings . TrimSuffix ( createTableSQL , "," )
createTableSQL += ")"
2020-02-22 12:53:57 +03:00
return m . DB . Exec ( createTableSQL , values ... ) . Error
2020-02-22 08:09:57 +03:00
} ) ; err != nil {
return err
}
}
return nil
2020-02-20 18:04:03 +03:00
}
2020-02-22 12:53:57 +03:00
func ( m Migrator ) DropTable ( values ... interface { } ) error {
2020-02-20 18:04:03 +03:00
for _ , value := range values {
2020-02-22 12:53:57 +03:00
if err := m . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
return m . DB . Exec ( "DROP TABLE ?" , clause . Table { Name : stmt . Table } ) . Error
2020-02-20 18:04:03 +03:00
} ) ; err != nil {
return err
}
}
return nil
}
2020-02-22 12:53:57 +03:00
func ( m Migrator ) HasTable ( value interface { } ) bool {
2020-02-20 18:04:03 +03:00
var count int64
2020-02-22 12:53:57 +03:00
m . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
currentDatabase := m . DB . Migrator ( ) . CurrentDatabase ( )
return m . DB . Raw ( "SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?" , currentDatabase , stmt . Table , "BASE TABLE" ) . Row ( ) . Scan ( & count )
} )
2020-02-20 18:04:03 +03:00
2020-02-22 12:53:57 +03:00
return count > 0
2020-02-20 18:04:03 +03:00
}
2020-02-22 12:53:57 +03:00
func ( m Migrator ) RenameTable ( oldName , newName string ) error {
return m . DB . Exec ( "RENAME TABLE ? TO ?" , oldName , newName ) . Error
2020-02-20 18:04:03 +03:00
}
2020-02-22 12:53:57 +03:00
func ( m Migrator ) AddColumn ( value interface { } , field string ) error {
return m . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
2020-02-20 18:04:03 +03:00
if field := stmt . Schema . LookUpField ( field ) ; field != nil {
2020-02-22 12:53:57 +03:00
return m . DB . Exec (
2020-02-22 06:15:51 +03:00
"ALTER TABLE ? ADD ? ?" ,
2020-02-22 12:53:57 +03:00
clause . Table { Name : stmt . Table } , clause . Column { Name : field . DBName } , clause . Expr { SQL : m . DataTypeOf ( field ) } ,
2020-02-22 06:15:51 +03:00
) . Error
2020-02-20 18:04:03 +03:00
}
return fmt . Errorf ( "failed to look up field with name: %s" , field )
} )
}
2020-02-22 12:53:57 +03:00
func ( m Migrator ) DropColumn ( value interface { } , field string ) error {
return m . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
2020-02-20 18:04:03 +03:00
if field := stmt . Schema . LookUpField ( field ) ; field != nil {
2020-02-22 12:53:57 +03:00
return m . DB . Exec (
2020-02-22 06:15:51 +03:00
"ALTER TABLE ? DROP COLUMN ?" , clause . Table { Name : stmt . Table } , clause . Column { Name : field . DBName } ,
) . Error
2020-02-20 18:04:03 +03:00
}
return fmt . Errorf ( "failed to look up field with name: %s" , field )
} )
}
2020-02-22 12:53:57 +03:00
func ( m Migrator ) AlterColumn ( value interface { } , field string ) error {
return m . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
2020-02-20 18:04:03 +03:00
if field := stmt . Schema . LookUpField ( field ) ; field != nil {
2020-02-22 12:53:57 +03:00
return m . DB . Exec (
2020-02-22 06:15:51 +03:00
"ALTER TABLE ? ALTER COLUMN ? TYPE ?" ,
2020-02-22 12:53:57 +03:00
clause . Table { Name : stmt . Table } , clause . Column { Name : field . DBName } , clause . Expr { SQL : m . DataTypeOf ( field ) } ,
2020-02-22 06:15:51 +03:00
) . Error
2020-02-20 18:04:03 +03:00
}
return fmt . Errorf ( "failed to look up field with name: %s" , field )
} )
}
2020-02-22 12:53:57 +03:00
func ( m Migrator ) HasColumn ( value interface { } , field string ) bool {
2020-02-22 08:09:57 +03:00
var count int64
2020-02-22 12:53:57 +03:00
m . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
currentDatabase := m . DB . Migrator ( ) . CurrentDatabase ( )
2020-02-22 08:09:57 +03:00
name := field
if field := stmt . Schema . LookUpField ( field ) ; field != nil {
name = field . DBName
}
2020-02-22 12:53:57 +03:00
return m . DB . Raw (
2020-02-22 08:09:57 +03:00
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?" ,
currentDatabase , stmt . Table , name ,
2020-02-22 12:53:57 +03:00
) . Row ( ) . Scan ( & count )
2020-02-22 08:09:57 +03:00
} )
2020-02-22 12:53:57 +03:00
return count > 0
2020-02-22 08:09:57 +03:00
}
2020-02-22 12:53:57 +03:00
func ( m Migrator ) RenameColumn ( value interface { } , oldName , field string ) error {
return m . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
2020-02-20 18:04:03 +03:00
if field := stmt . Schema . LookUpField ( field ) ; field != nil {
2020-02-22 12:53:57 +03:00
oldName = m . DB . NamingStrategy . ColumnName ( stmt . Table , oldName )
return m . DB . Exec (
2020-02-22 06:15:51 +03:00
"ALTER TABLE ? RENAME COLUMN ? TO ?" ,
clause . Table { Name : stmt . Table } , clause . Column { Name : oldName } , clause . Column { Name : field . DBName } ,
) . Error
2020-02-20 18:04:03 +03:00
}
return fmt . Errorf ( "failed to look up field with name: %s" , field )
} )
}
2020-02-22 14:41:01 +03:00
func ( m Migrator ) ColumnTypes ( value interface { } ) ( columnTypes [ ] * sql . ColumnType , err error ) {
err = m . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
rows , err := m . DB . Raw ( "select * from ?" , clause . Table { Name : stmt . Table } ) . Rows ( )
if err == nil {
columnTypes , err = rows . ColumnTypes ( )
}
return err
} )
return
2020-02-20 18:04:03 +03:00
}
2020-02-22 12:53:57 +03:00
func ( m Migrator ) CreateView ( name string , option gorm . ViewOption ) error {
2020-02-20 18:04:03 +03:00
return gorm . ErrNotImplemented
}
2020-02-22 12:53:57 +03:00
func ( m Migrator ) DropView ( name string ) error {
2020-02-20 18:04:03 +03:00
return gorm . ErrNotImplemented
}
2020-02-22 08:09:57 +03:00
func buildConstraint ( constraint * schema . Constraint ) ( sql string , results [ ] interface { } ) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint . OnDelete != "" {
sql += " ON DELETE " + constraint . OnDelete
}
if constraint . OnUpdate != "" {
sql += " ON UPDATE " + constraint . OnUpdate
}
var foreignKeys , references [ ] interface { }
for _ , field := range constraint . ForeignKeys {
foreignKeys = append ( foreignKeys , clause . Column { Name : field . DBName } )
}
for _ , field := range constraint . References {
references = append ( references , clause . Column { Name : field . DBName } )
}
2020-02-22 15:57:29 +03:00
results = append ( results , clause . Table { Name : constraint . Name } , foreignKeys , clause . Table { Name : constraint . ReferenceSchema . Table } , references )
2020-02-22 08:09:57 +03:00
return
}
2020-02-22 12:53:57 +03:00
func ( m Migrator ) CreateConstraint ( value interface { } , name string ) error {
return m . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
2020-02-22 06:15:51 +03:00
checkConstraints := stmt . Schema . ParseCheckConstraints ( )
if chk , ok := checkConstraints [ name ] ; ok {
2020-02-22 12:53:57 +03:00
return m . DB . Exec (
2020-02-22 06:15:51 +03:00
"ALTER TABLE ? ADD CONSTRAINT ? CHECK ?" ,
clause . Table { Name : stmt . Table } , clause . Column { Name : chk . Name } , clause . Expr { SQL : chk . Constraint } ,
) . Error
}
for _ , rel := range stmt . Schema . Relationships . Relations {
if constraint := rel . ParseConstraint ( ) ; constraint != nil && constraint . Name == name {
2020-02-22 08:09:57 +03:00
sql , values := buildConstraint ( constraint )
2020-02-22 12:53:57 +03:00
return m . DB . Exec ( "ALTER TABLE ? ADD " + sql , append ( [ ] interface { } { clause . Table { Name : stmt . Table } } , values ... ) ... ) . Error
2020-02-22 06:15:51 +03:00
}
}
err := fmt . Errorf ( "failed to create constraint with name %v" , name )
if field := stmt . Schema . LookUpField ( name ) ; field != nil {
for _ , cc := range checkConstraints {
2020-02-22 15:57:29 +03:00
if err = m . DB . Migrator ( ) . CreateIndex ( value , cc . Name ) ; err != nil {
2020-02-22 06:15:51 +03:00
return err
}
}
for _ , rel := range stmt . Schema . Relationships . Relations {
if constraint := rel . ParseConstraint ( ) ; constraint != nil && constraint . Field == field {
2020-02-22 15:57:29 +03:00
if err = m . DB . Migrator ( ) . CreateIndex ( value , constraint . Name ) ; err != nil {
2020-02-22 06:15:51 +03:00
return err
}
}
}
}
return err
} )
2020-02-20 18:04:03 +03:00
}
2020-02-22 12:53:57 +03:00
func ( m Migrator ) DropConstraint ( value interface { } , name string ) error {
return m . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
return m . DB . Exec (
2020-02-22 06:15:51 +03:00
"ALTER TABLE ? DROP CONSTRAINT ?" ,
clause . Table { Name : stmt . Table } , clause . Column { Name : name } ,
) . Error
2020-02-20 18:04:03 +03:00
} )
}
2020-02-22 12:53:57 +03:00
func ( m Migrator ) HasConstraint ( value interface { } , name string ) bool {
2020-02-22 08:09:57 +03:00
var count int64
2020-02-22 12:53:57 +03:00
m . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
currentDatabase := m . DB . Migrator ( ) . CurrentDatabase ( )
return m . DB . Raw (
2020-02-22 08:09:57 +03:00
"SELECT count(*) FROM INFORMATION_SCHEMA.referential_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?" ,
currentDatabase , stmt . Table , name ,
2020-02-22 12:53:57 +03:00
) . Row ( ) . Scan ( & count )
2020-02-22 08:09:57 +03:00
} )
2020-02-22 12:53:57 +03:00
return count > 0
2020-02-22 08:09:57 +03:00
}
2020-02-22 12:53:57 +03:00
func ( m Migrator ) BuildIndexOptions ( opts [ ] schema . IndexOption , stmt * gorm . Statement ) ( results [ ] interface { } ) {
2020-02-22 08:09:57 +03:00
for _ , opt := range opts {
str := stmt . Quote ( opt . DBName )
if opt . Expression != "" {
str = opt . Expression
} else if opt . Length > 0 {
str += fmt . Sprintf ( "(%d)" , opt . Length )
}
2020-02-22 12:53:57 +03:00
if opt . Collate != "" {
str += " COLLATE " + opt . Collate
}
2020-02-22 08:09:57 +03:00
if opt . Sort != "" {
str += " " + opt . Sort
}
results = append ( results , clause . Expr { SQL : str } )
}
return
}
2020-02-22 12:53:57 +03:00
type BuildIndexOptionsInterface interface {
BuildIndexOptions ( [ ] schema . IndexOption , * gorm . Statement ) [ ] interface { }
}
func ( m Migrator ) CreateIndex ( value interface { } , name string ) error {
return m . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
2020-02-22 06:15:51 +03:00
err := fmt . Errorf ( "failed to create index with name %v" , name )
indexes := stmt . Schema . ParseIndexes ( )
if idx , ok := indexes [ name ] ; ok {
2020-02-22 12:53:57 +03:00
opts := m . DB . Migrator ( ) . ( BuildIndexOptionsInterface ) . BuildIndexOptions ( idx . Fields , stmt )
2020-02-22 08:09:57 +03:00
values := [ ] interface { } { clause . Column { Name : idx . Name } , clause . Table { Name : stmt . Table } , opts }
2020-02-22 06:15:51 +03:00
createIndexSQL := "CREATE "
if idx . Class != "" {
createIndexSQL += idx . Class + " "
}
createIndexSQL += "INDEX ? ON ??"
if idx . Comment != "" {
values = append ( values , idx . Comment )
createIndexSQL += " COMMENT ?"
}
if idx . Type != "" {
createIndexSQL += " USING " + idx . Type
}
2020-02-22 12:53:57 +03:00
return m . DB . Exec ( createIndexSQL , values ... ) . Error
2020-02-22 06:15:51 +03:00
} else if field := stmt . Schema . LookUpField ( name ) ; field != nil {
for _ , idx := range indexes {
for _ , idxOpt := range idx . Fields {
if idxOpt . Field == field {
2020-02-22 12:53:57 +03:00
if err = m . CreateIndex ( value , idx . Name ) ; err != nil {
2020-02-22 06:15:51 +03:00
return err
}
}
}
}
}
return err
} )
2020-02-20 18:04:03 +03:00
}
2020-02-22 12:53:57 +03:00
func ( m Migrator ) DropIndex ( value interface { } , name string ) error {
return m . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
return m . DB . Exec ( "DROP INDEX ? ON ?" , clause . Column { Name : name } , clause . Table { Name : stmt . Table } ) . Error
2020-02-20 18:04:03 +03:00
} )
}
2020-02-22 12:53:57 +03:00
func ( m Migrator ) HasIndex ( value interface { } , name string ) bool {
2020-02-20 18:04:03 +03:00
var count int64
2020-02-22 12:53:57 +03:00
m . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
currentDatabase := m . DB . Migrator ( ) . CurrentDatabase ( )
return m . DB . Raw (
2020-02-22 06:15:51 +03:00
"SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?" ,
currentDatabase , stmt . Table , name ,
2020-02-22 12:53:57 +03:00
) . Row ( ) . Scan ( & count )
2020-02-20 18:04:03 +03:00
} )
2020-02-22 12:53:57 +03:00
return count > 0
2020-02-20 18:04:03 +03:00
}
2020-02-22 12:53:57 +03:00
func ( m Migrator ) RenameIndex ( value interface { } , oldName , newName string ) error {
return m . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
return m . DB . Exec (
2020-02-22 06:15:51 +03:00
"ALTER TABLE ? RENAME INDEX ? TO ?" ,
clause . Table { Name : stmt . Table } , clause . Column { Name : oldName } , clause . Column { Name : newName } ,
) . Error
2020-02-20 18:04:03 +03:00
} )
}
2020-02-22 12:53:57 +03:00
func ( m Migrator ) CurrentDatabase ( ) ( name string ) {
m . DB . Raw ( "SELECT DATABASE()" ) . Row ( ) . Scan ( & name )
2020-02-20 18:04:03 +03:00
return
}