forked from mirror/gorm
feat: add `Connection` to execute multiple commands in a single connection; (#4982)
This commit is contained in:
parent
f757b8fdc9
commit
0df42e9afc
|
@ -515,6 +515,30 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
|||
return tx.Error
|
||||
}
|
||||
|
||||
// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed.
|
||||
func (db *DB) Connection(fc func(tx *DB) error) (err error) {
|
||||
if db.Error != nil {
|
||||
return db.Error
|
||||
}
|
||||
|
||||
tx := db.getInstance()
|
||||
sqlDB, err := tx.DB()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := sqlDB.Conn(tx.Statement.Context)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer conn.Close()
|
||||
tx.Statement.ConnPool = conn
|
||||
err = fc(tx)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
|
||||
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
|
||||
panicked := true
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWithSingleConnection(t *testing.T) {
|
||||
|
||||
var expectedName = "test"
|
||||
var actualName string
|
||||
|
||||
setSQL, getSQL := getSetSQL(DB.Dialector.Name())
|
||||
if len(setSQL) == 0 || len(getSQL) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err := DB.Connection(func(tx *gorm.DB) error {
|
||||
if err := tx.Exec(setSQL, expectedName).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Raw(getSQL).Scan(&actualName).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf(fmt.Sprintf("WithSingleConnection should work, but got err %v", err))
|
||||
}
|
||||
|
||||
if actualName != expectedName {
|
||||
t.Errorf("WithSingleConnection() method should get correct value, expect: %v, got %v", expectedName, actualName)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func getSetSQL(driverName string) (string, string) {
|
||||
switch driverName {
|
||||
case mysql.Dialector{}.Name():
|
||||
return "SET @testName := ?", "SELECT @testName"
|
||||
default:
|
||||
return "", ""
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue