forked from mirror/gorm
Fix FindInBatches to modify the query conditions, close #3734
This commit is contained in:
parent
a8db54afd6
commit
320f33061c
|
@ -140,13 +140,18 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindInBatches find records in batches
|
// FindInBatches find records in batches
|
||||||
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) (tx *DB) {
|
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
|
||||||
tx = db.Session(&Session{WithConditions: true})
|
var (
|
||||||
rowsAffected := int64(0)
|
tx = db.Order(clause.OrderByColumn{
|
||||||
batch := 0
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||||
|
}).Session(&Session{WithConditions: true})
|
||||||
|
queryDB = tx
|
||||||
|
rowsAffected int64
|
||||||
|
batch int
|
||||||
|
)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
result := tx.Limit(batchSize).Offset(batch * batchSize).Find(dest)
|
result := queryDB.Limit(batchSize).Find(dest)
|
||||||
rowsAffected += result.RowsAffected
|
rowsAffected += result.RowsAffected
|
||||||
batch++
|
batch++
|
||||||
|
|
||||||
|
@ -156,11 +161,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
||||||
|
|
||||||
if tx.Error != nil || int(result.RowsAffected) < batchSize {
|
if tx.Error != nil || int(result.RowsAffected) < batchSize {
|
||||||
break
|
break
|
||||||
|
} else {
|
||||||
|
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
|
||||||
|
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1))
|
||||||
|
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tx.RowsAffected = rowsAffected
|
tx.RowsAffected = rowsAffected
|
||||||
return
|
return tx
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *DB) assignInterfacesToValue(values ...interface{}) {
|
func (tx *DB) assignInterfacesToValue(values ...interface{}) {
|
||||||
|
|
|
@ -260,6 +260,13 @@ func TestFindInBatches(t *testing.T) {
|
||||||
if result := DB.Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error {
|
if result := DB.Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error {
|
||||||
totalBatch += batch
|
totalBatch += batch
|
||||||
|
|
||||||
|
for idx := range results {
|
||||||
|
results[idx].Name = results[idx].Name + "_new"
|
||||||
|
}
|
||||||
|
if err := tx.Save(results).Error; err != nil {
|
||||||
|
t.Errorf("failed to save users, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if tx.RowsAffected != 2 {
|
if tx.RowsAffected != 2 {
|
||||||
t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected)
|
t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected)
|
||||||
}
|
}
|
||||||
|
@ -276,6 +283,12 @@ func TestFindInBatches(t *testing.T) {
|
||||||
if totalBatch != 6 {
|
if totalBatch != 6 {
|
||||||
t.Errorf("incorrect total batch, expects: %v, got %v", 6, totalBatch)
|
t.Errorf("incorrect total batch, expects: %v, got %v", 6, totalBatch)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var count int64
|
||||||
|
DB.Model(&User{}).Where("name = ?", "find_in_batches_new").Count(&count)
|
||||||
|
if count != 6 {
|
||||||
|
t.Errorf("incorrect count after update, expects: %v, got %v", 6, count)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFillSmallerStruct(t *testing.T) {
|
func TestFillSmallerStruct(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue