Fix panic bug in migrator due to lack of nil check for stmt.Schema (#6932)

This commit is contained in:
PiexlMax(奇淼 2024-04-26 15:15:49 +08:00 committed by GitHub
parent ac59252327
commit 78920199f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 47 additions and 19 deletions

View File

@ -127,6 +127,11 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
} }
} else { } else {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
columnTypes, err := queryTx.Migrator().ColumnTypes(value) columnTypes, err := queryTx.Migrator().ColumnTypes(value)
if err != nil { if err != nil {
return err return err
@ -211,6 +216,11 @@ func (m Migrator) CreateTable(values ...interface{}) error {
for _, value := range m.ReorderModels(values, false) { for _, value := range m.ReorderModels(values, false) {
tx := m.DB.Session(&gorm.Session{}) tx := m.DB.Session(&gorm.Session{})
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
var ( var (
createTableSQL = "CREATE TABLE ? (" createTableSQL = "CREATE TABLE ? ("
values = []interface{}{m.CurrentTable(stmt)} values = []interface{}{m.CurrentTable(stmt)}
@ -363,6 +373,9 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
func (m Migrator) AddColumn(value interface{}, name string) error { func (m Migrator) AddColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// avoid using the same name field // avoid using the same name field
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
f := stmt.Schema.LookUpField(name) f := stmt.Schema.LookUpField(name)
if f == nil { if f == nil {
return fmt.Errorf("failed to look up field with name: %s", name) return fmt.Errorf("failed to look up field with name: %s", name)
@ -382,8 +395,10 @@ func (m Migrator) AddColumn(value interface{}, name string) error {
// DropColumn drop value's `name` column // DropColumn drop value's `name` column
func (m Migrator) DropColumn(value interface{}, name string) error { func (m Migrator) DropColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(name); field != nil { if stmt.Schema != nil {
name = field.DBName if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
} }
return m.DB.Exec( return m.DB.Exec(
@ -395,13 +410,15 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
// AlterColumn alter value's `field` column' type based on schema definition // AlterColumn alter value's `field` column' type based on schema definition
func (m Migrator) AlterColumn(value interface{}, field string) error { func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil { if stmt.Schema != nil {
fileType := m.FullDataTypeOf(field) if field := stmt.Schema.LookUpField(field); field != nil {
return m.DB.Exec( fileType := m.FullDataTypeOf(field)
"ALTER TABLE ? ALTER COLUMN ? TYPE ?", return m.DB.Exec(
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, "ALTER TABLE ? ALTER COLUMN ? TYPE ?",
).Error m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
).Error
}
} }
return fmt.Errorf("failed to look up field with name: %s", field) return fmt.Errorf("failed to look up field with name: %s", field)
}) })
@ -413,8 +430,10 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase() currentDatabase := m.DB.Migrator().CurrentDatabase()
name := field name := field
if field := stmt.Schema.LookUpField(field); field != nil { if stmt.Schema != nil {
name = field.DBName if field := stmt.Schema.LookUpField(field); field != nil {
name = field.DBName
}
} }
return m.DB.Raw( return m.DB.Raw(
@ -429,12 +448,14 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
// RenameColumn rename value's field name from oldName to newName // RenameColumn rename value's field name from oldName to newName
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(oldName); field != nil { if stmt.Schema != nil {
oldName = field.DBName if field := stmt.Schema.LookUpField(oldName); field != nil {
} oldName = field.DBName
}
if field := stmt.Schema.LookUpField(newName); field != nil { if field := stmt.Schema.LookUpField(newName); field != nil {
newName = field.DBName newName = field.DBName
}
} }
return m.DB.Exec( return m.DB.Exec(
@ -794,6 +815,9 @@ type BuildIndexOptionsInterface interface {
// CreateIndex create index `name` // CreateIndex create index `name`
func (m Migrator) CreateIndex(value interface{}, name string) error { func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
if idx := stmt.Schema.LookIndex(name); idx != nil { if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
@ -826,8 +850,10 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
// DropIndex drop index `name` // DropIndex drop index `name`
func (m Migrator) DropIndex(value interface{}, name string) error { func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil { if stmt.Schema != nil {
name = idx.Name if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
} }
return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error
@ -839,8 +865,10 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int64 var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase() currentDatabase := m.DB.Migrator().CurrentDatabase()
if idx := stmt.Schema.LookIndex(name); idx != nil { if stmt.Schema != nil {
name = idx.Name if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
} }
return m.DB.Raw( return m.DB.Raw(