diff --git a/finisher_api.go b/finisher_api.go index afefd9fd..032c3059 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -106,6 +106,30 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { return } +// FindInBatches find records in batches +func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) (tx *DB) { + tx = db.Session(&Session{WithConditions: true}) + rowsAffected := int64(0) + batch := 0 + + for { + result := tx.Limit(batchSize).Offset(batch * batchSize).Find(dest) + rowsAffected += result.RowsAffected + batch++ + + if result.Error == nil && result.RowsAffected != 0 { + tx.AddError(fc(result, batch)) + } + + if tx.Error != nil || int(result.RowsAffected) < batchSize { + break + } + } + + tx.RowsAffected = rowsAffected + return +} + func (tx *DB) assignExprsToValue(exprs []clause.Expression) { for _, expr := range exprs { if eq, ok := expr.(clause.Eq); ok { diff --git a/tests/query_test.go b/tests/query_test.go index 66413b3b..de65b63b 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -102,6 +102,44 @@ func TestFind(t *testing.T) { }) } +func TestFindInBatches(t *testing.T) { + var users = []User{ + *GetUser("find_in_batches", Config{}), + *GetUser("find_in_batches", Config{}), + *GetUser("find_in_batches", Config{}), + *GetUser("find_in_batches", Config{}), + *GetUser("find_in_batches", Config{}), + *GetUser("find_in_batches", Config{}), + } + + DB.Create(&users) + + var ( + results []User + totalBatch int + ) + + if result := DB.Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { + totalBatch += batch + + if tx.RowsAffected != 2 { + t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected) + } + + if len(results) != 2 { + t.Errorf("Incorrect users length, expects: 2, got %v", len(results)) + } + + return nil + }); result.Error != nil || result.RowsAffected != 6 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + + if totalBatch != 6 { + t.Errorf("incorrect total batch, expects: %v, got %v", 6, totalBatch) + } +} + func TestFillSmallerStruct(t *testing.T) { user := User{Name: "SmallerUser", Age: 100} DB.Save(&user)