mirror of https://github.com/go-gorm/gorm.git
Add CreateBatchSize option
This commit is contained in:
parent
41e52f343a
commit
0c12a4c360
|
@ -15,6 +15,10 @@ import (
|
||||||
|
|
||||||
// Create insert the value into database
|
// Create insert the value into database
|
||||||
func (db *DB) Create(value interface{}) (tx *DB) {
|
func (db *DB) Create(value interface{}) (tx *DB) {
|
||||||
|
if db.CreateBatchSize > 0 {
|
||||||
|
return db.CreateInBatches(value, db.CreateBatchSize)
|
||||||
|
}
|
||||||
|
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.Dest = value
|
tx.Statement.Dest = value
|
||||||
tx.callbacks.Create().Execute(tx)
|
tx.callbacks.Create().Execute(tx)
|
||||||
|
@ -27,19 +31,30 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
||||||
|
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
|
var rowsAffected int64
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
for i := 0; i < reflectValue.Len(); i += batchSize {
|
tx.AddError(tx.Transaction(func(tx *DB) error {
|
||||||
tx.AddError(tx.Transaction(func(tx *DB) error {
|
for i := 0; i < reflectValue.Len(); i += batchSize {
|
||||||
ends := i + batchSize
|
ends := i + batchSize
|
||||||
if ends > reflectValue.Len() {
|
if ends > reflectValue.Len() {
|
||||||
ends = reflectValue.Len()
|
ends = reflectValue.Len()
|
||||||
}
|
}
|
||||||
|
|
||||||
return tx.Create(reflectValue.Slice(i, ends).Interface()).Error
|
subtx := tx.getInstance()
|
||||||
}))
|
subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface()
|
||||||
}
|
subtx.callbacks.Create().Execute(subtx)
|
||||||
|
if subtx.Error != nil {
|
||||||
|
return subtx.Error
|
||||||
|
}
|
||||||
|
rowsAffected += subtx.RowsAffected
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
tx.RowsAffected = rowsAffected
|
||||||
default:
|
default:
|
||||||
return db.Create(value)
|
tx = db.getInstance()
|
||||||
|
tx.Statement.Dest = value
|
||||||
|
tx.callbacks.Create().Execute(tx)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
7
gorm.go
7
gorm.go
|
@ -38,6 +38,8 @@ type Config struct {
|
||||||
AllowGlobalUpdate bool
|
AllowGlobalUpdate bool
|
||||||
// QueryFields executes the SQL query with all fields of the table
|
// QueryFields executes the SQL query with all fields of the table
|
||||||
QueryFields bool
|
QueryFields bool
|
||||||
|
// CreateBatchSize default create batch size
|
||||||
|
CreateBatchSize int
|
||||||
|
|
||||||
// ClauseBuilders clause builder
|
// ClauseBuilders clause builder
|
||||||
ClauseBuilders map[string]clause.ClauseBuilder
|
ClauseBuilders map[string]clause.ClauseBuilder
|
||||||
|
@ -74,6 +76,7 @@ type Session struct {
|
||||||
Context context.Context
|
Context context.Context
|
||||||
Logger logger.Interface
|
Logger logger.Interface
|
||||||
NowFunc func() time.Time
|
NowFunc func() time.Time
|
||||||
|
CreateBatchSize int
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open initialize db session based on dialector
|
// Open initialize db session based on dialector
|
||||||
|
@ -161,6 +164,10 @@ func (db *DB) Session(config *Session) *DB {
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config.CreateBatchSize > 0 {
|
||||||
|
tx.Config.CreateBatchSize = config.CreateBatchSize
|
||||||
|
}
|
||||||
|
|
||||||
if config.SkipDefaultTransaction {
|
if config.SkipDefaultTransaction {
|
||||||
tx.Config.SkipDefaultTransaction = true
|
tx.Config.SkipDefaultTransaction = true
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,7 +50,39 @@ func TestCreateInBatches(t *testing.T) {
|
||||||
*GetUser("create_in_batches_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}),
|
*GetUser("create_in_batches_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}),
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.CreateInBatches(&users, 2)
|
result := DB.CreateInBatches(&users, 2)
|
||||||
|
if result.RowsAffected != int64(len(users)) {
|
||||||
|
t.Errorf("affected rows should be %v, but got %v", len(users), result.RowsAffected)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, user := range users {
|
||||||
|
if user.ID == 0 {
|
||||||
|
t.Fatalf("failed to fill user's ID, got %v", user.ID)
|
||||||
|
} else {
|
||||||
|
var newUser User
|
||||||
|
if err := DB.Where("id = ?", user.ID).Preload(clause.Associations).First(&newUser).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when query: %v", err)
|
||||||
|
} else {
|
||||||
|
CheckUser(t, newUser, user)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateInBatchesWithDefaultSize(t *testing.T) {
|
||||||
|
users := []User{
|
||||||
|
*GetUser("create_with_default_batch_size_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}),
|
||||||
|
*GetUser("create_with_default_batch_sizs_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}),
|
||||||
|
*GetUser("create_with_default_batch_sizs_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}),
|
||||||
|
*GetUser("create_with_default_batch_sizs_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}),
|
||||||
|
*GetUser("create_with_default_batch_sizs_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}),
|
||||||
|
*GetUser("create_with_default_batch_sizs_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}),
|
||||||
|
}
|
||||||
|
|
||||||
|
result := DB.Session(&gorm.Session{CreateBatchSize: 2}).Create(&users)
|
||||||
|
if result.RowsAffected != int64(len(users)) {
|
||||||
|
t.Errorf("affected rows should be %v, but got %v", len(users), result.RowsAffected)
|
||||||
|
}
|
||||||
|
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
if user.ID == 0 {
|
if user.ID == 0 {
|
||||||
|
|
|
@ -9,9 +9,9 @@ require (
|
||||||
github.com/stretchr/testify v1.5.1
|
github.com/stretchr/testify v1.5.1
|
||||||
gorm.io/driver/mysql v1.0.3
|
gorm.io/driver/mysql v1.0.3
|
||||||
gorm.io/driver/postgres v1.0.5
|
gorm.io/driver/postgres v1.0.5
|
||||||
gorm.io/driver/sqlite v1.1.3
|
gorm.io/driver/sqlite v1.1.4
|
||||||
gorm.io/driver/sqlserver v1.0.5
|
gorm.io/driver/sqlserver v1.0.5
|
||||||
gorm.io/gorm v1.20.5
|
gorm.io/gorm v1.20.7
|
||||||
)
|
)
|
||||||
|
|
||||||
replace gorm.io/gorm => ../
|
replace gorm.io/gorm => ../
|
||||||
|
|
Loading…
Reference in New Issue