diff --git a/callback_delete.go b/callback_delete.go index 7b35bca4..9dc5b692 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -10,6 +10,8 @@ func BeforeDelete(scope *Scope) { } func Delete(scope *Scope) { + defer scope.Trace(time.Now()) + if scope.HasError() { return } @@ -33,7 +35,9 @@ func AfterDelete(scope *Scope) { } func init() { + DefaultCallback.Delete().Register("begin_transaction", BeginTransaction) DefaultCallback.Delete().Register("before_delete", BeforeDelete) DefaultCallback.Delete().Register("delete", Delete) DefaultCallback.Delete().Register("after_delete", AfterDelete) + DefaultCallback.Delete().Register("commit_or_rollback_transaction", CommitOrRollbackTransaction) } diff --git a/callback_shared.go b/callback_shared.go new file mode 100644 index 00000000..3ff5f104 --- /dev/null +++ b/callback_shared.go @@ -0,0 +1,9 @@ +package gorm + +func BeginTransaction(scope *Scope) { + scope.Begin() +} + +func CommitOrRollbackTransaction(scope *Scope) { + scope.CommitOrRollback() +} diff --git a/gorm_test.go b/gorm_test.go index e3eba586..1a533d20 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -1504,7 +1504,6 @@ func TestTransaction(t *testing.T) { if err := tx2.Save(&u2).Error; err != nil { t.Errorf("No error should raise, but got", err) } - tx2.Update("age", 90) if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { t.Errorf("Should find saved record, but got", err) diff --git a/scope.go b/scope.go index d72388c3..6d908cab 100644 --- a/scope.go +++ b/scope.go @@ -5,17 +5,19 @@ import ( "fmt" "github.com/jinzhu/gorm/dialect" "strings" + "time" "reflect" "regexp" ) type Scope struct { - Value interface{} - Search *search - Sql string - SqlVars []interface{} - db *DB + Value interface{} + Search *search + Sql string + SqlVars []interface{} + db *DB + startedTransaction bool } func (db *DB) newScope(value interface{}) *Scope { @@ -52,6 +54,21 @@ func (scope *Scope) PrimaryKey() string { return "id" } +func (scope *Scope) PrimaryKeyZero() bool { + return isBlank(reflect.ValueOf(scope.PrimaryKeyValue())) +} + +func (scope *Scope) PrimaryKeyValue() interface{} { + data := reflect.Indirect(reflect.ValueOf(scope.Value)) + + if data.Kind() == reflect.Struct { + if field := data.FieldByName(snakeToUpperCamel(scope.PrimaryKey())); field.IsValid() { + return field.Interface() + } + } + return 0 +} + func (scope *Scope) HasColumn(name string) bool { data := reflect.Indirect(reflect.ValueOf(scope.Value)) @@ -146,3 +163,31 @@ func (scope *Scope) Exec() { scope.Err(err) } } + +func (scope *Scope) Trace(t time.Time) { + if len(scope.Sql) > 0 { + scope.db.slog(scope.Sql, t, scope.SqlVars...) + } +} + +func (scope *Scope) Begin() *Scope { + if tx, err := scope.DB().(sqlDb).Begin(); err == nil { + scope.db.db = interface{}(tx).(sqlCommon) + scope.startedTransaction = true + } + return scope +} + +func (scope *Scope) CommitOrRollback() *Scope { + if scope.startedTransaction { + if db, ok := scope.db.db.(sqlTx); ok { + if scope.HasError() { + db.Rollback() + } else { + db.Commit() + } + scope.db.db = scope.db.parent.db + } + } + return scope +}