Fix FindInBatches to modify the query conditions, close #3734

This commit is contained in:
Jinzhu 2020-11-17 11:19:04 +08:00
parent a8db54afd6
commit 320f33061c
2 changed files with 28 additions and 6 deletions

View File

@ -140,13 +140,18 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
} }
// FindInBatches find records in batches // FindInBatches find records in batches
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) (tx *DB) { func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
tx = db.Session(&Session{WithConditions: true}) var (
rowsAffected := int64(0) tx = db.Order(clause.OrderByColumn{
batch := 0 Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
}).Session(&Session{WithConditions: true})
queryDB = tx
rowsAffected int64
batch int
)
for { for {
result := tx.Limit(batchSize).Offset(batch * batchSize).Find(dest) result := queryDB.Limit(batchSize).Find(dest)
rowsAffected += result.RowsAffected rowsAffected += result.RowsAffected
batch++ batch++
@ -156,11 +161,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
if tx.Error != nil || int(result.RowsAffected) < batchSize { if tx.Error != nil || int(result.RowsAffected) < batchSize {
break break
} else {
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1))
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
} }
} }
tx.RowsAffected = rowsAffected tx.RowsAffected = rowsAffected
return return tx
} }
func (tx *DB) assignInterfacesToValue(values ...interface{}) { func (tx *DB) assignInterfacesToValue(values ...interface{}) {

View File

@ -260,6 +260,13 @@ func TestFindInBatches(t *testing.T) {
if result := DB.Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { if result := DB.Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error {
totalBatch += batch totalBatch += batch
for idx := range results {
results[idx].Name = results[idx].Name + "_new"
}
if err := tx.Save(results).Error; err != nil {
t.Errorf("failed to save users, got error %v", err)
}
if tx.RowsAffected != 2 { if tx.RowsAffected != 2 {
t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected) t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected)
} }
@ -276,6 +283,12 @@ func TestFindInBatches(t *testing.T) {
if totalBatch != 6 { if totalBatch != 6 {
t.Errorf("incorrect total batch, expects: %v, got %v", 6, totalBatch) t.Errorf("incorrect total batch, expects: %v, got %v", 6, totalBatch)
} }
var count int64
DB.Model(&User{}).Where("name = ?", "find_in_batches_new").Count(&count)
if count != 6 {
t.Errorf("incorrect count after update, expects: %v, got %v", 6, count)
}
} }
func TestFillSmallerStruct(t *testing.T) { func TestFillSmallerStruct(t *testing.T) {