Almost finish Migrator

This commit is contained in:
Jinzhu 2020-02-22 13:09:57 +08:00
parent 0be4817ff9
commit 0801cdf164
2 changed files with 208 additions and 44 deletions

View File

@ -33,6 +33,7 @@ type Migrator interface {
AddColumn(dst interface{}, field string) error AddColumn(dst interface{}, field string) error
DropColumn(dst interface{}, field string) error DropColumn(dst interface{}, field string) error
AlterColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error
HasColumn(dst interface{}, field string) bool
RenameColumn(dst interface{}, oldName, field string) error RenameColumn(dst interface{}, oldName, field string) error
ColumnTypes(dst interface{}) ([]*sql.ColumnType, error) ColumnTypes(dst interface{}) ([]*sql.ColumnType, error)
@ -43,6 +44,7 @@ type Migrator interface {
// Constraints // Constraints
CreateConstraint(dst interface{}, name string) error CreateConstraint(dst interface{}, name string) error
DropConstraint(dst interface{}, name string) error DropConstraint(dst interface{}, name string) error
HasConstraint(dst interface{}, name string) bool
// Indexes // Indexes
CreateIndex(dst interface{}, name string) error CreateIndex(dst interface{}, name string) error

View File

@ -3,9 +3,12 @@ package migrator
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"reflect"
"strings"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/schema"
) )
// Migrator migrator struct // Migrator migrator struct
@ -34,19 +37,133 @@ func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement
// AutoMigrate // AutoMigrate
func (migrator Migrator) AutoMigrate(values ...interface{}) error { func (migrator Migrator) AutoMigrate(values ...interface{}) error {
// if has table // TODO smart migrate data type
// not -> create table
// check columns -> add column, change column type
// check foreign keys -> create indexes
// check indexes -> create indexes
return gorm.ErrNotImplemented for _, value := range values {
if !migrator.DB.Migrator().HasTable(value) {
if err := migrator.DB.Migrator().CreateTable(value); err != nil {
return err
}
} else {
if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
for _, field := range stmt.Schema.FieldsByDBName {
if !migrator.DB.Migrator().HasColumn(value, field.DBName) {
if err := migrator.DB.Migrator().AddColumn(value, field.DBName); err != nil {
return err
}
}
}
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil {
if !migrator.DB.Migrator().HasConstraint(value, constraint.Name) {
if err := migrator.DB.Migrator().CreateConstraint(value, constraint.Name); err != nil {
return err
}
}
}
for _, chk := range stmt.Schema.ParseCheckConstraints() {
if !migrator.DB.Migrator().HasConstraint(value, chk.Name) {
if err := migrator.DB.Migrator().CreateConstraint(value, chk.Name); err != nil {
return err
}
}
}
// create join table
joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
if !migrator.DB.Migrator().HasTable(joinValue) {
defer migrator.DB.Migrator().CreateTable(joinValue)
}
}
return nil
}); err != nil {
return err
}
}
}
return nil
} }
func (migrator Migrator) CreateTable(values ...interface{}) error { func (migrator Migrator) CreateTable(values ...interface{}) error {
// migrate for _, value := range values {
// create join table if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
return gorm.ErrNotImplemented var (
createTableSQL = "CREATE TABLE ? ("
values = []interface{}{clause.Table{Name: stmt.Table}}
hasPrimaryKeyInDataType bool
)
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[dbName]
createTableSQL += fmt.Sprintf("? ?")
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(field.DBDataType), "PRIMARY KEY")
values = append(values, clause.Column{Name: dbName}, clause.Expr{SQL: field.DBDataType})
if field.AutoIncrement {
createTableSQL += " AUTO_INCREMENT"
}
if field.NotNull {
createTableSQL += " NOT NULL"
}
if field.Unique {
createTableSQL += " UNIQUE"
}
if field.DefaultValue != "" {
createTableSQL += " DEFAULT ?"
values = append(values, clause.Expr{SQL: field.DefaultValue})
}
createTableSQL += ","
}
if !hasPrimaryKeyInDataType {
createTableSQL += "PRIMARY KEY ?,"
primaryKeys := []interface{}{}
for _, field := range stmt.Schema.PrimaryFields {
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
}
values = append(values, primaryKeys)
}
for _, idx := range stmt.Schema.ParseIndexes() {
createTableSQL += "INDEX ? ?,"
values = append(values, clause.Expr{SQL: idx.Name}, buildIndexOptions(idx.Fields, stmt))
}
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil {
sql, vars := buildConstraint(constraint)
createTableSQL += sql + ","
values = append(values, vars...)
}
// create join table
joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
if !migrator.DB.Migrator().HasTable(joinValue) {
defer migrator.DB.Migrator().CreateTable(joinValue)
}
}
for _, chk := range stmt.Schema.ParseCheckConstraints() {
createTableSQL += "CONSTRAINT ? CHECK ?,"
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
}
createTableSQL = strings.TrimSuffix(createTableSQL, ",")
createTableSQL += ")"
return migrator.DB.Exec(createTableSQL, values...).Error
}); err != nil {
return err
}
}
return nil
} }
func (migrator Migrator) DropTable(values ...interface{}) error { func (migrator Migrator) DropTable(values ...interface{}) error {
@ -115,6 +232,27 @@ func (migrator Migrator) AlterColumn(value interface{}, field string) error {
}) })
} }
func (migrator Migrator) HasColumn(value interface{}, field string) bool {
var count int64
migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := migrator.DB.Migrator().CurrentDatabase()
name := field
if field := stmt.Schema.LookUpField(field); field != nil {
name = field.DBName
}
return migrator.DB.Raw(
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
currentDatabase, stmt.Table, name,
).Scan(&count).Error
})
if count != 0 {
return true
}
return false
}
func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) error { func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) error {
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil { if field := stmt.Schema.LookUpField(field); field != nil {
@ -140,6 +278,28 @@ func (migrator Migrator) DropView(name string) error {
return gorm.ErrNotImplemented return gorm.ErrNotImplemented
} }
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
var foreignKeys, references []interface{}
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
results = append(results, constraint.Name, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
}
func (migrator Migrator) CreateConstraint(value interface{}, name string) error { func (migrator Migrator) CreateConstraint(value interface{}, name string) error {
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
checkConstraints := stmt.Schema.ParseCheckConstraints() checkConstraints := stmt.Schema.ParseCheckConstraints()
@ -152,26 +312,8 @@ func (migrator Migrator) CreateConstraint(value interface{}, name string) error
for _, rel := range stmt.Schema.Relationships.Relations { for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
sql := "ALTER TABLE ? ADD CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" sql, values := buildConstraint(constraint)
if constraint.OnDelete != "" { return migrator.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
var foreignKeys, references []interface{}
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
return migrator.DB.Exec(
sql, clause.Table{Name: stmt.Table}, clause.Column{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references,
).Error
} }
} }
@ -205,27 +347,47 @@ func (migrator Migrator) DropConstraint(value interface{}, name string) error {
}) })
} }
func (migrator Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := migrator.DB.Migrator().CurrentDatabase()
return migrator.DB.Raw(
"SELECT count(*) FROM INFORMATION_SCHEMA.referential_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?",
currentDatabase, stmt.Table, name,
).Scan(&count).Error
})
if count != 0 {
return true
}
return false
}
func buildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
for _, opt := range opts {
str := stmt.Quote(opt.DBName)
if opt.Expression != "" {
str = opt.Expression
} else if opt.Length > 0 {
str += fmt.Sprintf("(%d)", opt.Length)
}
if opt.Sort != "" {
str += " " + opt.Sort
}
results = append(results, clause.Expr{SQL: str})
}
return
}
func (migrator Migrator) CreateIndex(value interface{}, name string) error { func (migrator Migrator) CreateIndex(value interface{}, name string) error {
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
err := fmt.Errorf("failed to create index with name %v", name) err := fmt.Errorf("failed to create index with name %v", name)
indexes := stmt.Schema.ParseIndexes() indexes := stmt.Schema.ParseIndexes()
if idx, ok := indexes[name]; ok { if idx, ok := indexes[name]; ok {
fields := []interface{}{} opts := buildIndexOptions(idx.Fields, stmt)
for _, field := range idx.Fields { values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
str := stmt.Quote(field.DBName)
if field.Expression != "" {
str = field.Expression
} else if field.Length > 0 {
str += fmt.Sprintf("(%d)", field.Length)
}
if field.Sort != "" {
str += " " + field.Sort
}
fields = append(fields, clause.Expr{SQL: str})
}
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, fields}
createIndexSQL := "CREATE " createIndexSQL := "CREATE "
if idx.Class != "" { if idx.Class != "" {