From 4a15540504db9a7e1ecf69bb2a88bdb7097f6d1a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 18 Jan 2021 11:43:42 +0800 Subject: [PATCH] SkipDefaultTransaction skip CreateInBatches transaction --- callbacks/transaction.go | 2 +- finisher_api.go | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/callbacks/transaction.go b/callbacks/transaction.go index 3171b5bb..45c6ca11 100644 --- a/callbacks/transaction.go +++ b/callbacks/transaction.go @@ -9,7 +9,7 @@ func BeginTransaction(db *gorm.DB) { if tx := db.Begin(); tx.Error == nil { db.Statement.ConnPool = tx.Statement.ConnPool db.InstanceSet("gorm:started_transaction", true) - } else { + } else if tx.Error == gorm.ErrInvalidTransaction { tx.Error = nil } } diff --git a/finisher_api.go b/finisher_api.go index 7424a9cb..528f32be 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -33,7 +33,8 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { case reflect.Slice, reflect.Array: var rowsAffected int64 tx = db.getInstance() - tx.AddError(tx.Transaction(func(tx *DB) error { + + callFc := func(tx *DB) error { for i := 0; i < reflectValue.Len(); i += batchSize { ends := i + batchSize if ends > reflectValue.Len() { @@ -49,7 +50,14 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { rowsAffected += subtx.RowsAffected } return nil - })) + } + + if tx.SkipDefaultTransaction { + tx.AddError(callFc(tx.Session(&Session{}))) + } else { + tx.AddError(tx.Transaction(callFc)) + } + tx.RowsAffected = rowsAffected default: tx = db.getInstance()