From eab146a275263f4e2a6253c8e9980eaba46047ba Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 27 Jan 2014 11:56:04 +0800 Subject: [PATCH] Add getter setter for scope --- scope.go | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/scope.go b/scope.go index 3ea4646c..27efde54 100644 --- a/scope.go +++ b/scope.go @@ -13,17 +13,17 @@ import ( ) type Scope struct { - Value interface{} - Search *search - Sql string - SqlVars []interface{} - db *DB - startedTransaction bool + Value interface{} + Search *search + Sql string + SqlVars []interface{} + db *DB + _values map[string]interface{} } func (db *DB) NewScope(value interface{}) *Scope { db.Value = value - return &Scope{db: db, Search: db.search, Value: value} + return &Scope{db: db, Search: db.search, Value: value, _values: map[string]interface{}{}} } func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { @@ -271,6 +271,15 @@ func (scope *Scope) Exec() { } } +func (scope *Scope) Get(name string) (value interface{}, ok bool) { + value, ok = scope._values[name] + return +} + +func (scope *Scope) Set(name string, value interface{}) { + scope._values[name] = value +} + func (scope *Scope) Trace(t time.Time) { if len(scope.Sql) > 0 { scope.db.slog(scope.Sql, t, scope.SqlVars...) @@ -281,14 +290,14 @@ func (scope *Scope) Begin() *Scope { if db, ok := scope.DB().(sqlDb); ok { if tx, err := db.Begin(); err == nil { scope.db.db = interface{}(tx).(sqlCommon) - scope.startedTransaction = true + scope.Set("gorm:started_transaction", true) } } return scope } func (scope *Scope) CommitOrRollback() *Scope { - if scope.startedTransaction { + if _, ok := scope.Get("gorm:started_transaction"); ok { if db, ok := scope.db.db.(sqlTx); ok { if scope.HasError() { db.Rollback()