feat: add `Connection` to execute multiple commands in a single connection; (#4982)

This commit is contained in:
kinggo 2022-01-07 09:49:56 +08:00 committed by GitHub
parent f757b8fdc9
commit 0df42e9afc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 0 deletions

View File

@ -515,6 +515,30 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
return tx.Error return tx.Error
} }
// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed.
func (db *DB) Connection(fc func(tx *DB) error) (err error) {
if db.Error != nil {
return db.Error
}
tx := db.getInstance()
sqlDB, err := tx.DB()
if err != nil {
return
}
conn, err := sqlDB.Conn(tx.Statement.Context)
if err != nil {
return
}
defer conn.Close()
tx.Statement.ConnPool = conn
err = fc(tx)
return
}
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. // Transaction start a transaction as a block, return error will rollback, otherwise to commit.
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
panicked := true panicked := true

48
tests/connection_test.go Normal file
View File

@ -0,0 +1,48 @@
package tests_test
import (
"fmt"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"testing"
)
func TestWithSingleConnection(t *testing.T) {
var expectedName = "test"
var actualName string
setSQL, getSQL := getSetSQL(DB.Dialector.Name())
if len(setSQL) == 0 || len(getSQL) == 0 {
return
}
err := DB.Connection(func(tx *gorm.DB) error {
if err := tx.Exec(setSQL, expectedName).Error; err != nil {
return err
}
if err := tx.Raw(getSQL).Scan(&actualName).Error; err != nil {
return err
}
return nil
})
if err != nil {
t.Errorf(fmt.Sprintf("WithSingleConnection should work, but got err %v", err))
}
if actualName != expectedName {
t.Errorf("WithSingleConnection() method should get correct value, expect: %v, got %v", expectedName, actualName)
}
}
func getSetSQL(driverName string) (string, string) {
switch driverName {
case mysql.Dialector{}.Name():
return "SET @testName := ?", "SELECT @testName"
default:
return "", ""
}
}