mirror of https://github.com/go-gorm/gorm.git
Fix RenameColumn for mssql, DropColumn for sqlite
This commit is contained in:
parent
58bc0f51c1
commit
24285060d5
|
@ -41,6 +41,23 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
|
||||||
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
|
if field := stmt.Schema.LookUpField(oldName); field != nil {
|
||||||
|
oldName = field.DBName
|
||||||
|
}
|
||||||
|
|
||||||
|
if field := stmt.Schema.LookUpField(newName); field != nil {
|
||||||
|
newName = field.DBName
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.DB.Exec(
|
||||||
|
"sp_rename @objname = ?, @newname = ?, @objtype = 'COLUMN';",
|
||||||
|
fmt.Sprintf("%s.%s", stmt.Table, oldName), clause.Column{Name: newName},
|
||||||
|
).Error
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (m Migrator) HasIndex(value interface{}, name string) bool {
|
func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||||||
var count int
|
var count int
|
||||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
|
|
|
@ -2,6 +2,7 @@ package sqlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
|
@ -22,11 +23,10 @@ func (m Migrator) HasTable(value interface{}) bool {
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Migrator) HasColumn(value interface{}, field string) bool {
|
func (m Migrator) HasColumn(value interface{}, name string) bool {
|
||||||
var count int
|
var count int
|
||||||
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
name := field
|
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
|
||||||
name = field.DBName
|
name = field.DBName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,6 +38,45 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m Migrator) DropColumn(value interface{}, name string) error {
|
||||||
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
|
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||||
|
name = field.DBName
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
createSQL string
|
||||||
|
newTableName = stmt.Table + "__temp"
|
||||||
|
)
|
||||||
|
|
||||||
|
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL)
|
||||||
|
|
||||||
|
if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil {
|
||||||
|
tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
|
||||||
|
createSQL = reg.ReplaceAllString(createSQL, "")
|
||||||
|
|
||||||
|
var columns []string
|
||||||
|
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
|
||||||
|
for _, columnType := range columnTypes {
|
||||||
|
if columnType.Name() != name {
|
||||||
|
columns = append(columns, fmt.Sprintf("`%v`", columnType.Name()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table)
|
||||||
|
|
||||||
|
return m.DB.Exec(createSQL).Error
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (m Migrator) CreateConstraint(interface{}, string) error {
|
func (m Migrator) CreateConstraint(interface{}, string) error {
|
||||||
return gorm.ErrNotImplemented
|
return gorm.ErrNotImplemented
|
||||||
}
|
}
|
||||||
|
|
2
gorm.go
2
gorm.go
|
@ -66,7 +66,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.NowFunc == nil {
|
if config.NowFunc == nil {
|
||||||
config.NowFunc = func() time.Time { return time.Now().Local().Round(time.Second) }
|
config.NowFunc = func() time.Time { return time.Now().Local() }
|
||||||
}
|
}
|
||||||
|
|
||||||
if dialector != nil {
|
if dialector != nil {
|
||||||
|
|
|
@ -243,14 +243,15 @@ func (m Migrator) AddColumn(value interface{}, field string) error {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Migrator) DropColumn(value interface{}, field string) error {
|
func (m Migrator) DropColumn(value interface{}, name string) error {
|
||||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||||
return m.DB.Exec(
|
name = field.DBName
|
||||||
"ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName},
|
|
||||||
).Error
|
|
||||||
}
|
}
|
||||||
return fmt.Errorf("failed to look up field with name: %s", field)
|
|
||||||
|
return m.DB.Exec(
|
||||||
|
"ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: name},
|
||||||
|
).Error
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -284,16 +285,20 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Migrator) RenameColumn(value interface{}, oldName, field string) error {
|
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
|
||||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
if field := stmt.Schema.LookUpField(oldName); field != nil {
|
||||||
oldName = m.DB.NamingStrategy.ColumnName(stmt.Table, oldName)
|
oldName = field.DBName
|
||||||
return m.DB.Exec(
|
|
||||||
"ALTER TABLE ? RENAME COLUMN ? TO ?",
|
|
||||||
clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: field.DBName},
|
|
||||||
).Error
|
|
||||||
}
|
}
|
||||||
return fmt.Errorf("failed to look up field with name: %s", field)
|
|
||||||
|
if field := stmt.Schema.LookUpField(newName); field != nil {
|
||||||
|
newName = field.DBName
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.DB.Exec(
|
||||||
|
"ALTER TABLE ? RENAME COLUMN ? TO ?",
|
||||||
|
clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName},
|
||||||
|
).Error
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -98,18 +98,38 @@ func TestColumns(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil {
|
if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil {
|
||||||
t.Errorf("Failed to add column, got %v", err)
|
t.Fatalf("Failed to add column, got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") {
|
if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") {
|
||||||
t.Errorf("Failed to find added column")
|
t.Fatalf("Failed to find added column")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "NewName"); err != nil {
|
if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "NewName"); err != nil {
|
||||||
t.Errorf("Failed to add column, got %v", err)
|
t.Fatalf("Failed to add column, got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") {
|
if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") {
|
||||||
t.Errorf("Found deleted column")
|
t.Fatalf("Found deleted column")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil {
|
||||||
|
t.Fatalf("Failed to add column, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil {
|
||||||
|
t.Fatalf("Failed to add column, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") {
|
||||||
|
t.Fatalf("Failed to found renamed column")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "new_new_name"); err != nil {
|
||||||
|
t.Fatalf("Failed to add column, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") {
|
||||||
|
t.Fatalf("Found deleted column")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -88,8 +88,8 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
|
||||||
if curTime, ok := got.(time.Time); ok {
|
if curTime, ok := got.(time.Time); ok {
|
||||||
format := "2006-01-02T15:04:05Z07:00"
|
format := "2006-01-02T15:04:05Z07:00"
|
||||||
|
|
||||||
if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) {
|
if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) && curTime.Truncate(time.Second).Format(format) != expect.(time.Time).Truncate(time.Second).Format(format) {
|
||||||
t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time).Round(time.Second).Format(format), curTime.Round(time.Second).Format(format))
|
t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime)
|
||||||
}
|
}
|
||||||
} else if fmt.Sprint(got) != fmt.Sprint(expect) {
|
} else if fmt.Sprint(got) != fmt.Sprint(expect) {
|
||||||
t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got)
|
t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got)
|
||||||
|
|
Loading…
Reference in New Issue