diff --git a/finisher_api.go b/finisher_api.go index f2aed8da..fc7a73be 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -15,6 +15,10 @@ import ( // Create insert the value into database func (db *DB) Create(value interface{}) (tx *DB) { + if db.CreateBatchSize > 0 { + return db.CreateInBatches(value, db.CreateBatchSize) + } + tx = db.getInstance() tx.Statement.Dest = value tx.callbacks.Create().Execute(tx) @@ -27,19 +31,30 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: + var rowsAffected int64 tx = db.getInstance() - for i := 0; i < reflectValue.Len(); i += batchSize { - tx.AddError(tx.Transaction(func(tx *DB) error { + tx.AddError(tx.Transaction(func(tx *DB) error { + for i := 0; i < reflectValue.Len(); i += batchSize { ends := i + batchSize if ends > reflectValue.Len() { ends = reflectValue.Len() } - return tx.Create(reflectValue.Slice(i, ends).Interface()).Error - })) - } + subtx := tx.getInstance() + subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface() + subtx.callbacks.Create().Execute(subtx) + if subtx.Error != nil { + return subtx.Error + } + rowsAffected += subtx.RowsAffected + } + return nil + })) + tx.RowsAffected = rowsAffected default: - return db.Create(value) + tx = db.getInstance() + tx.Statement.Dest = value + tx.callbacks.Create().Execute(tx) } return } diff --git a/gorm.go b/gorm.go index 1947b4df..ae1cf2c9 100644 --- a/gorm.go +++ b/gorm.go @@ -38,6 +38,8 @@ type Config struct { AllowGlobalUpdate bool // QueryFields executes the SQL query with all fields of the table QueryFields bool + // CreateBatchSize default create batch size + CreateBatchSize int // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder @@ -74,6 +76,7 @@ type Session struct { Context context.Context Logger logger.Interface NowFunc func() time.Time + CreateBatchSize int } // Open initialize db session based on dialector @@ -161,6 +164,10 @@ func (db *DB) Session(config *Session) *DB { } ) + if config.CreateBatchSize > 0 { + tx.Config.CreateBatchSize = config.CreateBatchSize + } + if config.SkipDefaultTransaction { tx.Config.SkipDefaultTransaction = true } diff --git a/tests/create_test.go b/tests/create_test.go index 8d005d0b..170c8546 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -50,7 +50,39 @@ func TestCreateInBatches(t *testing.T) { *GetUser("create_in_batches_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), } - DB.CreateInBatches(&users, 2) + result := DB.CreateInBatches(&users, 2) + if result.RowsAffected != int64(len(users)) { + t.Errorf("affected rows should be %v, but got %v", len(users), result.RowsAffected) + } + + for _, user := range users { + if user.ID == 0 { + t.Fatalf("failed to fill user's ID, got %v", user.ID) + } else { + var newUser User + if err := DB.Where("id = ?", user.ID).Preload(clause.Associations).First(&newUser).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + CheckUser(t, newUser, user) + } + } + } +} + +func TestCreateInBatchesWithDefaultSize(t *testing.T) { + users := []User{ + *GetUser("create_with_default_batch_size_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}), + *GetUser("create_with_default_batch_sizs_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}), + *GetUser("create_with_default_batch_sizs_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}), + *GetUser("create_with_default_batch_sizs_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}), + *GetUser("create_with_default_batch_sizs_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}), + *GetUser("create_with_default_batch_sizs_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), + } + + result := DB.Session(&gorm.Session{CreateBatchSize: 2}).Create(&users) + if result.RowsAffected != int64(len(users)) { + t.Errorf("affected rows should be %v, but got %v", len(users), result.RowsAffected) + } for _, user := range users { if user.ID == 0 { diff --git a/tests/go.mod b/tests/go.mod index fa293987..03283a53 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,9 +9,9 @@ require ( github.com/stretchr/testify v1.5.1 gorm.io/driver/mysql v1.0.3 gorm.io/driver/postgres v1.0.5 - gorm.io/driver/sqlite v1.1.3 + gorm.io/driver/sqlite v1.1.4 gorm.io/driver/sqlserver v1.0.5 - gorm.io/gorm v1.20.5 + gorm.io/gorm v1.20.7 ) replace gorm.io/gorm => ../