mirror of https://github.com/go-gorm/gorm.git
fix: use preparestmt in trasaction will use new conn, close #5508
This commit is contained in:
parent
73bc53f061
commit
12237454ed
16
gorm.go
16
gorm.go
|
@ -248,10 +248,18 @@ func (db *DB) Session(config *Session) *DB {
|
||||||
if config.PrepareStmt {
|
if config.PrepareStmt {
|
||||||
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
|
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
|
||||||
preparedStmt := v.(*PreparedStmtDB)
|
preparedStmt := v.(*PreparedStmtDB)
|
||||||
tx.Statement.ConnPool = &PreparedStmtDB{
|
switch t := tx.Statement.ConnPool.(type) {
|
||||||
ConnPool: db.Config.ConnPool,
|
case Tx:
|
||||||
Mux: preparedStmt.Mux,
|
tx.Statement.ConnPool = &PreparedStmtTX{
|
||||||
Stmts: preparedStmt.Stmts,
|
Tx: t,
|
||||||
|
PreparedStmtDB: preparedStmt,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
tx.Statement.ConnPool = &PreparedStmtDB{
|
||||||
|
ConnPool: db.Config.ConnPool,
|
||||||
|
Mux: preparedStmt.Mux,
|
||||||
|
Stmts: preparedStmt.Stmts,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
txConfig.ConnPool = tx.Statement.ConnPool
|
txConfig.ConnPool = tx.Statement.ConnPool
|
||||||
txConfig.PrepareStmt = true
|
txConfig.PrepareStmt = true
|
||||||
|
|
|
@ -2,6 +2,7 @@ package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -88,3 +89,19 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
|
||||||
}
|
}
|
||||||
tx2.Commit()
|
tx2.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPreparedStmtInTransaction(t *testing.T) {
|
||||||
|
user := User{Name: "jinzhu"}
|
||||||
|
|
||||||
|
if err := DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
tx.Session(&gorm.Session{PrepareStmt: true}).Create(&user)
|
||||||
|
return errors.New("test")
|
||||||
|
}); err == nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result User
|
||||||
|
if err := DB.First(&result, user.ID).Error; err == nil {
|
||||||
|
t.Errorf("Failed, got error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue