diff --git a/dialects/mssql/migrator.go b/dialects/mssql/migrator.go index d1abd0e9..42a6b9b9 100644 --- a/dialects/mssql/migrator.go +++ b/dialects/mssql/migrator.go @@ -41,6 +41,23 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { return count > 0 } +func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(oldName); field != nil { + oldName = field.DBName + } + + if field := stmt.Schema.LookUpField(newName); field != nil { + newName = field.DBName + } + + return m.DB.Exec( + "sp_rename @objname = ?, @newname = ?, @objtype = 'COLUMN';", + fmt.Sprintf("%s.%s", stmt.Table, oldName), clause.Column{Name: newName}, + ).Error + }) +} + func (m Migrator) HasIndex(value interface{}, name string) bool { var count int m.RunWithValue(value, func(stmt *gorm.Statement) error { diff --git a/dialects/sqlite/migrator.go b/dialects/sqlite/migrator.go index 5f3671b4..e36dc5e7 100644 --- a/dialects/sqlite/migrator.go +++ b/dialects/sqlite/migrator.go @@ -2,6 +2,7 @@ package sqlite import ( "fmt" + "regexp" "strings" "github.com/jinzhu/gorm" @@ -22,11 +23,10 @@ func (m Migrator) HasTable(value interface{}) bool { return count > 0 } -func (m Migrator) HasColumn(value interface{}, field string) bool { +func (m Migrator) HasColumn(value interface{}, name string) bool { var count int m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - name := field - if field := stmt.Schema.LookUpField(field); field != nil { + if field := stmt.Schema.LookUpField(name); field != nil { name = field.DBName } @@ -38,6 +38,45 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { return count > 0 } +func (m Migrator) DropColumn(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(name); field != nil { + name = field.DBName + } + + var ( + createSQL string + newTableName = stmt.Table + "__temp" + ) + + m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL) + + if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil { + tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ") + if err != nil { + return err + } + + createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) + createSQL = reg.ReplaceAllString(createSQL, "") + + var columns []string + columnTypes, _ := m.DB.Migrator().ColumnTypes(value) + for _, columnType := range columnTypes { + if columnType.Name() != name { + columns = append(columns, fmt.Sprintf("`%v`", columnType.Name())) + } + } + + createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table) + + return m.DB.Exec(createSQL).Error + } else { + return err + } + }) +} + func (m Migrator) CreateConstraint(interface{}, string) error { return gorm.ErrNotImplemented } diff --git a/gorm.go b/gorm.go index 9adc0858..6b2a6d75 100644 --- a/gorm.go +++ b/gorm.go @@ -66,7 +66,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { } if config.NowFunc == nil { - config.NowFunc = func() time.Time { return time.Now().Local().Round(time.Second) } + config.NowFunc = func() time.Time { return time.Now().Local() } } if dialector != nil { diff --git a/migrator/migrator.go b/migrator/migrator.go index 8f35cbea..d41646f4 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -243,14 +243,15 @@ func (m Migrator) AddColumn(value interface{}, field string) error { }) } -func (m Migrator) DropColumn(value interface{}, field string) error { +func (m Migrator) DropColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(field); field != nil { - return m.DB.Exec( - "ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, - ).Error + if field := stmt.Schema.LookUpField(name); field != nil { + name = field.DBName } - return fmt.Errorf("failed to look up field with name: %s", field) + + return m.DB.Exec( + "ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: name}, + ).Error }) } @@ -284,16 +285,20 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { return count > 0 } -func (m Migrator) RenameColumn(value interface{}, oldName, field string) error { +func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(field); field != nil { - oldName = m.DB.NamingStrategy.ColumnName(stmt.Table, oldName) - return m.DB.Exec( - "ALTER TABLE ? RENAME COLUMN ? TO ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: field.DBName}, - ).Error + if field := stmt.Schema.LookUpField(oldName); field != nil { + oldName = field.DBName } - return fmt.Errorf("failed to look up field with name: %s", field) + + if field := stmt.Schema.LookUpField(newName); field != nil { + newName = field.DBName + } + + return m.DB.Exec( + "ALTER TABLE ? RENAME COLUMN ? TO ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, + ).Error }) } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 00025c58..2252d09d 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -98,18 +98,38 @@ func TestColumns(t *testing.T) { } if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { - t.Errorf("Failed to add column, got %v", err) + t.Fatalf("Failed to add column, got %v", err) } if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { - t.Errorf("Failed to find added column") + t.Fatalf("Failed to find added column") } if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "NewName"); err != nil { - t.Errorf("Failed to add column, got %v", err) + t.Fatalf("Failed to add column, got %v", err) } if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { - t.Errorf("Found deleted column") + t.Fatalf("Found deleted column") + } + + if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") { + t.Fatalf("Failed to found renamed column") + } + + if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "new_new_name"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") { + t.Fatalf("Found deleted column") } } diff --git a/tests/utils.go b/tests/utils.go index 0add8143..7cc6d2bc 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -88,8 +88,8 @@ func AssertEqual(t *testing.T, got, expect interface{}) { if curTime, ok := got.(time.Time); ok { format := "2006-01-02T15:04:05Z07:00" - if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) { - t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time).Round(time.Second).Format(format), curTime.Round(time.Second).Format(format)) + if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) && curTime.Truncate(time.Second).Format(format) != expect.(time.Time).Truncate(time.Second).Format(format) { + t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime) } } else if fmt.Sprint(got) != fmt.Sprint(expect) { t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got)