diff --git a/gorm.go b/gorm.go index 355a0e55..88885407 100644 --- a/gorm.go +++ b/gorm.go @@ -126,7 +126,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { preparedStmt := &PreparedStmtDB{ ConnPool: db.ConnPool, - Stmts: map[string]*sql.Stmt{}, + Stmts: map[string]Stmt{}, Mux: &sync.RWMutex{}, PreparedSQL: make([]string, 0, 100), } diff --git a/prepare_stmt.go b/prepare_stmt.go index dbf21118..78a8adb4 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -6,8 +6,13 @@ import ( "sync" ) +type Stmt struct { + *sql.Stmt + Transaction bool +} + type PreparedStmtDB struct { - Stmts map[string]*sql.Stmt + Stmts map[string]Stmt PreparedSQL []string Mux *sync.RWMutex ConnPool @@ -25,9 +30,9 @@ func (db *PreparedStmtDB) Close() { db.Mux.Unlock() } -func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, query string) (*sql.Stmt, error) { +func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { db.Mux.RLock() - if stmt, ok := db.Stmts[query]; ok { + if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { db.Mux.RUnlock() return stmt, nil } @@ -35,19 +40,21 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, query stri db.Mux.Lock() // double check - if stmt, ok := db.Stmts[query]; ok { + if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { db.Mux.Unlock() return stmt, nil + } else if ok { + stmt.Close() } stmt, err := conn.PrepareContext(ctx, query) if err == nil { - db.Stmts[query] = stmt + db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction} db.PreparedSQL = append(db.PreparedSQL, query) } db.Mux.Unlock() - return stmt, err + return db.Stmts[query], err } func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { @@ -59,7 +66,7 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn } func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { - stmt, err := db.prepare(ctx, db.ConnPool, query) + stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { result, err = stmt.ExecContext(ctx, args...) if err != nil { @@ -73,7 +80,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. } func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { - stmt, err := db.prepare(ctx, db.ConnPool, query) + stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { rows, err = stmt.QueryContext(ctx, args...) if err != nil { @@ -87,7 +94,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . } func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - stmt, err := db.prepare(ctx, db.ConnPool, query) + stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { return stmt.QueryRowContext(ctx, args...) } @@ -114,9 +121,9 @@ func (tx *PreparedStmtTX) Rollback() error { } func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { - stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { - result, err = tx.Tx.StmtContext(ctx, stmt).ExecContext(ctx, args...) + result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() stmt.Close() @@ -128,9 +135,9 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. } func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { - stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { - rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...) + rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() stmt.Close() @@ -142,9 +149,9 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . } func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { - return tx.Tx.StmtContext(ctx, stmt).QueryRowContext(ctx, args...) + return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...) } return &sql.Row{} } diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index 6b10b6dc..8730e547 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -50,3 +50,41 @@ func TestPreparedStmt(t *testing.T) { t.Fatalf("no error should happen but got %v", err) } } + +func TestPreparedStmtFromTransaction(t *testing.T) { + db := DB.Session(&gorm.Session{PrepareStmt: true, SkipDefaultTransaction: true}) + + tx := db.Begin() + defer func() { + if r := recover(); r != nil { + tx.Rollback() + } + }() + if err := tx.Error; err != nil { + t.Errorf("Failed to start transaction, got error %v\n", err) + } + + if err := tx.Where("name=?", "zzjin").Delete(&User{}).Error; err != nil { + tx.Rollback() + t.Errorf("Failed to run one transaction, got error %v\n", err) + } + + if err := tx.Create(&User{Name: "zzjin"}).Error; err != nil { + tx.Rollback() + t.Errorf("Failed to run one transaction, got error %v\n", err) + } + + if err := tx.Commit().Error; err != nil { + t.Errorf("Failed to commit transaction, got error %v\n", err) + } + + if result := db.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 1 { + t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) + } + + tx2 := db.Begin() + if result := tx2.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 0 { + t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) + } + tx2.Commit() +}