diff --git a/migrator.go b/migrator.go index b6d273e7..a5ea4d8f 100644 --- a/migrator.go +++ b/migrator.go @@ -33,6 +33,7 @@ type Migrator interface { AddColumn(dst interface{}, field string) error DropColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error + HasColumn(dst interface{}, field string) bool RenameColumn(dst interface{}, oldName, field string) error ColumnTypes(dst interface{}) ([]*sql.ColumnType, error) @@ -43,6 +44,7 @@ type Migrator interface { // Constraints CreateConstraint(dst interface{}, name string) error DropConstraint(dst interface{}, name string) error + HasConstraint(dst interface{}, name string) bool // Indexes CreateIndex(dst interface{}, name string) error diff --git a/migrator/migrator.go b/migrator/migrator.go index fc93954e..7e749037 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -3,9 +3,12 @@ package migrator import ( "database/sql" "fmt" + "reflect" + "strings" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/schema" ) // Migrator migrator struct @@ -34,19 +37,133 @@ func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement // AutoMigrate func (migrator Migrator) AutoMigrate(values ...interface{}) error { - // if has table - // not -> create table - // check columns -> add column, change column type - // check foreign keys -> create indexes - // check indexes -> create indexes + // TODO smart migrate data type - return gorm.ErrNotImplemented + for _, value := range values { + if !migrator.DB.Migrator().HasTable(value) { + if err := migrator.DB.Migrator().CreateTable(value); err != nil { + return err + } + } else { + if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + for _, field := range stmt.Schema.FieldsByDBName { + if !migrator.DB.Migrator().HasColumn(value, field.DBName) { + if err := migrator.DB.Migrator().AddColumn(value, field.DBName); err != nil { + return err + } + } + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil { + if !migrator.DB.Migrator().HasConstraint(value, constraint.Name) { + if err := migrator.DB.Migrator().CreateConstraint(value, constraint.Name); err != nil { + return err + } + } + } + + for _, chk := range stmt.Schema.ParseCheckConstraints() { + if !migrator.DB.Migrator().HasConstraint(value, chk.Name) { + if err := migrator.DB.Migrator().CreateConstraint(value, chk.Name); err != nil { + return err + } + } + } + + // create join table + joinValue := reflect.New(rel.JoinTable.ModelType).Interface() + if !migrator.DB.Migrator().HasTable(joinValue) { + defer migrator.DB.Migrator().CreateTable(joinValue) + } + } + return nil + }); err != nil { + return err + } + } + } + + return nil } func (migrator Migrator) CreateTable(values ...interface{}) error { - // migrate - // create join table - return gorm.ErrNotImplemented + for _, value := range values { + if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + var ( + createTableSQL = "CREATE TABLE ? (" + values = []interface{}{clause.Table{Name: stmt.Table}} + hasPrimaryKeyInDataType bool + ) + + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.FieldsByDBName[dbName] + createTableSQL += fmt.Sprintf("? ?") + hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(field.DBDataType), "PRIMARY KEY") + values = append(values, clause.Column{Name: dbName}, clause.Expr{SQL: field.DBDataType}) + + if field.AutoIncrement { + createTableSQL += " AUTO_INCREMENT" + } + + if field.NotNull { + createTableSQL += " NOT NULL" + } + + if field.Unique { + createTableSQL += " UNIQUE" + } + + if field.DefaultValue != "" { + createTableSQL += " DEFAULT ?" + values = append(values, clause.Expr{SQL: field.DefaultValue}) + } + createTableSQL += "," + } + + if !hasPrimaryKeyInDataType { + createTableSQL += "PRIMARY KEY ?," + primaryKeys := []interface{}{} + for _, field := range stmt.Schema.PrimaryFields { + primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName}) + } + + values = append(values, primaryKeys) + } + + for _, idx := range stmt.Schema.ParseIndexes() { + createTableSQL += "INDEX ? ?," + values = append(values, clause.Expr{SQL: idx.Name}, buildIndexOptions(idx.Fields, stmt)) + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil { + sql, vars := buildConstraint(constraint) + createTableSQL += sql + "," + values = append(values, vars...) + } + + // create join table + joinValue := reflect.New(rel.JoinTable.ModelType).Interface() + if !migrator.DB.Migrator().HasTable(joinValue) { + defer migrator.DB.Migrator().CreateTable(joinValue) + } + } + + for _, chk := range stmt.Schema.ParseCheckConstraints() { + createTableSQL += "CONSTRAINT ? CHECK ?," + values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) + } + + createTableSQL = strings.TrimSuffix(createTableSQL, ",") + + createTableSQL += ")" + return migrator.DB.Exec(createTableSQL, values...).Error + }); err != nil { + return err + } + } + return nil } func (migrator Migrator) DropTable(values ...interface{}) error { @@ -115,6 +232,27 @@ func (migrator Migrator) AlterColumn(value interface{}, field string) error { }) } +func (migrator Migrator) HasColumn(value interface{}, field string) bool { + var count int64 + migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := migrator.DB.Migrator().CurrentDatabase() + name := field + if field := stmt.Schema.LookUpField(field); field != nil { + name = field.DBName + } + + return migrator.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", + currentDatabase, stmt.Table, name, + ).Scan(&count).Error + }) + + if count != 0 { + return true + } + return false +} + func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { @@ -140,6 +278,28 @@ func (migrator Migrator) DropView(name string) error { return gorm.ErrNotImplemented } +func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) { + sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" + if constraint.OnDelete != "" { + sql += " ON DELETE " + constraint.OnDelete + } + + if constraint.OnUpdate != "" { + sql += " ON UPDATE " + constraint.OnUpdate + } + + var foreignKeys, references []interface{} + for _, field := range constraint.ForeignKeys { + foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) + } + + for _, field := range constraint.References { + references = append(references, clause.Column{Name: field.DBName}) + } + results = append(results, constraint.Name, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) + return +} + func (migrator Migrator) CreateConstraint(value interface{}, name string) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { checkConstraints := stmt.Schema.ParseCheckConstraints() @@ -152,26 +312,8 @@ func (migrator Migrator) CreateConstraint(value interface{}, name string) error for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { - sql := "ALTER TABLE ? ADD CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" - if constraint.OnDelete != "" { - sql += " ON DELETE " + constraint.OnDelete - } - - if constraint.OnUpdate != "" { - sql += " ON UPDATE " + constraint.OnUpdate - } - var foreignKeys, references []interface{} - for _, field := range constraint.ForeignKeys { - foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) - } - - for _, field := range constraint.References { - references = append(references, clause.Column{Name: field.DBName}) - } - - return migrator.DB.Exec( - sql, clause.Table{Name: stmt.Table}, clause.Column{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references, - ).Error + sql, values := buildConstraint(constraint) + return migrator.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error } } @@ -205,27 +347,47 @@ func (migrator Migrator) DropConstraint(value interface{}, name string) error { }) } +func (migrator Migrator) HasConstraint(value interface{}, name string) bool { + var count int64 + migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := migrator.DB.Migrator().CurrentDatabase() + return migrator.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.referential_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?", + currentDatabase, stmt.Table, name, + ).Scan(&count).Error + }) + + if count != 0 { + return true + } + return false +} + +func buildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { + for _, opt := range opts { + str := stmt.Quote(opt.DBName) + if opt.Expression != "" { + str = opt.Expression + } else if opt.Length > 0 { + str += fmt.Sprintf("(%d)", opt.Length) + } + + if opt.Sort != "" { + str += " " + opt.Sort + } + results = append(results, clause.Expr{SQL: str}) + } + return +} + func (migrator Migrator) CreateIndex(value interface{}, name string) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { err := fmt.Errorf("failed to create index with name %v", name) indexes := stmt.Schema.ParseIndexes() if idx, ok := indexes[name]; ok { - fields := []interface{}{} - for _, field := range idx.Fields { - str := stmt.Quote(field.DBName) - if field.Expression != "" { - str = field.Expression - } else if field.Length > 0 { - str += fmt.Sprintf("(%d)", field.Length) - } - - if field.Sort != "" { - str += " " + field.Sort - } - fields = append(fields, clause.Expr{SQL: str}) - } - values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, fields} + opts := buildIndexOptions(idx.Fields, stmt) + values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} createIndexSQL := "CREATE " if idx.Class != "" {