2020-01-28 18:01:35 +03:00
package migrator
2020-02-20 18:04:03 +03:00
import (
"database/sql"
"fmt"
"github.com/jinzhu/gorm"
2020-02-22 06:15:51 +03:00
"github.com/jinzhu/gorm/clause"
2020-02-20 18:04:03 +03:00
)
2020-01-28 18:01:35 +03:00
// Migrator migrator struct
type Migrator struct {
* Config
}
// Config schema config
type Config struct {
CheckExistsBeforeDropping bool
DB * gorm . DB
}
2020-02-20 18:04:03 +03:00
func ( migrator Migrator ) RunWithValue ( value interface { } , fc func ( * gorm . Statement ) error ) error {
stmt := migrator . DB . Statement
if stmt == nil {
stmt = & gorm . Statement { DB : migrator . DB }
}
if err := stmt . Parse ( value ) ; err != nil {
return err
}
return fc ( stmt )
}
// AutoMigrate
func ( migrator Migrator ) AutoMigrate ( values ... interface { } ) error {
2020-02-22 06:15:51 +03:00
// if has table
// not -> create table
// check columns -> add column, change column type
// check foreign keys -> create indexes
// check indexes -> create indexes
2020-02-20 18:04:03 +03:00
return gorm . ErrNotImplemented
}
func ( migrator Migrator ) CreateTable ( values ... interface { } ) error {
2020-02-22 06:15:51 +03:00
// migrate
// create join table
2020-02-20 18:04:03 +03:00
return gorm . ErrNotImplemented
}
func ( migrator Migrator ) DropTable ( values ... interface { } ) error {
for _ , value := range values {
if err := migrator . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
2020-02-22 06:15:51 +03:00
return migrator . DB . Exec ( "DROP TABLE ?" , clause . Table { Name : stmt . Table } ) . Error
2020-02-20 18:04:03 +03:00
} ) ; err != nil {
return err
}
}
return nil
}
func ( migrator Migrator ) HasTable ( values ... interface { } ) bool {
var count int64
for _ , value := range values {
err := migrator . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
currentDatabase := migrator . DB . Migrator ( ) . CurrentDatabase ( )
return migrator . DB . Raw ( "SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?" , currentDatabase , stmt . Table , "BASE TABLE" ) . Scan ( & count ) . Error
} )
if err != nil || count == 0 {
return false
}
}
return true
}
func ( migrator Migrator ) RenameTable ( oldName , newName string ) error {
return migrator . DB . Exec ( "RENAME TABLE ? TO ?" , oldName , newName ) . Error
}
func ( migrator Migrator ) AddColumn ( value interface { } , field string ) error {
return migrator . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
if field := stmt . Schema . LookUpField ( field ) ; field != nil {
2020-02-22 06:15:51 +03:00
return migrator . DB . Exec (
"ALTER TABLE ? ADD ? ?" ,
clause . Table { Name : stmt . Table } , clause . Column { Name : field . DBName } , clause . Expr { SQL : field . DBDataType } ,
) . Error
2020-02-20 18:04:03 +03:00
}
return fmt . Errorf ( "failed to look up field with name: %s" , field )
} )
}
func ( migrator Migrator ) DropColumn ( value interface { } , field string ) error {
return migrator . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
if field := stmt . Schema . LookUpField ( field ) ; field != nil {
2020-02-22 06:15:51 +03:00
return migrator . DB . Exec (
"ALTER TABLE ? DROP COLUMN ?" , clause . Table { Name : stmt . Table } , clause . Column { Name : field . DBName } ,
) . Error
2020-02-20 18:04:03 +03:00
}
return fmt . Errorf ( "failed to look up field with name: %s" , field )
} )
}
func ( migrator Migrator ) AlterColumn ( value interface { } , field string ) error {
return migrator . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
if field := stmt . Schema . LookUpField ( field ) ; field != nil {
2020-02-22 06:15:51 +03:00
return migrator . DB . Exec (
"ALTER TABLE ? ALTER COLUMN ? TYPE ?" ,
clause . Table { Name : stmt . Table } , clause . Column { Name : field . DBName } , clause . Expr { SQL : field . DBDataType } ,
) . Error
2020-02-20 18:04:03 +03:00
}
return fmt . Errorf ( "failed to look up field with name: %s" , field )
} )
}
func ( migrator Migrator ) RenameColumn ( value interface { } , oldName , field string ) error {
return migrator . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
if field := stmt . Schema . LookUpField ( field ) ; field != nil {
oldName = migrator . DB . NamingStrategy . ColumnName ( stmt . Table , oldName )
2020-02-22 06:15:51 +03:00
return migrator . DB . Exec (
"ALTER TABLE ? RENAME COLUMN ? TO ?" ,
clause . Table { Name : stmt . Table } , clause . Column { Name : oldName } , clause . Column { Name : field . DBName } ,
) . Error
2020-02-20 18:04:03 +03:00
}
return fmt . Errorf ( "failed to look up field with name: %s" , field )
} )
}
func ( migrator Migrator ) ColumnTypes ( value interface { } ) ( [ ] * sql . ColumnType , error ) {
return nil , gorm . ErrNotImplemented
}
func ( migrator Migrator ) CreateView ( name string , option gorm . ViewOption ) error {
return gorm . ErrNotImplemented
}
func ( migrator Migrator ) DropView ( name string ) error {
return gorm . ErrNotImplemented
}
func ( migrator Migrator ) CreateConstraint ( value interface { } , name string ) error {
2020-02-22 06:15:51 +03:00
return migrator . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
checkConstraints := stmt . Schema . ParseCheckConstraints ( )
if chk , ok := checkConstraints [ name ] ; ok {
return migrator . DB . Exec (
"ALTER TABLE ? ADD CONSTRAINT ? CHECK ?" ,
clause . Table { Name : stmt . Table } , clause . Column { Name : chk . Name } , clause . Expr { SQL : chk . Constraint } ,
) . Error
}
for _ , rel := range stmt . Schema . Relationships . Relations {
if constraint := rel . ParseConstraint ( ) ; constraint != nil && constraint . Name == name {
sql := "ALTER TABLE ? ADD CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint . OnDelete != "" {
sql += " ON DELETE " + constraint . OnDelete
}
if constraint . OnUpdate != "" {
sql += " ON UPDATE " + constraint . OnUpdate
}
var foreignKeys , references [ ] interface { }
for _ , field := range constraint . ForeignKeys {
foreignKeys = append ( foreignKeys , clause . Column { Name : field . DBName } )
}
for _ , field := range constraint . References {
references = append ( references , clause . Column { Name : field . DBName } )
}
return migrator . DB . Exec (
sql , clause . Table { Name : stmt . Table } , clause . Column { Name : constraint . Name } , foreignKeys , clause . Table { Name : constraint . ReferenceSchema . Table } , references ,
) . Error
}
}
err := fmt . Errorf ( "failed to create constraint with name %v" , name )
if field := stmt . Schema . LookUpField ( name ) ; field != nil {
for _ , cc := range checkConstraints {
if err = migrator . CreateIndex ( value , cc . Name ) ; err != nil {
return err
}
}
for _ , rel := range stmt . Schema . Relationships . Relations {
if constraint := rel . ParseConstraint ( ) ; constraint != nil && constraint . Field == field {
if err = migrator . CreateIndex ( value , constraint . Name ) ; err != nil {
return err
}
}
}
}
return err
} )
2020-02-20 18:04:03 +03:00
}
func ( migrator Migrator ) DropConstraint ( value interface { } , name string ) error {
return migrator . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
2020-02-22 06:15:51 +03:00
return migrator . DB . Exec (
"ALTER TABLE ? DROP CONSTRAINT ?" ,
clause . Table { Name : stmt . Table } , clause . Column { Name : name } ,
) . Error
2020-02-20 18:04:03 +03:00
} )
}
func ( migrator Migrator ) CreateIndex ( value interface { } , name string ) error {
2020-02-22 06:15:51 +03:00
return migrator . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
err := fmt . Errorf ( "failed to create index with name %v" , name )
indexes := stmt . Schema . ParseIndexes ( )
if idx , ok := indexes [ name ] ; ok {
fields := [ ] interface { } { }
for _ , field := range idx . Fields {
str := stmt . Quote ( field . DBName )
if field . Expression != "" {
str = field . Expression
} else if field . Length > 0 {
str += fmt . Sprintf ( "(%d)" , field . Length )
}
if field . Sort != "" {
str += " " + field . Sort
}
fields = append ( fields , clause . Expr { SQL : str } )
}
values := [ ] interface { } { clause . Column { Name : idx . Name } , clause . Table { Name : stmt . Table } , fields }
createIndexSQL := "CREATE "
if idx . Class != "" {
createIndexSQL += idx . Class + " "
}
createIndexSQL += "INDEX ? ON ??"
if idx . Comment != "" {
values = append ( values , idx . Comment )
createIndexSQL += " COMMENT ?"
}
if idx . Type != "" {
createIndexSQL += " USING " + idx . Type
}
return migrator . DB . Raw ( createIndexSQL , values ... ) . Error
} else if field := stmt . Schema . LookUpField ( name ) ; field != nil {
for _ , idx := range indexes {
for _ , idxOpt := range idx . Fields {
if idxOpt . Field == field {
if err = migrator . CreateIndex ( value , idx . Name ) ; err != nil {
return err
}
}
}
}
}
return err
} )
2020-02-20 18:04:03 +03:00
}
func ( migrator Migrator ) DropIndex ( value interface { } , name string ) error {
return migrator . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
2020-02-22 06:15:51 +03:00
return migrator . DB . Raw ( "DROP INDEX ? ON ?" , clause . Column { Name : name } , clause . Table { Name : stmt . Table } ) . Error
2020-02-20 18:04:03 +03:00
} )
}
func ( migrator Migrator ) HasIndex ( value interface { } , name string ) bool {
var count int64
migrator . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
currentDatabase := migrator . DB . Migrator ( ) . CurrentDatabase ( )
2020-02-22 06:15:51 +03:00
return migrator . DB . Raw (
"SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?" ,
currentDatabase , stmt . Table , name ,
) . Scan ( & count ) . Error
2020-02-20 18:04:03 +03:00
} )
if count != 0 {
return true
}
return false
}
func ( migrator Migrator ) RenameIndex ( value interface { } , oldName , newName string ) error {
return migrator . RunWithValue ( value , func ( stmt * gorm . Statement ) error {
2020-02-22 06:15:51 +03:00
return migrator . DB . Exec (
"ALTER TABLE ? RENAME INDEX ? TO ?" ,
clause . Table { Name : stmt . Table } , clause . Column { Name : oldName } , clause . Column { Name : newName } ,
) . Error
2020-02-20 18:04:03 +03:00
} )
}
func ( migrator Migrator ) CurrentDatabase ( ) ( name string ) {
migrator . DB . Raw ( "SELECT DATABASE()" ) . Scan ( & name )
return
}