diff --git a/gorm.go b/gorm.go index 1f1dac21..81b6e2af 100644 --- a/gorm.go +++ b/gorm.go @@ -248,10 +248,18 @@ func (db *DB) Session(config *Session) *DB { if config.PrepareStmt { if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { preparedStmt := v.(*PreparedStmtDB) - tx.Statement.ConnPool = &PreparedStmtDB{ - ConnPool: db.Config.ConnPool, - Mux: preparedStmt.Mux, - Stmts: preparedStmt.Stmts, + switch t := tx.Statement.ConnPool.(type) { + case Tx: + tx.Statement.ConnPool = &PreparedStmtTX{ + Tx: t, + PreparedStmtDB: preparedStmt, + } + default: + tx.Statement.ConnPool = &PreparedStmtDB{ + ConnPool: db.Config.ConnPool, + Mux: preparedStmt.Mux, + Stmts: preparedStmt.Stmts, + } } txConfig.ConnPool = tx.Statement.ConnPool txConfig.PrepareStmt = true diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index 8730e547..86e3630d 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -2,6 +2,7 @@ package tests_test import ( "context" + "errors" "testing" "time" @@ -88,3 +89,19 @@ func TestPreparedStmtFromTransaction(t *testing.T) { } tx2.Commit() } + +func TestPreparedStmtInTransaction(t *testing.T) { + user := User{Name: "jinzhu"} + + if err := DB.Transaction(func(tx *gorm.DB) error { + tx.Session(&gorm.Session{PrepareStmt: true}).Create(&user) + return errors.New("test") + }); err == nil { + t.Error(err) + } + + var result User + if err := DB.First(&result, user.ID).Error; err == nil { + t.Errorf("Failed, got error: %v", err) + } +}