Add test for AlterColumn

This commit is contained in:
Jinzhu 2020-05-31 10:38:01 +08:00
parent d81179557d
commit 536e4d34b0
6 changed files with 78 additions and 4 deletions

View File

@ -71,6 +71,18 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
return count > 0 return count > 0
} }
func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil {
return m.DB.Exec(
"ALTER TABLE ? ALTER COLUMN ? ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field),
).Error
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
}
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(oldName); field != nil { if field := stmt.Schema.LookUpField(oldName); field != nil {

View File

@ -16,8 +16,8 @@ func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil { if field := stmt.Schema.LookUpField(field); field != nil {
return m.DB.Exec( return m.DB.Exec(
"ALTER TABLE ? MODIFY COLUMN ? TYPE ?", "ALTER TABLE ? MODIFY COLUMN ? ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType}, clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field),
).Error ).Error
} }
return fmt.Errorf("failed to look up field with name: %s", field) return fmt.Errorf("failed to look up field with name: %s", field)

View File

@ -89,7 +89,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
} }
return "text" return "text"
case schema.Time: case schema.Time:
return "timestamp with time zone" return "timestamptz"
case schema.Bytes: case schema.Bytes:
return "bytea" return "bytea"
} }

View File

@ -38,6 +38,42 @@ func (m Migrator) HasColumn(value interface{}, name string) bool {
return count > 0 return count > 0
} }
func (m Migrator) AlterColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(name); field != nil {
var (
createSQL string
newTableName = stmt.Table + "__temp"
)
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL)
if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil {
tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ")
if err != nil {
return err
}
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
createSQL = reg.ReplaceAllString(createSQL, "?")
var columns []string
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
for _, columnType := range columnTypes {
columns = append(columns, fmt.Sprintf("`%v`", columnType.Name()))
}
createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table)
return m.DB.Exec(createSQL, m.FullDataTypeOf(field)).Error
} else {
return err
}
} else {
return fmt.Errorf("failed to alter field with name %v", name)
}
})
}
func (m Migrator) DropColumn(value interface{}, name string) error { func (m Migrator) DropColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(name); field != nil { if field := stmt.Schema.LookUpField(name); field != nil {

View File

@ -283,7 +283,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error {
if field := stmt.Schema.LookUpField(field); field != nil { if field := stmt.Schema.LookUpField(field); field != nil {
return m.DB.Exec( return m.DB.Exec(
"ALTER TABLE ? ALTER COLUMN ? TYPE ?", "ALTER TABLE ? ALTER COLUMN ? TYPE ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.DataTypeOf(field)}, clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field),
).Error ).Error
} }
return fmt.Errorf("failed to look up field with name: %s", field) return fmt.Errorf("failed to look up field with name: %s", field)

View File

@ -2,6 +2,7 @@ package tests_test
import ( import (
"math/rand" "math/rand"
"strings"
"testing" "testing"
"time" "time"
@ -124,6 +125,31 @@ func TestColumns(t *testing.T) {
t.Errorf("Failed to migrate, got %v", err) t.Errorf("Failed to migrate, got %v", err)
} }
type ColumnStruct2 struct {
gorm.Model
Name string `gorm:"size:100"`
}
if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct2{}, "Name"); err != nil {
t.Fatalf("no error should happend when alter column, but got %v", err)
}
if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil {
t.Fatalf("no error should returns for ColumnTypes")
} else {
stmt := &gorm.Statement{DB: DB}
stmt.Parse(&ColumnStruct2{})
for _, columnType := range columnTypes {
if columnType.Name() == "name" {
dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
t.Errorf("column type should be correct, name: %v, length: %v, expects: %v", columnType.Name(), columnType.DatabaseTypeName(), dataType)
}
}
}
}
type NewColumnStruct struct { type NewColumnStruct struct {
gorm.Model gorm.Model
Name string Name string