mirror of https://github.com/go-gorm/gorm.git
Support smart migrate, close #3078
This commit is contained in:
parent
3a97639880
commit
cc6a64adfb
|
@ -42,6 +42,7 @@ type Migrator interface {
|
||||||
AddColumn(dst interface{}, field string) error
|
AddColumn(dst interface{}, field string) error
|
||||||
DropColumn(dst interface{}, field string) error
|
DropColumn(dst interface{}, field string) error
|
||||||
AlterColumn(dst interface{}, field string) error
|
AlterColumn(dst interface{}, field string) error
|
||||||
|
MigrateColumn(dst interface{}, field *schema.Field, columnType *sql.ColumnType) error
|
||||||
HasColumn(dst interface{}, field string) bool
|
HasColumn(dst interface{}, field string) bool
|
||||||
RenameColumn(dst interface{}, oldName, field string) error
|
RenameColumn(dst interface{}, oldName, field string) error
|
||||||
ColumnTypes(dst interface{}) ([]*sql.ColumnType, error)
|
ColumnTypes(dst interface{}) ([]*sql.ColumnType, error)
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
@ -80,7 +81,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
|
||||||
|
|
||||||
// AutoMigrate
|
// AutoMigrate
|
||||||
func (m Migrator) AutoMigrate(values ...interface{}) error {
|
func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||||
// TODO smart migrate data type
|
|
||||||
for _, value := range m.ReorderModels(values, true) {
|
for _, value := range m.ReorderModels(values, true) {
|
||||||
tx := m.DB.Session(&gorm.Session{})
|
tx := m.DB.Session(&gorm.Session{})
|
||||||
if !tx.Migrator().HasTable(value) {
|
if !tx.Migrator().HasTable(value) {
|
||||||
|
@ -89,11 +89,26 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
||||||
|
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
|
||||||
|
|
||||||
for _, field := range stmt.Schema.FieldsByDBName {
|
for _, field := range stmt.Schema.FieldsByDBName {
|
||||||
if !tx.Migrator().HasColumn(value, field.DBName) {
|
var foundColumn *sql.ColumnType
|
||||||
|
|
||||||
|
for _, columnType := range columnTypes {
|
||||||
|
if columnType.Name() == field.DBName {
|
||||||
|
foundColumn = columnType
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if foundColumn == nil {
|
||||||
|
// not found, add column
|
||||||
if err := tx.Migrator().AddColumn(value, field.DBName); err != nil {
|
if err := tx.Migrator().AddColumn(value, field.DBName); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
} else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
|
||||||
|
// found, smart migrate
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,7 +135,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
fmt.Println(err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -327,6 +341,49 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType *sql.ColumnType) error {
|
||||||
|
// found, smart migrate
|
||||||
|
fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)
|
||||||
|
realDataType := strings.ToLower(columnType.DatabaseTypeName())
|
||||||
|
|
||||||
|
alterColumn := false
|
||||||
|
|
||||||
|
// check size
|
||||||
|
if length, _ := columnType.Length(); length != int64(field.Size) {
|
||||||
|
if length > 0 && field.Size > 0 {
|
||||||
|
alterColumn = true
|
||||||
|
} else {
|
||||||
|
// has size in data type and not equal
|
||||||
|
matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllString(realDataType, 1)
|
||||||
|
matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1)
|
||||||
|
if len(matches) > 0 && matches[1] != fmt.Sprint(field.Size) || len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) {
|
||||||
|
alterColumn = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check precision
|
||||||
|
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
|
||||||
|
if strings.Contains(fullDataType, fmt.Sprint(field.Precision)) {
|
||||||
|
alterColumn = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check nullable
|
||||||
|
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
|
||||||
|
// not primary key & database is nullable
|
||||||
|
if !field.PrimaryKey && nullable {
|
||||||
|
alterColumn = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if alterColumn {
|
||||||
|
return m.DB.Migrator().AlterColumn(value, field.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) {
|
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) {
|
||||||
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows()
|
rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows()
|
||||||
|
|
|
@ -55,6 +55,7 @@ type Field struct {
|
||||||
Comment string
|
Comment string
|
||||||
Size int
|
Size int
|
||||||
Precision int
|
Precision int
|
||||||
|
Scale int
|
||||||
FieldType reflect.Type
|
FieldType reflect.Type
|
||||||
IndirectFieldType reflect.Type
|
IndirectFieldType reflect.Type
|
||||||
StructField reflect.StructField
|
StructField reflect.StructField
|
||||||
|
@ -160,6 +161,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||||
field.Precision, _ = strconv.Atoi(p)
|
field.Precision, _ = strconv.Atoi(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s, ok := field.TagSettings["SCALE"]; ok {
|
||||||
|
field.Scale, _ = strconv.Atoi(s)
|
||||||
|
}
|
||||||
|
|
||||||
if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) {
|
if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) {
|
||||||
field.NotNull = true
|
field.NotNull = true
|
||||||
}
|
}
|
||||||
|
|
|
@ -379,7 +379,6 @@ func (stmt *Statement) Build(clauses ...string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TODO handle named vars
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stmt *Statement) Parse(value interface{}) (err error) {
|
func (stmt *Statement) Parse(value interface{}) (err error) {
|
||||||
|
|
|
@ -6,11 +6,11 @@ require (
|
||||||
github.com/google/uuid v1.1.1
|
github.com/google/uuid v1.1.1
|
||||||
github.com/jinzhu/now v1.1.1
|
github.com/jinzhu/now v1.1.1
|
||||||
github.com/lib/pq v1.6.0
|
github.com/lib/pq v1.6.0
|
||||||
gorm.io/driver/mysql v0.3.1
|
gorm.io/driver/mysql v0.3.2
|
||||||
gorm.io/driver/postgres v0.2.6
|
gorm.io/driver/postgres v0.2.9
|
||||||
gorm.io/driver/sqlite v1.0.9
|
gorm.io/driver/sqlite v1.0.9
|
||||||
gorm.io/driver/sqlserver v0.2.7
|
gorm.io/driver/sqlserver v0.2.7
|
||||||
gorm.io/gorm v0.2.19
|
gorm.io/gorm v0.2.36
|
||||||
)
|
)
|
||||||
|
|
||||||
replace gorm.io/gorm => ../
|
replace gorm.io/gorm => ../
|
||||||
|
|
|
@ -47,6 +47,86 @@ func TestMigrate(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSmartMigrateColumn(t *testing.T) {
|
||||||
|
type UserMigrateColumn struct {
|
||||||
|
ID uint
|
||||||
|
Name string
|
||||||
|
Salary float64
|
||||||
|
Birthday time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&UserMigrateColumn{})
|
||||||
|
|
||||||
|
DB.AutoMigrate(&UserMigrateColumn{})
|
||||||
|
|
||||||
|
type UserMigrateColumn2 struct {
|
||||||
|
ID uint
|
||||||
|
Name string `gorm:"size:128"`
|
||||||
|
Salary float64 `gorm:"precision:2"`
|
||||||
|
Birthday time.Time `gorm:"precision:2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil {
|
||||||
|
t.Fatalf("failed to auto migrate, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get column types, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, columnType := range columnTypes {
|
||||||
|
switch columnType.Name() {
|
||||||
|
case "name":
|
||||||
|
if length, _ := columnType.Length(); length != 0 && length != 128 {
|
||||||
|
t.Fatalf("name's length should be 128, but got %v", length)
|
||||||
|
}
|
||||||
|
case "salary":
|
||||||
|
if precision, o, _ := columnType.DecimalSize(); precision != 0 && precision != 2 {
|
||||||
|
t.Fatalf("salary's precision should be 2, but got %v %v", precision, o)
|
||||||
|
}
|
||||||
|
case "birthday":
|
||||||
|
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 2 {
|
||||||
|
t.Fatalf("birthday's precision should be 2, but got %v", precision)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserMigrateColumn3 struct {
|
||||||
|
ID uint
|
||||||
|
Name string `gorm:"size:256"`
|
||||||
|
Salary float64 `gorm:"precision:3"`
|
||||||
|
Birthday time.Time `gorm:"precision:3"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil {
|
||||||
|
t.Fatalf("failed to auto migrate, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
columnTypes, err = DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get column types, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, columnType := range columnTypes {
|
||||||
|
switch columnType.Name() {
|
||||||
|
case "name":
|
||||||
|
if length, _ := columnType.Length(); length != 0 && length != 256 {
|
||||||
|
t.Fatalf("name's length should be 128, but got %v", length)
|
||||||
|
}
|
||||||
|
case "salary":
|
||||||
|
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 {
|
||||||
|
t.Fatalf("salary's precision should be 2, but got %v", precision)
|
||||||
|
}
|
||||||
|
case "birthday":
|
||||||
|
if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 {
|
||||||
|
t.Fatalf("birthday's precision should be 2, but got %v", precision)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func TestMigrateWithComment(t *testing.T) {
|
func TestMigrateWithComment(t *testing.T) {
|
||||||
type UserWithComment struct {
|
type UserWithComment struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
|
|
Loading…
Reference in New Issue