From f5c2126c29e375955b4db406fe6c6440f5c46b8e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 17 Nov 2020 13:14:34 +0800 Subject: [PATCH] Fix FindInBatches tests --- callbacks/create.go | 2 ++ tests/query_test.go | 15 ++++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index ad91ebc3..aec0afe9 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -55,6 +55,7 @@ func Create(config *Config) func(db *gorm.DB) { if err == nil { db.RowsAffected, _ = result.RowsAffected() + if db.RowsAffected > 0 { if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { @@ -138,6 +139,7 @@ func CreateWithReturning(db *gorm.DB) { } if !db.DryRun && db.Error == nil { + db.RowsAffected = 0 rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { diff --git a/tests/query_test.go b/tests/query_test.go index bb77dfae..20968c7e 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -260,13 +260,6 @@ func TestFindInBatches(t *testing.T) { if result := DB.Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { totalBatch += batch - for idx := range results { - results[idx].Name = results[idx].Name + "_new" - } - if err := tx.Save(results).Error; err != nil { - t.Errorf("failed to save users, got error %v", err) - } - if tx.RowsAffected != 2 { t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected) } @@ -275,6 +268,14 @@ func TestFindInBatches(t *testing.T) { t.Errorf("Incorrect users length, expects: 2, got %v", len(results)) } + for idx := range results { + results[idx].Name = results[idx].Name + "_new" + } + + if err := tx.Save(results).Error; err != nil { + t.Errorf("failed to save users, got error %v", err) + } + return nil }); result.Error != nil || result.RowsAffected != 6 { t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected)