From 536e4d34b078ea812521e209be5ac304848559e9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 10:38:01 +0800 Subject: [PATCH] Add test for AlterColumn --- dialects/mssql/migrator.go | 12 ++++++++++++ dialects/mysql/migrator.go | 4 ++-- dialects/postgres/postgres.go | 2 +- dialects/sqlite/migrator.go | 36 +++++++++++++++++++++++++++++++++++ migrator/migrator.go | 2 +- tests/migrate_test.go | 26 +++++++++++++++++++++++++ 6 files changed, 78 insertions(+), 4 deletions(-) diff --git a/dialects/mssql/migrator.go b/dialects/mssql/migrator.go index b334268e..1de49ae9 100644 --- a/dialects/mssql/migrator.go +++ b/dialects/mssql/migrator.go @@ -71,6 +71,18 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { return count > 0 } +func (m Migrator) AlterColumn(value interface{}, field string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(field); field != nil { + return m.DB.Exec( + "ALTER TABLE ? ALTER COLUMN ? ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field), + ).Error + } + return fmt.Errorf("failed to look up field with name: %s", field) + }) +} + func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(oldName); field != nil { diff --git a/dialects/mysql/migrator.go b/dialects/mysql/migrator.go index 2c11af94..74c11277 100644 --- a/dialects/mysql/migrator.go +++ b/dialects/mysql/migrator.go @@ -16,8 +16,8 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { return m.DB.Exec( - "ALTER TABLE ? MODIFY COLUMN ? TYPE ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType}, + "ALTER TABLE ? MODIFY COLUMN ? ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field), ).Error } return fmt.Errorf("failed to look up field with name: %s", field) diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 73a19e9d..db559b9d 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -89,7 +89,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { } return "text" case schema.Time: - return "timestamp with time zone" + return "timestamptz" case schema.Bytes: return "bytea" } diff --git a/dialects/sqlite/migrator.go b/dialects/sqlite/migrator.go index e36dc5e7..252e4183 100644 --- a/dialects/sqlite/migrator.go +++ b/dialects/sqlite/migrator.go @@ -38,6 +38,42 @@ func (m Migrator) HasColumn(value interface{}, name string) bool { return count > 0 } +func (m Migrator) AlterColumn(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(name); field != nil { + var ( + createSQL string + newTableName = stmt.Table + "__temp" + ) + + m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL) + + if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil { + tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ") + if err != nil { + return err + } + + createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) + createSQL = reg.ReplaceAllString(createSQL, "?") + + var columns []string + columnTypes, _ := m.DB.Migrator().ColumnTypes(value) + for _, columnType := range columnTypes { + columns = append(columns, fmt.Sprintf("`%v`", columnType.Name())) + } + + createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table) + return m.DB.Exec(createSQL, m.FullDataTypeOf(field)).Error + } else { + return err + } + } else { + return fmt.Errorf("failed to alter field with name %v", name) + } + }) +} + func (m Migrator) DropColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(name); field != nil { diff --git a/migrator/migrator.go b/migrator/migrator.go index f22d6d2c..5a06beb1 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -283,7 +283,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { if field := stmt.Schema.LookUpField(field); field != nil { return m.DB.Exec( "ALTER TABLE ? ALTER COLUMN ? TYPE ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.DataTypeOf(field)}, + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field), ).Error } return fmt.Errorf("failed to look up field with name: %s", field) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 748ee816..957db8d6 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -2,6 +2,7 @@ package tests_test import ( "math/rand" + "strings" "testing" "time" @@ -124,6 +125,31 @@ func TestColumns(t *testing.T) { t.Errorf("Failed to migrate, got %v", err) } + type ColumnStruct2 struct { + gorm.Model + Name string `gorm:"size:100"` + } + + if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct2{}, "Name"); err != nil { + t.Fatalf("no error should happend when alter column, but got %v", err) + } + + if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { + t.Fatalf("no error should returns for ColumnTypes") + } else { + stmt := &gorm.Statement{DB: DB} + stmt.Parse(&ColumnStruct2{}) + + for _, columnType := range columnTypes { + if columnType.Name() == "name" { + dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) + if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { + t.Errorf("column type should be correct, name: %v, length: %v, expects: %v", columnType.Name(), columnType.DatabaseTypeName(), dataType) + } + } + } + } + type NewColumnStruct struct { gorm.Model Name string