From 300a23fc3137b947a3ce9bca97fa5c81cc605636 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 2 Dec 2021 10:39:24 +0800 Subject: [PATCH] Check rows.Close error, close #4891 --- callbacks/create.go | 2 +- callbacks/delete.go | 2 +- callbacks/query.go | 3 +-- callbacks/update.go | 2 +- finisher_api.go | 2 +- migrator/migrator.go | 8 +++++--- tests/associations_belongs_to_test.go | 7 +++++++ 7 files changed, 17 insertions(+), 9 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index c585fbe9..9dc5b8b1 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -83,7 +83,7 @@ func Create(config *Config) func(db *gorm.DB) { ) if db.AddError(err) == nil { gorm.Scan(rows, db, mode) - rows.Close() + db.AddError(rows.Close()) } return diff --git a/callbacks/delete.go b/callbacks/delete.go index 525c0145..b05a9d08 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -168,7 +168,7 @@ func Delete(config *Config) func(db *gorm.DB) { if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { gorm.Scan(rows, db, mode) - rows.Close() + db.AddError(rows.Close()) } } } diff --git a/callbacks/query.go b/callbacks/query.go index 6ca3a1fb..2f98a4b6 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -20,9 +20,8 @@ func Query(db *gorm.DB) { db.AddError(err) return } - defer rows.Close() - gorm.Scan(rows, db, 0) + db.AddError(rows.Close()) } } } diff --git a/callbacks/update.go b/callbacks/update.go index 1f4960b5..fa7640de 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -88,7 +88,7 @@ func Update(config *Config) func(db *gorm.DB) { db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface() gorm.Scan(rows, db, mode) db.Statement.Dest = dest - rows.Close() + db.AddError(rows.Close()) } } else { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/finisher_api.go b/finisher_api.go index b3bdedc8..d38d60b7 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -457,12 +457,12 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { tx.Config = &config if rows, err := tx.Rows(); err == nil { - defer rows.Close() if rows.Next() { tx.ScanRows(rows, dest) } else { tx.RowsAffected = 0 } + tx.AddError(rows.Close()) } currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) { diff --git a/migrator/migrator.go b/migrator/migrator.go index 91bf60a7..18212dbb 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -430,13 +430,15 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy // ColumnTypes return columnTypes []gorm.ColumnType and execErr error func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { columnTypes := make([]gorm.ColumnType, 0) - execErr := m.RunWithValue(value, func(stmt *gorm.Statement) error { + execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() if err != nil { return err } - defer rows.Close() + defer func() { + err = rows.Close() + }() var rawColumnTypes []*sql.ColumnType rawColumnTypes, err = rows.ColumnTypes() @@ -448,7 +450,7 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { columnTypes = append(columnTypes, c) } - return nil + return }) return columnTypes, execErr diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index 3e4de726..e37da7d3 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -132,6 +132,13 @@ func TestBelongsToAssociation(t *testing.T) { AssertAssociationCount(t, user2, "Company", 0, "after clear") AssertAssociationCount(t, user2, "Manager", 0, "after clear") + + // unexist company id + unexistCompanyID := company.ID + 9999999 + user = User{Name: "invalid-user-with-invalid-belongs-to-foreign-key", CompanyID: &unexistCompanyID} + if err := DB.Create(&user).Error; err == nil { + t.Errorf("should have gotten foreign key violation error") + } } func TestBelongsToAssociationForSlice(t *testing.T) {