forked from mirror/gorm
fix: FindInBatches with offset limit (#5255)
* fix: FindInBatches with offset limit * fix: break first * fix: FindInBatches Limit zero
This commit is contained in:
parent
e0ed3ce400
commit
b49ae84780
|
@ -181,6 +181,21 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
||||||
batch int
|
batch int
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// user specified offset or limit
|
||||||
|
var totalSize int
|
||||||
|
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
|
||||||
|
if limit, ok := c.Expression.(clause.Limit); ok {
|
||||||
|
totalSize = limit.Limit
|
||||||
|
|
||||||
|
if totalSize > 0 && batchSize > totalSize {
|
||||||
|
batchSize = totalSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// reset to offset to 0 in next batch
|
||||||
|
tx = tx.Offset(-1).Session(&Session{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
result := queryDB.Limit(batchSize).Find(dest)
|
result := queryDB.Limit(batchSize).Find(dest)
|
||||||
rowsAffected += result.RowsAffected
|
rowsAffected += result.RowsAffected
|
||||||
|
@ -196,6 +211,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if totalSize > 0 {
|
||||||
|
if totalSize <= int(rowsAffected) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if totalSize/batchSize == batch {
|
||||||
|
batchSize = totalSize % batchSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Optimize for-break
|
// Optimize for-break
|
||||||
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
|
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
|
||||||
if result.Statement.Schema.PrioritizedPrimaryField == nil {
|
if result.Statement.Schema.PrioritizedPrimaryField == nil {
|
||||||
|
|
|
@ -292,6 +292,68 @@ func TestFindInBatches(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFindInBatchesWithOffsetLimit(t *testing.T) {
|
||||||
|
users := []User{
|
||||||
|
*GetUser("find_in_batches_with_offset_limit", Config{}),
|
||||||
|
*GetUser("find_in_batches_with_offset_limit", Config{}),
|
||||||
|
*GetUser("find_in_batches_with_offset_limit", Config{}),
|
||||||
|
*GetUser("find_in_batches_with_offset_limit", Config{}),
|
||||||
|
*GetUser("find_in_batches_with_offset_limit", Config{}),
|
||||||
|
*GetUser("find_in_batches_with_offset_limit", Config{}),
|
||||||
|
*GetUser("find_in_batches_with_offset_limit", Config{}),
|
||||||
|
*GetUser("find_in_batches_with_offset_limit", Config{}),
|
||||||
|
*GetUser("find_in_batches_with_offset_limit", Config{}),
|
||||||
|
*GetUser("find_in_batches_with_offset_limit", Config{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Create(&users)
|
||||||
|
|
||||||
|
var (
|
||||||
|
sub, results []User
|
||||||
|
lastBatch int
|
||||||
|
)
|
||||||
|
|
||||||
|
// offset limit
|
||||||
|
if result := DB.Offset(3).Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub, 2, func(tx *gorm.DB, batch int) error {
|
||||||
|
results = append(results, sub...)
|
||||||
|
lastBatch = batch
|
||||||
|
return nil
|
||||||
|
}); result.Error != nil || result.RowsAffected != 5 {
|
||||||
|
t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected)
|
||||||
|
}
|
||||||
|
if lastBatch != 3 {
|
||||||
|
t.Fatalf("incorrect last batch, expected: %v, got: %v", 3, lastBatch)
|
||||||
|
}
|
||||||
|
|
||||||
|
targetUsers := users[3:8]
|
||||||
|
for i := 0; i < len(targetUsers); i++ {
|
||||||
|
AssertEqual(t, results[i], targetUsers[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
var sub1 []User
|
||||||
|
// limit < batchSize
|
||||||
|
if result := DB.Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub1, 10, func(tx *gorm.DB, batch int) error {
|
||||||
|
return nil
|
||||||
|
}); result.Error != nil || result.RowsAffected != 5 {
|
||||||
|
t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected)
|
||||||
|
}
|
||||||
|
|
||||||
|
var sub2 []User
|
||||||
|
// only offset
|
||||||
|
if result := DB.Offset(3).Where("name = ?", users[0].Name).FindInBatches(&sub2, 2, func(tx *gorm.DB, batch int) error {
|
||||||
|
return nil
|
||||||
|
}); result.Error != nil || result.RowsAffected != 7 {
|
||||||
|
t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected)
|
||||||
|
}
|
||||||
|
|
||||||
|
var sub3 []User
|
||||||
|
if result := DB.Limit(4).Where("name = ?", users[0].Name).FindInBatches(&sub3, 2, func(tx *gorm.DB, batch int) error {
|
||||||
|
return nil
|
||||||
|
}); result.Error != nil || result.RowsAffected != 4 {
|
||||||
|
t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFindInBatchesWithError(t *testing.T) {
|
func TestFindInBatchesWithError(t *testing.T) {
|
||||||
if name := DB.Dialector.Name(); name == "sqlserver" {
|
if name := DB.Dialector.Name(); name == "sqlserver" {
|
||||||
t.Skip("skip sqlserver due to it will raise data race for invalid sql")
|
t.Skip("skip sqlserver due to it will raise data race for invalid sql")
|
||||||
|
|
Loading…
Reference in New Issue