forked from mirror/gorm
Support smart migrate, close #3078
This commit is contained in:
parent
3a97639880
commit
cc6a64adfb
|
@ -42,6 +42,7 @@ type Migrator interface {
|
|||
AddColumn(dst interface{}, field string) error
|
||||
DropColumn(dst interface{}, field string) error
|
||||
AlterColumn(dst interface{}, field string) error
|
||||
MigrateColumn(dst interface{}, field *schema.Field, columnType *sql.ColumnType) error
|
||||
HasColumn(dst interface{}, field string) bool
|
||||
RenameColumn(dst interface{}, oldName, field string) error
|
||||
ColumnTypes(dst interface{}) ([]*sql.ColumnType, error)
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
@ -80,7 +81,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
|
|||
|
||||
// AutoMigrate
|
||||
func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||
// TODO smart migrate data type
|
||||
for _, value := range m.ReorderModels(values, true) {
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
if !tx.Migrator().HasTable(value) {
|
||||
|
@ -89,11 +89,26 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
|||
}
|
||||
} else {
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
||||
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
|
||||
|
||||
for _, field := range stmt.Schema.FieldsByDBName {
|
||||
if !tx.Migrator().HasColumn(value, field.DBName) {
|
||||
var foundColumn *sql.ColumnType
|
||||
|
||||
for _, columnType := range columnTypes {
|
||||
if columnType.Name() == field.DBName {
|
||||
foundColumn = columnType
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if foundColumn == nil {
|
||||
// not found, add column
|
||||
if err := tx.Migrator().AddColumn(value, field.DBName); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
|
||||
// found, smart migrate
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -120,7 +135,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
|||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
fmt.Println(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -327,6 +341,49 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
|
|||
})
|
||||
}
|
||||
|
||||
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType *sql.ColumnType) error {
|
||||
// found, smart migrate
|
||||
fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)
|
||||
realDataType := strings.ToLower(columnType.DatabaseTypeName())
|
||||
|
||||
alterColumn := false
|
||||
|
||||
// check size
|
||||
if length, _ := columnType.Length(); length != int64(field.Size) {
|
||||
if length > 0 && field.Size > 0 {
|
||||
alterColumn = true
|
||||
} else {
|
||||
// has size in data type and not equal
|
||||
matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllString(realDataType, 1)
|
||||
matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1)
|
||||
if len(matches) > 0 && matches[1] != fmt.Sprint(field.Size) || len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check precision
|
||||
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
|
||||
if strings.Contains(fullDataType, fmt.Sprint(field.Precision)) {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
|
||||
// check nullable
|
||||
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
|
||||
// not primary key & database is nullable
|
||||
if !field.PrimaryKey && nullable {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
|
||||
if alterColumn {
|
||||
return m.DB.Migrator().AlterColumn(value, field.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) {
|
||||
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows()
|
||||
|
|
|
@ -55,6 +55,7 @@ type Field struct {
|
|||
Comment string
|
||||
Size int
|
||||
Precision int
|
||||
Scale int
|
||||
FieldType reflect.Type
|
||||
IndirectFieldType reflect.Type
|
||||
StructField reflect.StructField
|
||||
|
@ -160,6 +161,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||
field.Precision, _ = strconv.Atoi(p)
|
||||
}
|
||||
|
||||
if s, ok := field.TagSettings["SCALE"]; ok {
|
||||
field.Scale, _ = strconv.Atoi(s)
|
||||
}
|
||||
|
||||
if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) {
|
||||
field.NotNull = true
|
||||
}
|
||||
|
|
|
@ -379,7 +379,6 @@ func (stmt *Statement) Build(clauses ...string) {
|
|||
}
|
||||
}
|
||||
}
|
||||
// TODO handle named vars
|
||||
}
|
||||
|
||||
func (stmt *Statement) Parse(value interface{}) (err error) {
|
||||
|
|
|
@ -6,11 +6,11 @@ require (
|
|||
github.com/google/uuid v1.1.1
|
||||
github.com/jinzhu/now v1.1.1
|
||||
github.com/lib/pq v1.6.0
|
||||
gorm.io/driver/mysql v0.3.1
|
||||
gorm.io/driver/postgres v0.2.6
|
||||
gorm.io/driver/mysql v0.3.2
|
||||
gorm.io/driver/postgres v0.2.9
|
||||
gorm.io/driver/sqlite v1.0.9
|
||||
gorm.io/driver/sqlserver v0.2.7
|
||||
gorm.io/gorm v0.2.19
|
||||
gorm.io/gorm v0.2.36
|
||||
)
|
||||
|
||||
replace gorm.io/gorm => ../
|
||||
|
|
|
@ -47,6 +47,86 @@ func TestMigrate(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSmartMigrateColumn(t *testing.T) {
|
||||
type UserMigrateColumn struct {
|
||||
ID uint
|
||||
Name string
|
||||
Salary float64
|
||||
Birthday time.Time
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&UserMigrateColumn{})
|
||||
|
||||
DB.AutoMigrate(&UserMigrateColumn{})
|
||||
|
||||
type UserMigrateColumn2 struct {
|
||||
ID uint
|
||||
Name string `gorm:"size:128"`
|
||||
Salary float64 `gorm:"precision:2"`
|
||||
Birthday time.Time `gorm:"precision:2"`
|
||||
}
|
||||
|
||||
if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil {
|
||||
t.Fatalf("failed to auto migrate, got error: %v", err)
|
||||
}
|
||||
|
||||
columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get column types, got error: %v", err)
|
||||
}
|
||||
|
||||
for _, columnType := range columnTypes {
|
||||
switch columnType.Name() {
|
||||
case "name":
|
||||
if length, _ := columnType.Length(); length != 0 && length != 128 {
|
||||
t.Fatalf("name's length should be 128, but got %v", length)
|
||||
}
|
||||
case "salary":
|
||||
if precision, o, _ := columnType.DecimalSize(); precision != 0 && precision != 2 {
|
||||
t.Fatalf("salary's precision should be 2, but got %v %v", precision, o)
|
||||
}
|
||||
case "birthday":
|
||||
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 2 {
|
||||
t.Fatalf("birthday's precision should be 2, but got %v", precision)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type UserMigrateColumn3 struct {
|
||||
ID uint
|
||||
Name string `gorm:"size:256"`
|
||||
Salary float64 `gorm:"precision:3"`
|
||||
Birthday time.Time `gorm:"precision:3"`
|
||||
}
|
||||
|
||||
if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil {
|
||||
t.Fatalf("failed to auto migrate, got error: %v", err)
|
||||
}
|
||||
|
||||
columnTypes, err = DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get column types, got error: %v", err)
|
||||
}
|
||||
|
||||
for _, columnType := range columnTypes {
|
||||
switch columnType.Name() {
|
||||
case "name":
|
||||
if length, _ := columnType.Length(); length != 0 && length != 256 {
|
||||
t.Fatalf("name's length should be 128, but got %v", length)
|
||||
}
|
||||
case "salary":
|
||||
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 {
|
||||
t.Fatalf("salary's precision should be 2, but got %v", precision)
|
||||
}
|
||||
case "birthday":
|
||||
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 {
|
||||
t.Fatalf("birthday's precision should be 2, but got %v", precision)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestMigrateWithComment(t *testing.T) {
|
||||
type UserWithComment struct {
|
||||
gorm.Model
|
||||
|
|
Loading…
Reference in New Issue