mirror of https://github.com/go-gorm/gorm.git
Add test for AlterColumn
This commit is contained in:
parent
d81179557d
commit
536e4d34b0
|
@ -71,6 +71,18 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
|
|||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) AlterColumn(value interface{}, field string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? ALTER COLUMN ? ?",
|
||||
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field),
|
||||
).Error
|
||||
}
|
||||
return fmt.Errorf("failed to look up field with name: %s", field)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(oldName); field != nil {
|
||||
|
|
|
@ -16,8 +16,8 @@ func (m Migrator) AlterColumn(value interface{}, field string) error {
|
|||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? MODIFY COLUMN ? TYPE ?",
|
||||
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType},
|
||||
"ALTER TABLE ? MODIFY COLUMN ? ?",
|
||||
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field),
|
||||
).Error
|
||||
}
|
||||
return fmt.Errorf("failed to look up field with name: %s", field)
|
||||
|
|
|
@ -89,7 +89,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
|||
}
|
||||
return "text"
|
||||
case schema.Time:
|
||||
return "timestamp with time zone"
|
||||
return "timestamptz"
|
||||
case schema.Bytes:
|
||||
return "bytea"
|
||||
}
|
||||
|
|
|
@ -38,6 +38,42 @@ func (m Migrator) HasColumn(value interface{}, name string) bool {
|
|||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) AlterColumn(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
var (
|
||||
createSQL string
|
||||
newTableName = stmt.Table + "__temp"
|
||||
)
|
||||
|
||||
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL)
|
||||
|
||||
if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil {
|
||||
tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
|
||||
createSQL = reg.ReplaceAllString(createSQL, "?")
|
||||
|
||||
var columns []string
|
||||
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
|
||||
for _, columnType := range columnTypes {
|
||||
columns = append(columns, fmt.Sprintf("`%v`", columnType.Name()))
|
||||
}
|
||||
|
||||
createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table)
|
||||
return m.DB.Exec(createSQL, m.FullDataTypeOf(field)).Error
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("failed to alter field with name %v", name)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) DropColumn(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
|
|
|
@ -283,7 +283,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error {
|
|||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
|
||||
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.DataTypeOf(field)},
|
||||
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field),
|
||||
).Error
|
||||
}
|
||||
return fmt.Errorf("failed to look up field with name: %s", field)
|
||||
|
|
|
@ -2,6 +2,7 @@ package tests_test
|
|||
|
||||
import (
|
||||
"math/rand"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -124,6 +125,31 @@ func TestColumns(t *testing.T) {
|
|||
t.Errorf("Failed to migrate, got %v", err)
|
||||
}
|
||||
|
||||
type ColumnStruct2 struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"size:100"`
|
||||
}
|
||||
|
||||
if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct2{}, "Name"); err != nil {
|
||||
t.Fatalf("no error should happend when alter column, but got %v", err)
|
||||
}
|
||||
|
||||
if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil {
|
||||
t.Fatalf("no error should returns for ColumnTypes")
|
||||
} else {
|
||||
stmt := &gorm.Statement{DB: DB}
|
||||
stmt.Parse(&ColumnStruct2{})
|
||||
|
||||
for _, columnType := range columnTypes {
|
||||
if columnType.Name() == "name" {
|
||||
dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
|
||||
if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
|
||||
t.Errorf("column type should be correct, name: %v, length: %v, expects: %v", columnType.Name(), columnType.DatabaseTypeName(), dataType)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type NewColumnStruct struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
|
|
Loading…
Reference in New Issue