Support smart migrate, close #3078

This commit is contained in:
Jinzhu 2020-08-23 15:40:19 +08:00
parent 3a97639880
commit cc6a64adfb
6 changed files with 149 additions and 7 deletions

View File

@ -42,6 +42,7 @@ type Migrator interface {
AddColumn(dst interface{}, field string) error AddColumn(dst interface{}, field string) error
DropColumn(dst interface{}, field string) error DropColumn(dst interface{}, field string) error
AlterColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error
MigrateColumn(dst interface{}, field *schema.Field, columnType *sql.ColumnType) error
HasColumn(dst interface{}, field string) bool HasColumn(dst interface{}, field string) bool
RenameColumn(dst interface{}, oldName, field string) error RenameColumn(dst interface{}, oldName, field string) error
ColumnTypes(dst interface{}) ([]*sql.ColumnType, error) ColumnTypes(dst interface{}) ([]*sql.ColumnType, error)

View File

@ -5,6 +5,7 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"reflect" "reflect"
"regexp"
"strings" "strings"
"gorm.io/gorm" "gorm.io/gorm"
@ -80,7 +81,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
// AutoMigrate // AutoMigrate
func (m Migrator) AutoMigrate(values ...interface{}) error { func (m Migrator) AutoMigrate(values ...interface{}) error {
// TODO smart migrate data type
for _, value := range m.ReorderModels(values, true) { for _, value := range m.ReorderModels(values, true) {
tx := m.DB.Session(&gorm.Session{}) tx := m.DB.Session(&gorm.Session{})
if !tx.Migrator().HasTable(value) { if !tx.Migrator().HasTable(value) {
@ -89,11 +89,26 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
} }
} else { } else {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
for _, field := range stmt.Schema.FieldsByDBName { for _, field := range stmt.Schema.FieldsByDBName {
if !tx.Migrator().HasColumn(value, field.DBName) { var foundColumn *sql.ColumnType
for _, columnType := range columnTypes {
if columnType.Name() == field.DBName {
foundColumn = columnType
break
}
}
if foundColumn == nil {
// not found, add column
if err := tx.Migrator().AddColumn(value, field.DBName); err != nil { if err := tx.Migrator().AddColumn(value, field.DBName); err != nil {
return err return err
} }
} else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
// found, smart migrate
return err
} }
} }
@ -120,7 +135,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
} }
return nil return nil
}); err != nil { }); err != nil {
fmt.Println(err)
return err return err
} }
} }
@ -327,6 +341,49 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
}) })
} }
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType *sql.ColumnType) error {
// found, smart migrate
fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)
realDataType := strings.ToLower(columnType.DatabaseTypeName())
alterColumn := false
// check size
if length, _ := columnType.Length(); length != int64(field.Size) {
if length > 0 && field.Size > 0 {
alterColumn = true
} else {
// has size in data type and not equal
matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllString(realDataType, 1)
matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1)
if len(matches) > 0 && matches[1] != fmt.Sprint(field.Size) || len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) {
alterColumn = true
}
}
}
// check precision
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
if strings.Contains(fullDataType, fmt.Sprint(field.Precision)) {
alterColumn = true
}
}
// check nullable
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
// not primary key & database is nullable
if !field.PrimaryKey && nullable {
alterColumn = true
}
}
if alterColumn {
return m.DB.Migrator().AlterColumn(value, field.Name)
}
return nil
}
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) {
err = m.RunWithValue(value, func(stmt *gorm.Statement) error { err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows() rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows()

View File

@ -55,6 +55,7 @@ type Field struct {
Comment string Comment string
Size int Size int
Precision int Precision int
Scale int
FieldType reflect.Type FieldType reflect.Type
IndirectFieldType reflect.Type IndirectFieldType reflect.Type
StructField reflect.StructField StructField reflect.StructField
@ -160,6 +161,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.Precision, _ = strconv.Atoi(p) field.Precision, _ = strconv.Atoi(p)
} }
if s, ok := field.TagSettings["SCALE"]; ok {
field.Scale, _ = strconv.Atoi(s)
}
if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) { if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) {
field.NotNull = true field.NotNull = true
} }

View File

@ -379,7 +379,6 @@ func (stmt *Statement) Build(clauses ...string) {
} }
} }
} }
// TODO handle named vars
} }
func (stmt *Statement) Parse(value interface{}) (err error) { func (stmt *Statement) Parse(value interface{}) (err error) {

View File

@ -6,11 +6,11 @@ require (
github.com/google/uuid v1.1.1 github.com/google/uuid v1.1.1
github.com/jinzhu/now v1.1.1 github.com/jinzhu/now v1.1.1
github.com/lib/pq v1.6.0 github.com/lib/pq v1.6.0
gorm.io/driver/mysql v0.3.1 gorm.io/driver/mysql v0.3.2
gorm.io/driver/postgres v0.2.6 gorm.io/driver/postgres v0.2.9
gorm.io/driver/sqlite v1.0.9 gorm.io/driver/sqlite v1.0.9
gorm.io/driver/sqlserver v0.2.7 gorm.io/driver/sqlserver v0.2.7
gorm.io/gorm v0.2.19 gorm.io/gorm v0.2.36
) )
replace gorm.io/gorm => ../ replace gorm.io/gorm => ../

View File

@ -47,6 +47,86 @@ func TestMigrate(t *testing.T) {
} }
} }
func TestSmartMigrateColumn(t *testing.T) {
type UserMigrateColumn struct {
ID uint
Name string
Salary float64
Birthday time.Time
}
DB.Migrator().DropTable(&UserMigrateColumn{})
DB.AutoMigrate(&UserMigrateColumn{})
type UserMigrateColumn2 struct {
ID uint
Name string `gorm:"size:128"`
Salary float64 `gorm:"precision:2"`
Birthday time.Time `gorm:"precision:2"`
}
if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil {
t.Fatalf("failed to auto migrate, got error: %v", err)
}
columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{})
if err != nil {
t.Fatalf("failed to get column types, got error: %v", err)
}
for _, columnType := range columnTypes {
switch columnType.Name() {
case "name":
if length, _ := columnType.Length(); length != 0 && length != 128 {
t.Fatalf("name's length should be 128, but got %v", length)
}
case "salary":
if precision, o, _ := columnType.DecimalSize(); precision != 0 && precision != 2 {
t.Fatalf("salary's precision should be 2, but got %v %v", precision, o)
}
case "birthday":
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 2 {
t.Fatalf("birthday's precision should be 2, but got %v", precision)
}
}
}
type UserMigrateColumn3 struct {
ID uint
Name string `gorm:"size:256"`
Salary float64 `gorm:"precision:3"`
Birthday time.Time `gorm:"precision:3"`
}
if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil {
t.Fatalf("failed to auto migrate, got error: %v", err)
}
columnTypes, err = DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{})
if err != nil {
t.Fatalf("failed to get column types, got error: %v", err)
}
for _, columnType := range columnTypes {
switch columnType.Name() {
case "name":
if length, _ := columnType.Length(); length != 0 && length != 256 {
t.Fatalf("name's length should be 128, but got %v", length)
}
case "salary":
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 {
t.Fatalf("salary's precision should be 2, but got %v", precision)
}
case "birthday":
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 {
t.Fatalf("birthday's precision should be 2, but got %v", precision)
}
}
}
}
func TestMigrateWithComment(t *testing.T) { func TestMigrateWithComment(t *testing.T) {
type UserWithComment struct { type UserWithComment struct {
gorm.Model gorm.Model