diff --git a/migrator.go b/migrator.go index 37051f81..ed8a8e26 100644 --- a/migrator.go +++ b/migrator.go @@ -42,6 +42,7 @@ type Migrator interface { AddColumn(dst interface{}, field string) error DropColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error + MigrateColumn(dst interface{}, field *schema.Field, columnType *sql.ColumnType) error HasColumn(dst interface{}, field string) bool RenameColumn(dst interface{}, oldName, field string) error ColumnTypes(dst interface{}) ([]*sql.ColumnType, error) diff --git a/migrator/migrator.go b/migrator/migrator.go index d50159dd..d93b8a6d 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "reflect" + "regexp" "strings" "gorm.io/gorm" @@ -80,7 +81,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { // AutoMigrate func (m Migrator) AutoMigrate(values ...interface{}) error { - // TODO smart migrate data type for _, value := range m.ReorderModels(values, true) { tx := m.DB.Session(&gorm.Session{}) if !tx.Migrator().HasTable(value) { @@ -89,11 +89,26 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } else { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { + columnTypes, _ := m.DB.Migrator().ColumnTypes(value) + for _, field := range stmt.Schema.FieldsByDBName { - if !tx.Migrator().HasColumn(value, field.DBName) { + var foundColumn *sql.ColumnType + + for _, columnType := range columnTypes { + if columnType.Name() == field.DBName { + foundColumn = columnType + break + } + } + + if foundColumn == nil { + // not found, add column if err := tx.Migrator().AddColumn(value, field.DBName); err != nil { return err } + } else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil { + // found, smart migrate + return err } } @@ -120,7 +135,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } return nil }); err != nil { - fmt.Println(err) return err } } @@ -327,6 +341,49 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error }) } +func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType *sql.ColumnType) error { + // found, smart migrate + fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) + realDataType := strings.ToLower(columnType.DatabaseTypeName()) + + alterColumn := false + + // check size + if length, _ := columnType.Length(); length != int64(field.Size) { + if length > 0 && field.Size > 0 { + alterColumn = true + } else { + // has size in data type and not equal + matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllString(realDataType, 1) + matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1) + if len(matches) > 0 && matches[1] != fmt.Sprint(field.Size) || len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) { + alterColumn = true + } + } + } + + // check precision + if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { + if strings.Contains(fullDataType, fmt.Sprint(field.Precision)) { + alterColumn = true + } + } + + // check nullable + if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull { + // not primary key & database is nullable + if !field.PrimaryKey && nullable { + alterColumn = true + } + } + + if alterColumn { + return m.DB.Migrator().AlterColumn(value, field.Name) + } + + return nil +} + func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { err = m.RunWithValue(value, func(stmt *gorm.Statement) error { rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows() diff --git a/schema/field.go b/schema/field.go index 497aa02d..524d19fb 100644 --- a/schema/field.go +++ b/schema/field.go @@ -55,6 +55,7 @@ type Field struct { Comment string Size int Precision int + Scale int FieldType reflect.Type IndirectFieldType reflect.Type StructField reflect.StructField @@ -160,6 +161,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Precision, _ = strconv.Atoi(p) } + if s, ok := field.TagSettings["SCALE"]; ok { + field.Scale, _ = strconv.Atoi(s) + } + if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) { field.NotNull = true } diff --git a/statement.go b/statement.go index 214a15bb..95d23fa5 100644 --- a/statement.go +++ b/statement.go @@ -379,7 +379,6 @@ func (stmt *Statement) Build(clauses ...string) { } } } - // TODO handle named vars } func (stmt *Statement) Parse(value interface{}) (err error) { diff --git a/tests/go.mod b/tests/go.mod index 54a808d0..9d4e892d 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,11 +6,11 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.3.1 - gorm.io/driver/postgres v0.2.6 + gorm.io/driver/mysql v0.3.2 + gorm.io/driver/postgres v0.2.9 gorm.io/driver/sqlite v1.0.9 gorm.io/driver/sqlserver v0.2.7 - gorm.io/gorm v0.2.19 + gorm.io/gorm v0.2.36 ) replace gorm.io/gorm => ../ diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 1b002049..4cc8a7c3 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -47,6 +47,86 @@ func TestMigrate(t *testing.T) { } } +func TestSmartMigrateColumn(t *testing.T) { + type UserMigrateColumn struct { + ID uint + Name string + Salary float64 + Birthday time.Time + } + + DB.Migrator().DropTable(&UserMigrateColumn{}) + + DB.AutoMigrate(&UserMigrateColumn{}) + + type UserMigrateColumn2 struct { + ID uint + Name string `gorm:"size:128"` + Salary float64 `gorm:"precision:2"` + Birthday time.Time `gorm:"precision:2"` + } + + if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{}) + if err != nil { + t.Fatalf("failed to get column types, got error: %v", err) + } + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "name": + if length, _ := columnType.Length(); length != 0 && length != 128 { + t.Fatalf("name's length should be 128, but got %v", length) + } + case "salary": + if precision, o, _ := columnType.DecimalSize(); precision != 0 && precision != 2 { + t.Fatalf("salary's precision should be 2, but got %v %v", precision, o) + } + case "birthday": + if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 2 { + t.Fatalf("birthday's precision should be 2, but got %v", precision) + } + } + } + + type UserMigrateColumn3 struct { + ID uint + Name string `gorm:"size:256"` + Salary float64 `gorm:"precision:3"` + Birthday time.Time `gorm:"precision:3"` + } + + if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + columnTypes, err = DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{}) + if err != nil { + t.Fatalf("failed to get column types, got error: %v", err) + } + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "name": + if length, _ := columnType.Length(); length != 0 && length != 256 { + t.Fatalf("name's length should be 128, but got %v", length) + } + case "salary": + if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 { + t.Fatalf("salary's precision should be 2, but got %v", precision) + } + case "birthday": + if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 { + t.Fatalf("birthday's precision should be 2, but got %v", precision) + } + } + } + +} + func TestMigrateWithComment(t *testing.T) { type UserWithComment struct { gorm.Model