Improve support for AutoMigrate

This commit is contained in:
Jinzhu 2022-02-19 23:42:20 +08:00
parent e0b4e0ec8f
commit 48ced75d1d
4 changed files with 66 additions and 19 deletions

View File

@ -11,7 +11,7 @@ type ColumnType struct {
NameValue sql.NullString NameValue sql.NullString
DataTypeValue sql.NullString DataTypeValue sql.NullString
ColumnTypeValue sql.NullString ColumnTypeValue sql.NullString
PrimayKeyValue sql.NullBool PrimaryKeyValue sql.NullBool
UniqueValue sql.NullBool UniqueValue sql.NullBool
AutoIncrementValue sql.NullBool AutoIncrementValue sql.NullBool
LengthValue sql.NullInt64 LengthValue sql.NullInt64
@ -51,7 +51,7 @@ func (ct ColumnType) ColumnType() (columnType string, ok bool) {
// PrimaryKey returns the column is primary key or not. // PrimaryKey returns the column is primary key or not.
func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) { func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) {
return ct.PrimayKeyValue.Bool, ct.PrimayKeyValue.Valid return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid
} }
// AutoIncrement returns the column is auto increment or not. // AutoIncrement returns the column is auto increment or not.

View File

@ -436,6 +436,30 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
} }
} }
// check unique
if unique, ok := columnType.Unique(); ok && unique != field.Unique {
// not primary key
if !field.PrimaryKey {
alterColumn = true
}
}
// check default value
if v, ok := columnType.DefaultValue(); ok && v != field.DefaultValue {
// not primary key
if !field.PrimaryKey {
alterColumn = true
}
}
// check comment
if comment, ok := columnType.Comment(); ok && comment != field.Comment {
// not primary key
if !field.PrimaryKey {
alterColumn = true
}
}
if alterColumn && !field.IgnoreMigration { if alterColumn && !field.IgnoreMigration {
return m.DB.Migrator().AlterColumn(value, field.Name) return m.DB.Migrator().AlterColumn(value, field.Name)
} }

View File

@ -9,11 +9,11 @@ require (
github.com/lib/pq v1.10.4 github.com/lib/pq v1.10.4
github.com/mattn/go-sqlite3 v1.14.11 // indirect github.com/mattn/go-sqlite3 v1.14.11 // indirect
golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect
gorm.io/driver/mysql v1.3.0 gorm.io/driver/mysql v1.3.1
gorm.io/driver/postgres v1.3.0 gorm.io/driver/postgres v1.3.1
gorm.io/driver/sqlite v1.3.0 gorm.io/driver/sqlite v1.3.1
gorm.io/driver/sqlserver v1.3.0 gorm.io/driver/sqlserver v1.3.1
gorm.io/gorm v1.22.5 gorm.io/gorm v1.23.0
) )
replace gorm.io/gorm => ../ replace gorm.io/gorm => ../

View File

@ -45,7 +45,7 @@ func TestMigrate(t *testing.T) {
for _, m := range allModels { for _, m := range allModels {
if !DB.Migrator().HasTable(m) { if !DB.Migrator().HasTable(m) {
t.Fatalf("Failed to create table for %#v---", m) t.Fatalf("Failed to create table for %#v", m)
} }
} }
@ -313,7 +313,6 @@ func TestMigrateIndexes(t *testing.T) {
} }
func TestMigrateColumns(t *testing.T) { func TestMigrateColumns(t *testing.T) {
fullSupported := map[string]bool{"sqlite": true, "mysql": true, "postgres": true, "sqlserver": true}[DB.Dialector.Name()]
sqlite := DB.Dialector.Name() == "sqlite" sqlite := DB.Dialector.Name() == "sqlite"
sqlserver := DB.Dialector.Name() == "sqlserver" sqlserver := DB.Dialector.Name() == "sqlserver"
@ -321,7 +320,9 @@ func TestMigrateColumns(t *testing.T) {
gorm.Model gorm.Model
Name string Name string
Age int `gorm:"default:18;comment:my age"` Age int `gorm:"default:18;comment:my age"`
Code string `gorm:"unique"` Code string `gorm:"unique;comment:my code;"`
Code2 string
Code3 string `gorm:"unique"`
} }
DB.Migrator().DropTable(&ColumnStruct{}) DB.Migrator().DropTable(&ColumnStruct{})
@ -333,12 +334,19 @@ func TestMigrateColumns(t *testing.T) {
type ColumnStruct2 struct { type ColumnStruct2 struct {
gorm.Model gorm.Model
Name string `gorm:"size:100"` Name string `gorm:"size:100"`
Code string `gorm:"unique;comment:my code2;default:hello"`
Code2 string `gorm:"unique"`
// Code3 string
} }
if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct2{}, "Name"); err != nil { if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct{}, "Name"); err != nil {
t.Fatalf("no error should happened when alter column, but got %v", err) t.Fatalf("no error should happened when alter column, but got %v", err)
} }
if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil {
t.Fatalf("no error should happened when auto migrate column, but got %v", err)
}
if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil {
t.Fatalf("no error should returns for ColumnTypes") t.Fatalf("no error should returns for ColumnTypes")
} else { } else {
@ -348,7 +356,7 @@ func TestMigrateColumns(t *testing.T) {
for _, columnType := range columnTypes { for _, columnType := range columnTypes {
switch columnType.Name() { switch columnType.Name() {
case "id": case "id":
if v, ok := columnType.PrimaryKey(); (fullSupported || ok) && !v { if v, ok := columnType.PrimaryKey(); !ok || !v {
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType)
} }
case "name": case "name":
@ -356,20 +364,35 @@ func TestMigrateColumns(t *testing.T) {
if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
} }
if length, ok := columnType.Length(); ((fullSupported && !sqlite) || ok) && length != 100 { if length, ok := columnType.Length(); !sqlite && (!ok || length != 100) {
t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType)
} }
case "age": case "age":
if v, ok := columnType.DefaultValue(); (fullSupported || ok) && v != "18" { if v, ok := columnType.DefaultValue(); !ok || v != "18" {
t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
} }
if v, ok := columnType.Comment(); ((fullSupported && !sqlite && !sqlserver) || ok) && v != "my age" { if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my age") {
t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
} }
case "code": case "code":
if v, ok := columnType.Unique(); (fullSupported || ok) && !v { if v, ok := columnType.Unique(); !ok || !v {
t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
} }
if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") {
t.Fatalf("column code default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") {
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "code2":
if v, ok := columnType.Unique(); !sqlserver && (!ok || !v) {
t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "code3":
// TODO
// if v, ok := columnType.Unique(); !ok || v {
// t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
// }
} }
} }
} }