mirror of https://github.com/go-gorm/gorm.git
Add FindInBatches support
This commit is contained in:
parent
dbc3f8feb0
commit
45cb6b49bf
|
@ -106,6 +106,30 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||
return
|
||||
}
|
||||
|
||||
// FindInBatches find records in batches
|
||||
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) (tx *DB) {
|
||||
tx = db.Session(&Session{WithConditions: true})
|
||||
rowsAffected := int64(0)
|
||||
batch := 0
|
||||
|
||||
for {
|
||||
result := tx.Limit(batchSize).Offset(batch * batchSize).Find(dest)
|
||||
rowsAffected += result.RowsAffected
|
||||
batch++
|
||||
|
||||
if result.Error == nil && result.RowsAffected != 0 {
|
||||
tx.AddError(fc(result, batch))
|
||||
}
|
||||
|
||||
if tx.Error != nil || int(result.RowsAffected) < batchSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
tx.RowsAffected = rowsAffected
|
||||
return
|
||||
}
|
||||
|
||||
func (tx *DB) assignExprsToValue(exprs []clause.Expression) {
|
||||
for _, expr := range exprs {
|
||||
if eq, ok := expr.(clause.Eq); ok {
|
||||
|
|
|
@ -102,6 +102,44 @@ func TestFind(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestFindInBatches(t *testing.T) {
|
||||
var users = []User{
|
||||
*GetUser("find_in_batches", Config{}),
|
||||
*GetUser("find_in_batches", Config{}),
|
||||
*GetUser("find_in_batches", Config{}),
|
||||
*GetUser("find_in_batches", Config{}),
|
||||
*GetUser("find_in_batches", Config{}),
|
||||
*GetUser("find_in_batches", Config{}),
|
||||
}
|
||||
|
||||
DB.Create(&users)
|
||||
|
||||
var (
|
||||
results []User
|
||||
totalBatch int
|
||||
)
|
||||
|
||||
if result := DB.Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error {
|
||||
totalBatch += batch
|
||||
|
||||
if tx.RowsAffected != 2 {
|
||||
t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected)
|
||||
}
|
||||
|
||||
if len(results) != 2 {
|
||||
t.Errorf("Incorrect users length, expects: 2, got %v", len(results))
|
||||
}
|
||||
|
||||
return nil
|
||||
}); result.Error != nil || result.RowsAffected != 6 {
|
||||
t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected)
|
||||
}
|
||||
|
||||
if totalBatch != 6 {
|
||||
t.Errorf("incorrect total batch, expects: %v, got %v", 6, totalBatch)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFillSmallerStruct(t *testing.T) {
|
||||
user := User{Name: "SmallerUser", Age: 100}
|
||||
DB.Save(&user)
|
||||
|
|
Loading…
Reference in New Issue