forked from mirror/gorm
Add Transaction Support
This commit is contained in:
parent
50a1b6e3e5
commit
d550315548
20
README.md
20
README.md
|
@ -9,6 +9,7 @@ Yet Another ORM library for Go, aims for developer friendly
|
|||
* Callbacks (before/after create/save/update/delete)
|
||||
* Soft Delete
|
||||
* Auto Migration
|
||||
* Transaction
|
||||
* Every feature comes with tests
|
||||
* Convention Over Configuration
|
||||
* Developer Friendly
|
||||
|
@ -617,6 +618,24 @@ func (u User) TableName() string {
|
|||
}
|
||||
```
|
||||
|
||||
## Transaction
|
||||
|
||||
```go
|
||||
tx := db.Begin()
|
||||
|
||||
user := User{Name: "transcation"}
|
||||
|
||||
tx.Save(&u)
|
||||
tx.Update("age": 90)
|
||||
// do whatever
|
||||
|
||||
// rollback
|
||||
tx.Rollback()
|
||||
|
||||
// commit
|
||||
tx.Commit()
|
||||
```
|
||||
|
||||
## Run Raw SQl
|
||||
|
||||
```go
|
||||
|
@ -668,7 +687,6 @@ db.Where("email = ?", "x@example.org").Attrs(User{FromIp: "111.111.111.111"}).Fi
|
|||
```
|
||||
|
||||
## TODO
|
||||
* Transaction
|
||||
* Logger
|
||||
* Join, Having, Group
|
||||
* Index, Unique, Valiations
|
||||
|
|
33
chain.go
33
chain.go
|
@ -1,14 +1,13 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
type Chain struct {
|
||||
db *sql.DB
|
||||
db sql_common
|
||||
driver string
|
||||
value interface{}
|
||||
|
||||
|
@ -243,6 +242,36 @@ func (s *Chain) Related(value interface{}, foreign_keys ...string) *Chain {
|
|||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Begin() *Chain {
|
||||
if db, ok := s.db.(sql_db); ok {
|
||||
tx, err := db.Begin()
|
||||
s.db = interface{}(tx).(sql_common)
|
||||
s.err(err)
|
||||
} else {
|
||||
s.err(errors.New("Can't start a transaction."))
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Commit() *Chain {
|
||||
if db, ok := s.db.(sql_tx); ok {
|
||||
s.err(db.Commit())
|
||||
} else {
|
||||
s.err(errors.New("Commit is not supported, no database transaction found."))
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Rollback() *Chain {
|
||||
if db, ok := s.db.(sql_tx); ok {
|
||||
s.err(db.Rollback())
|
||||
} else {
|
||||
s.err(errors.New("Rollback is not supported, no database transaction found."))
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) validSql(str string) (result bool) {
|
||||
result = regexp.MustCompile("^\\s*[\\w\\s,.*()]*\\s*$").MatchString(str)
|
||||
if !result {
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
package gorm
|
||||
|
||||
import "database/sql"
|
||||
|
||||
type sql_common interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Prepare(query string) (*sql.Stmt, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
type sql_db interface {
|
||||
Begin() (*sql.Tx, error)
|
||||
SetMaxIdleConns(n int)
|
||||
}
|
||||
|
||||
type sql_tx interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
2
do.go
2
do.go
|
@ -14,7 +14,7 @@ import (
|
|||
|
||||
type Do struct {
|
||||
chain *Chain
|
||||
db *sql.DB
|
||||
db sql_common
|
||||
driver string
|
||||
guessedTableName string
|
||||
specifiedTableName string
|
||||
|
|
35
gorm_test.go
35
gorm_test.go
|
@ -1329,6 +1329,41 @@ func TestSqlNullValue(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestTransaction(t *testing.T) {
|
||||
d := db.Begin()
|
||||
u := User{Name: "transcation"}
|
||||
if err := d.Save(&u).Error; err != nil {
|
||||
t.Errorf("No error should raise, but got", err)
|
||||
}
|
||||
|
||||
if err := d.First(&User{}, "name = ?", "transcation").Error; err != nil {
|
||||
t.Errorf("Should find saved record, but got", err)
|
||||
}
|
||||
|
||||
d.Rollback()
|
||||
|
||||
if err := d.First(&User{}, "name = ?", "transcation").Error; err == nil {
|
||||
t.Errorf("Should not find record after rollback")
|
||||
}
|
||||
|
||||
d2 := db.Begin()
|
||||
u2 := User{Name: "transcation-2"}
|
||||
if err := d2.Save(&u2).Error; err != nil {
|
||||
t.Errorf("No error should raise, but got", err)
|
||||
}
|
||||
d2.Update("age", 90)
|
||||
|
||||
if err := d2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
|
||||
t.Errorf("Should find saved record, but got", err)
|
||||
}
|
||||
|
||||
d2.Commit()
|
||||
|
||||
if err := db.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
|
||||
t.Errorf("Should be able to find committed record")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGorm(b *testing.B) {
|
||||
for x := 0; x < b.N; x++ {
|
||||
email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()}
|
||||
|
|
|
@ -15,7 +15,7 @@ type Logger interface {
|
|||
}
|
||||
|
||||
func print(level string, v ...interface{}) {
|
||||
if logger_disabled {
|
||||
if logger_disabled && level != "debug" {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
10
main.go
10
main.go
|
@ -5,7 +5,7 @@ import "database/sql"
|
|||
var singularTableName bool
|
||||
|
||||
type DB struct {
|
||||
db *sql.DB
|
||||
db sql_common
|
||||
driver string
|
||||
}
|
||||
|
||||
|
@ -16,7 +16,9 @@ func Open(driver, source string) (db DB, err error) {
|
|||
}
|
||||
|
||||
func (s *DB) SetPool(n int) {
|
||||
s.db.SetMaxIdleConns(n)
|
||||
if db, ok := s.db.(sql_db); ok {
|
||||
db.SetMaxIdleConns(n)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DB) SetLogger(l interface{}) {
|
||||
|
@ -122,3 +124,7 @@ func (s *DB) DropTable(value interface{}) *Chain {
|
|||
func (s *DB) AutoMigrate(value interface{}) *Chain {
|
||||
return s.buildChain().AutoMigrate(value)
|
||||
}
|
||||
|
||||
func (s *DB) Begin() *Chain {
|
||||
return s.buildChain().Begin()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue