forked from mirror/gorm
Fix FindInBatches to modify the query conditions, close #3734
This commit is contained in:
parent
a8db54afd6
commit
320f33061c
|
@ -140,13 +140,18 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||
}
|
||||
|
||||
// FindInBatches find records in batches
|
||||
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) (tx *DB) {
|
||||
tx = db.Session(&Session{WithConditions: true})
|
||||
rowsAffected := int64(0)
|
||||
batch := 0
|
||||
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
|
||||
var (
|
||||
tx = db.Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
}).Session(&Session{WithConditions: true})
|
||||
queryDB = tx
|
||||
rowsAffected int64
|
||||
batch int
|
||||
)
|
||||
|
||||
for {
|
||||
result := tx.Limit(batchSize).Offset(batch * batchSize).Find(dest)
|
||||
result := queryDB.Limit(batchSize).Find(dest)
|
||||
rowsAffected += result.RowsAffected
|
||||
batch++
|
||||
|
||||
|
@ -156,11 +161,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
|||
|
||||
if tx.Error != nil || int(result.RowsAffected) < batchSize {
|
||||
break
|
||||
} else {
|
||||
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
|
||||
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1))
|
||||
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
|
||||
}
|
||||
}
|
||||
|
||||
tx.RowsAffected = rowsAffected
|
||||
return
|
||||
return tx
|
||||
}
|
||||
|
||||
func (tx *DB) assignInterfacesToValue(values ...interface{}) {
|
||||
|
|
|
@ -260,6 +260,13 @@ func TestFindInBatches(t *testing.T) {
|
|||
if result := DB.Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error {
|
||||
totalBatch += batch
|
||||
|
||||
for idx := range results {
|
||||
results[idx].Name = results[idx].Name + "_new"
|
||||
}
|
||||
if err := tx.Save(results).Error; err != nil {
|
||||
t.Errorf("failed to save users, got error %v", err)
|
||||
}
|
||||
|
||||
if tx.RowsAffected != 2 {
|
||||
t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected)
|
||||
}
|
||||
|
@ -276,6 +283,12 @@ func TestFindInBatches(t *testing.T) {
|
|||
if totalBatch != 6 {
|
||||
t.Errorf("incorrect total batch, expects: %v, got %v", 6, totalBatch)
|
||||
}
|
||||
|
||||
var count int64
|
||||
DB.Model(&User{}).Where("name = ?", "find_in_batches_new").Count(&count)
|
||||
if count != 6 {
|
||||
t.Errorf("incorrect count after update, expects: %v, got %v", 6, count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFillSmallerStruct(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue