diff --git a/finisher_api.go b/finisher_api.go index c9e2a3b2..211e2f8f 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -140,13 +140,18 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { } // FindInBatches find records in batches -func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) (tx *DB) { - tx = db.Session(&Session{WithConditions: true}) - rowsAffected := int64(0) - batch := 0 +func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { + var ( + tx = db.Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }).Session(&Session{WithConditions: true}) + queryDB = tx + rowsAffected int64 + batch int + ) for { - result := tx.Limit(batchSize).Offset(batch * batchSize).Find(dest) + result := queryDB.Limit(batchSize).Find(dest) rowsAffected += result.RowsAffected batch++ @@ -156,11 +161,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat if tx.Error != nil || int(result.RowsAffected) < batchSize { break + } else { + resultsValue := reflect.Indirect(reflect.ValueOf(dest)) + primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) + queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } } tx.RowsAffected = rowsAffected - return + return tx } func (tx *DB) assignInterfacesToValue(values ...interface{}) { diff --git a/tests/query_test.go b/tests/query_test.go index dc2907e6..bb77dfae 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -260,6 +260,13 @@ func TestFindInBatches(t *testing.T) { if result := DB.Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { totalBatch += batch + for idx := range results { + results[idx].Name = results[idx].Name + "_new" + } + if err := tx.Save(results).Error; err != nil { + t.Errorf("failed to save users, got error %v", err) + } + if tx.RowsAffected != 2 { t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected) } @@ -276,6 +283,12 @@ func TestFindInBatches(t *testing.T) { if totalBatch != 6 { t.Errorf("incorrect total batch, expects: %v, got %v", 6, totalBatch) } + + var count int64 + DB.Model(&User{}).Where("name = ?", "find_in_batches_new").Count(&count) + if count != 6 { + t.Errorf("incorrect count after update, expects: %v, got %v", 6, count) + } } func TestFillSmallerStruct(t *testing.T) {