diff --git a/README.md b/README.md index 4496f152..d43f0549 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ Yet Another ORM library for Go, aims for developer friendly * Callbacks (before/after create/save/update/delete) * Soft Delete * Auto Migration +* Transaction * Every feature comes with tests * Convention Over Configuration * Developer Friendly @@ -617,6 +618,24 @@ func (u User) TableName() string { } ``` +## Transaction + +```go +tx := db.Begin() + +user := User{Name: "transcation"} + +tx.Save(&u) +tx.Update("age": 90) +// do whatever + +// rollback +tx.Rollback() + +// commit +tx.Commit() +``` + ## Run Raw SQl ```go @@ -668,7 +687,6 @@ db.Where("email = ?", "x@example.org").Attrs(User{FromIp: "111.111.111.111"}).Fi ``` ## TODO -* Transaction * Logger * Join, Having, Group * Index, Unique, Valiations diff --git a/chain.go b/chain.go index d17cda24..aa65e357 100644 --- a/chain.go +++ b/chain.go @@ -1,14 +1,13 @@ package gorm import ( - "database/sql" "errors" "fmt" "regexp" ) type Chain struct { - db *sql.DB + db sql_common driver string value interface{} @@ -243,6 +242,36 @@ func (s *Chain) Related(value interface{}, foreign_keys ...string) *Chain { return s } +func (s *Chain) Begin() *Chain { + if db, ok := s.db.(sql_db); ok { + tx, err := db.Begin() + s.db = interface{}(tx).(sql_common) + s.err(err) + } else { + s.err(errors.New("Can't start a transaction.")) + } + + return s +} + +func (s *Chain) Commit() *Chain { + if db, ok := s.db.(sql_tx); ok { + s.err(db.Commit()) + } else { + s.err(errors.New("Commit is not supported, no database transaction found.")) + } + return s +} + +func (s *Chain) Rollback() *Chain { + if db, ok := s.db.(sql_tx); ok { + s.err(db.Rollback()) + } else { + s.err(errors.New("Rollback is not supported, no database transaction found.")) + } + return s +} + func (s *Chain) validSql(str string) (result bool) { result = regexp.MustCompile("^\\s*[\\w\\s,.*()]*\\s*$").MatchString(str) if !result { diff --git a/db.go b/db.go new file mode 100644 index 00000000..6d62d5a4 --- /dev/null +++ b/db.go @@ -0,0 +1,20 @@ +package gorm + +import "database/sql" + +type sql_common interface { + Exec(query string, args ...interface{}) (sql.Result, error) + Prepare(query string) (*sql.Stmt, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row +} + +type sql_db interface { + Begin() (*sql.Tx, error) + SetMaxIdleConns(n int) +} + +type sql_tx interface { + Commit() error + Rollback() error +} diff --git a/do.go b/do.go index d80693aa..49c4dc80 100644 --- a/do.go +++ b/do.go @@ -14,7 +14,7 @@ import ( type Do struct { chain *Chain - db *sql.DB + db sql_common driver string guessedTableName string specifiedTableName string diff --git a/gorm_test.go b/gorm_test.go index 0cbe21f8..74aaf950 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -1329,6 +1329,41 @@ func TestSqlNullValue(t *testing.T) { } } +func TestTransaction(t *testing.T) { + d := db.Begin() + u := User{Name: "transcation"} + if err := d.Save(&u).Error; err != nil { + t.Errorf("No error should raise, but got", err) + } + + if err := d.First(&User{}, "name = ?", "transcation").Error; err != nil { + t.Errorf("Should find saved record, but got", err) + } + + d.Rollback() + + if err := d.First(&User{}, "name = ?", "transcation").Error; err == nil { + t.Errorf("Should not find record after rollback") + } + + d2 := db.Begin() + u2 := User{Name: "transcation-2"} + if err := d2.Save(&u2).Error; err != nil { + t.Errorf("No error should raise, but got", err) + } + d2.Update("age", 90) + + if err := d2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { + t.Errorf("Should find saved record, but got", err) + } + + d2.Commit() + + if err := db.First(&User{}, "name = ?", "transcation-2").Error; err != nil { + t.Errorf("Should be able to find committed record") + } +} + func BenchmarkGorm(b *testing.B) { for x := 0; x < b.N; x++ { email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()} diff --git a/logger.go b/logger.go index a67eabfa..b23869f8 100644 --- a/logger.go +++ b/logger.go @@ -15,7 +15,7 @@ type Logger interface { } func print(level string, v ...interface{}) { - if logger_disabled { + if logger_disabled && level != "debug" { return } diff --git a/main.go b/main.go index bbea1265..3d7d9d02 100644 --- a/main.go +++ b/main.go @@ -5,7 +5,7 @@ import "database/sql" var singularTableName bool type DB struct { - db *sql.DB + db sql_common driver string } @@ -16,7 +16,9 @@ func Open(driver, source string) (db DB, err error) { } func (s *DB) SetPool(n int) { - s.db.SetMaxIdleConns(n) + if db, ok := s.db.(sql_db); ok { + db.SetMaxIdleConns(n) + } } func (s *DB) SetLogger(l interface{}) { @@ -122,3 +124,7 @@ func (s *DB) DropTable(value interface{}) *Chain { func (s *DB) AutoMigrate(value interface{}) *Chain { return s.buildChain().AutoMigrate(value) } + +func (s *DB) Begin() *Chain { + return s.buildChain().Begin() +}