From 0df42e9afc15544a6927e4393b36f2ebd32a561e Mon Sep 17 00:00:00 2001 From: kinggo <30891428+longlihale@users.noreply.github.com> Date: Fri, 7 Jan 2022 09:49:56 +0800 Subject: [PATCH] feat: add `Connection` to execute multiple commands in a single connection; (#4982) --- finisher_api.go | 24 ++++++++++++++++++++ tests/connection_test.go | 48 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 tests/connection_test.go diff --git a/finisher_api.go b/finisher_api.go index d38d60b7..dd0eb83a 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -515,6 +515,30 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { return tx.Error } +// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed. +func (db *DB) Connection(fc func(tx *DB) error) (err error) { + if db.Error != nil { + return db.Error + } + + tx := db.getInstance() + sqlDB, err := tx.DB() + if err != nil { + return + } + + conn, err := sqlDB.Conn(tx.Statement.Context) + if err != nil { + return + } + + defer conn.Close() + tx.Statement.ConnPool = conn + err = fc(tx) + + return +} + // Transaction start a transaction as a block, return error will rollback, otherwise to commit. func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { panicked := true diff --git a/tests/connection_test.go b/tests/connection_test.go new file mode 100644 index 00000000..9b5dcd05 --- /dev/null +++ b/tests/connection_test.go @@ -0,0 +1,48 @@ +package tests_test + +import ( + "fmt" + "gorm.io/driver/mysql" + "gorm.io/gorm" + "testing" +) + +func TestWithSingleConnection(t *testing.T) { + + var expectedName = "test" + var actualName string + + setSQL, getSQL := getSetSQL(DB.Dialector.Name()) + if len(setSQL) == 0 || len(getSQL) == 0 { + return + } + + err := DB.Connection(func(tx *gorm.DB) error { + if err := tx.Exec(setSQL, expectedName).Error; err != nil { + return err + } + + if err := tx.Raw(getSQL).Scan(&actualName).Error; err != nil { + return err + } + return nil + }) + + if err != nil { + t.Errorf(fmt.Sprintf("WithSingleConnection should work, but got err %v", err)) + } + + if actualName != expectedName { + t.Errorf("WithSingleConnection() method should get correct value, expect: %v, got %v", expectedName, actualName) + } + +} + +func getSetSQL(driverName string) (string, string) { + switch driverName { + case mysql.Dialector{}.Name(): + return "SET @testName := ?", "SELECT @testName" + default: + return "", "" + } +}