Add Transaction Support

This commit is contained in:
Jinzhu 2013-11-11 13:16:08 +08:00
parent 50a1b6e3e5
commit d550315548
7 changed files with 115 additions and 7 deletions

View File

@ -9,6 +9,7 @@ Yet Another ORM library for Go, aims for developer friendly
* Callbacks (before/after create/save/update/delete)
* Soft Delete
* Auto Migration
* Transaction
* Every feature comes with tests
* Convention Over Configuration
* Developer Friendly
@ -617,6 +618,24 @@ func (u User) TableName() string {
}
```
## Transaction
```go
tx := db.Begin()
user := User{Name: "transcation"}
tx.Save(&u)
tx.Update("age": 90)
// do whatever
// rollback
tx.Rollback()
// commit
tx.Commit()
```
## Run Raw SQl
```go
@ -668,7 +687,6 @@ db.Where("email = ?", "x@example.org").Attrs(User{FromIp: "111.111.111.111"}).Fi
```
## TODO
* Transaction
* Logger
* Join, Having, Group
* Index, Unique, Valiations

View File

@ -1,14 +1,13 @@
package gorm
import (
"database/sql"
"errors"
"fmt"
"regexp"
)
type Chain struct {
db *sql.DB
db sql_common
driver string
value interface{}
@ -243,6 +242,36 @@ func (s *Chain) Related(value interface{}, foreign_keys ...string) *Chain {
return s
}
func (s *Chain) Begin() *Chain {
if db, ok := s.db.(sql_db); ok {
tx, err := db.Begin()
s.db = interface{}(tx).(sql_common)
s.err(err)
} else {
s.err(errors.New("Can't start a transaction."))
}
return s
}
func (s *Chain) Commit() *Chain {
if db, ok := s.db.(sql_tx); ok {
s.err(db.Commit())
} else {
s.err(errors.New("Commit is not supported, no database transaction found."))
}
return s
}
func (s *Chain) Rollback() *Chain {
if db, ok := s.db.(sql_tx); ok {
s.err(db.Rollback())
} else {
s.err(errors.New("Rollback is not supported, no database transaction found."))
}
return s
}
func (s *Chain) validSql(str string) (result bool) {
result = regexp.MustCompile("^\\s*[\\w\\s,.*()]*\\s*$").MatchString(str)
if !result {

20
db.go Normal file
View File

@ -0,0 +1,20 @@
package gorm
import "database/sql"
type sql_common interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
type sql_db interface {
Begin() (*sql.Tx, error)
SetMaxIdleConns(n int)
}
type sql_tx interface {
Commit() error
Rollback() error
}

2
do.go
View File

@ -14,7 +14,7 @@ import (
type Do struct {
chain *Chain
db *sql.DB
db sql_common
driver string
guessedTableName string
specifiedTableName string

View File

@ -1329,6 +1329,41 @@ func TestSqlNullValue(t *testing.T) {
}
}
func TestTransaction(t *testing.T) {
d := db.Begin()
u := User{Name: "transcation"}
if err := d.Save(&u).Error; err != nil {
t.Errorf("No error should raise, but got", err)
}
if err := d.First(&User{}, "name = ?", "transcation").Error; err != nil {
t.Errorf("Should find saved record, but got", err)
}
d.Rollback()
if err := d.First(&User{}, "name = ?", "transcation").Error; err == nil {
t.Errorf("Should not find record after rollback")
}
d2 := db.Begin()
u2 := User{Name: "transcation-2"}
if err := d2.Save(&u2).Error; err != nil {
t.Errorf("No error should raise, but got", err)
}
d2.Update("age", 90)
if err := d2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
t.Errorf("Should find saved record, but got", err)
}
d2.Commit()
if err := db.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
t.Errorf("Should be able to find committed record")
}
}
func BenchmarkGorm(b *testing.B) {
for x := 0; x < b.N; x++ {
email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()}

View File

@ -15,7 +15,7 @@ type Logger interface {
}
func print(level string, v ...interface{}) {
if logger_disabled {
if logger_disabled && level != "debug" {
return
}

10
main.go
View File

@ -5,7 +5,7 @@ import "database/sql"
var singularTableName bool
type DB struct {
db *sql.DB
db sql_common
driver string
}
@ -16,7 +16,9 @@ func Open(driver, source string) (db DB, err error) {
}
func (s *DB) SetPool(n int) {
s.db.SetMaxIdleConns(n)
if db, ok := s.db.(sql_db); ok {
db.SetMaxIdleConns(n)
}
}
func (s *DB) SetLogger(l interface{}) {
@ -122,3 +124,7 @@ func (s *DB) DropTable(value interface{}) *Chain {
func (s *DB) AutoMigrate(value interface{}) *Chain {
return s.buildChain().AutoMigrate(value)
}
func (s *DB) Begin() *Chain {
return s.buildChain().Begin()
}