diff --git a/chain.go b/chain.go index 28ba14fe..48acc163 100644 --- a/chain.go +++ b/chain.go @@ -50,8 +50,8 @@ func (s *Chain) deleteLastError() { s.Errors = s.Errors[:len(s.Errors)-1] } -func (s *Chain) do(value interface{}) (do *Do) { - do = &Do{ +func (s *Chain) do(value interface{}) *Do { + do := Do{ chain: s, db: s.db, whereClause: s.whereClause, @@ -67,7 +67,7 @@ func (s *Chain) do(value interface{}) (do *Do) { s.value = value do.setModel(value) - return + return &do } func (s *Chain) Model(model interface{}) *Chain { @@ -132,12 +132,22 @@ func (s *Chain) Select(value interface{}) *Chain { } func (s *Chain) Save(value interface{}) *Chain { - s.do(value).save() + do := s.do(value) + tx_started := do.begin() + do.save() + if tx_started { + do.commit() + } return s } func (s *Chain) Delete(value interface{}) *Chain { - s.do(value).delete() + do := s.do(value) + tx_started := do.begin() + do.delete() + if tx_started { + do.commit() + } return s } @@ -146,7 +156,12 @@ func (s *Chain) Update(attrs ...interface{}) *Chain { } func (s *Chain) Updates(values interface{}, ignore_protected_attrs ...bool) *Chain { - s.do(s.value).setUpdateAttrs(values, ignore_protected_attrs...).update() + do := s.do(s.value) + tx_started := do.begin() + do.setUpdateAttrs(values, ignore_protected_attrs...).update() + if tx_started { + do.commit() + } return s } diff --git a/do.go b/do.go index fe483927..cde02ccf 100644 --- a/do.go +++ b/do.go @@ -84,11 +84,11 @@ func (s *Do) exec(sqls ...string) (err error) { return s.err(err) } -func (s *Do) save() (i interface{}) { +func (s *Do) save() (value interface{}) { if s.model.primaryKeyZero() { - return s.create() + value = s.create() } else { - return s.update() + value = s.update() } return } @@ -255,7 +255,7 @@ func (s *Do) prepareUpdateSql(results map[string]interface{}) { return } -func (s *Do) update() (i int64) { +func (s *Do) update() (i interface{}) { update_attrs := s.updateAttrs if len(update_attrs) > 0 { var need_update bool @@ -756,6 +756,23 @@ func (s *Do) autoMigrate() *Do { return s } +func (s *Do) begin() bool { + if db, ok := s.db.(sql_db); ok { + tx, err := db.Begin() + if err == nil { + s.db = interface{}(tx).(sql_common) + return true + } + } + return false +} + +func (s *Do) commit() { + if db, ok := s.db.(sql_tx); ok { + s.err(db.Commit()) + } +} + func (s *Do) initializeWithSearchCondition() { m := Model{data: s.value, do: s} diff --git a/gorm_test.go b/gorm_test.go index 8379a5b8..5b67006c 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -87,7 +87,7 @@ func init() { // db, err = Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True") // db, err = Open("sqlite3", "/tmp/gorm.db") // db.SetLogger(log.New(os.Stdout, "\r\n", 0)) - db.LogMode(true) + // db.LogMode(true) if err != nil { panic(fmt.Sprintf("No error should happen when connect database, but got %+v", err)) diff --git a/logger.go b/logger.go index 407738dd..226345fb 100644 --- a/logger.go +++ b/logger.go @@ -5,7 +5,6 @@ import ( "log" "os" "regexp" - "time" )