refactor(migrator): non-standard codes (#6180)

This commit is contained in:
jessetang 2023-04-11 10:32:46 +08:00 committed by GitHub
parent 1d9f4b0f55
commit 05bb9d6106
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 16 deletions

View File

@ -17,12 +17,12 @@ func (idx Index) Table() string {
return idx.TableName return idx.TableName
} }
// Name return the name of the index. // Name return the name of the index.
func (idx Index) Name() string { func (idx Index) Name() string {
return idx.NameValue return idx.NameValue
} }
// Columns return the columns fo the index // Columns return the columns of the index
func (idx Index) Columns() []string { func (idx Index) Columns() []string {
return idx.ColumnList return idx.ColumnList
} }
@ -37,7 +37,7 @@ func (idx Index) Unique() (unique bool, ok bool) {
return idx.UniqueValue.Bool, idx.UniqueValue.Valid return idx.UniqueValue.Bool, idx.UniqueValue.Valid
} }
// Option return the optional attribute fo the index // Option return the optional attribute of the index
func (idx Index) Option() string { func (idx Index) Option() string {
return idx.OptionValue return idx.OptionValue
} }

View File

@ -113,7 +113,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
return err return err
} }
} else { } else {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
columnTypes, err := queryTx.Migrator().ColumnTypes(value) columnTypes, err := queryTx.Migrator().ColumnTypes(value)
if err != nil { if err != nil {
return err return err
@ -123,7 +123,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
parseCheckConstraints = stmt.Schema.ParseCheckConstraints() parseCheckConstraints = stmt.Schema.ParseCheckConstraints()
) )
for _, dbName := range stmt.Schema.DBNames { for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[dbName]
var foundColumn gorm.ColumnType var foundColumn gorm.ColumnType
for _, columnType := range columnTypes { for _, columnType := range columnTypes {
@ -135,12 +134,15 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
if foundColumn == nil { if foundColumn == nil {
// not found, add column // not found, add column
if err := execTx.Migrator().AddColumn(value, dbName); err != nil { if err = execTx.Migrator().AddColumn(value, dbName); err != nil {
return err
}
} else {
// found, smartly migrate
field := stmt.Schema.FieldsByDBName[dbName]
if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
return err return err
} }
} else if err := execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
// found, smart migrate
return err
} }
} }
@ -195,7 +197,7 @@ func (m Migrator) GetTables() (tableList []string, err error) {
func (m Migrator) CreateTable(values ...interface{}) error { func (m Migrator) CreateTable(values ...interface{}) error {
for _, value := range m.ReorderModels(values, false) { for _, value := range m.ReorderModels(values, false) {
tx := m.DB.Session(&gorm.Session{}) tx := m.DB.Session(&gorm.Session{})
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
var ( var (
createTableSQL = "CREATE TABLE ? (" createTableSQL = "CREATE TABLE ? ("
values = []interface{}{m.CurrentTable(stmt)} values = []interface{}{m.CurrentTable(stmt)}
@ -214,7 +216,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
createTableSQL += "PRIMARY KEY ?," createTableSQL += "PRIMARY KEY ?,"
primaryKeys := []interface{}{} primaryKeys := make([]interface{}, 0, len(stmt.Schema.PrimaryFields))
for _, field := range stmt.Schema.PrimaryFields { for _, field := range stmt.Schema.PrimaryFields {
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName}) primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
} }
@ -225,8 +227,8 @@ func (m Migrator) CreateTable(values ...interface{}) error {
for _, idx := range stmt.Schema.ParseIndexes() { for _, idx := range stmt.Schema.ParseIndexes() {
if m.CreateIndexAfterCreateTable { if m.CreateIndexAfterCreateTable {
defer func(value interface{}, name string) { defer func(value interface{}, name string) {
if errr == nil { if err == nil {
errr = tx.Migrator().CreateIndex(value, name) err = tx.Migrator().CreateIndex(value, name)
} }
}(value, idx.Name) }(value, idx.Name)
} else { } else {
@ -276,8 +278,8 @@ func (m Migrator) CreateTable(values ...interface{}) error {
createTableSQL += fmt.Sprint(tableOption) createTableSQL += fmt.Sprint(tableOption)
} }
errr = tx.Exec(createTableSQL, values...).Error err = tx.Exec(createTableSQL, values...).Error
return errr return err
}); err != nil { }); err != nil {
return err return err
} }
@ -498,7 +500,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL")) currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL"))
dv, dvNotNull := columnType.DefaultValue() dv, dvNotNull := columnType.DefaultValue()
if dvNotNull && !currentDefaultNotNull { if dvNotNull && !currentDefaultNotNull {
// defalut value -> null // default value -> null
alterColumn = true alterColumn = true
} else if !dvNotNull && currentDefaultNotNull { } else if !dvNotNull && currentDefaultNotNull {
// null -> default value // null -> default value