From 39ac95adbbc84c2091bd044e06145bd81d952628 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 20 Aug 2014 17:05:02 +0800 Subject: [PATCH] Add InstanceSet, InstanceGet fomr Scope --- association_test.go | 2 +- callback_query.go | 2 +- callback_update.go | 8 ++++---- main.go | 18 ++++++++++-------- main_private.go | 6 +++++- main_test.go | 14 ++++++++++++++ query_test.go | 1 - scope.go | 26 ++++++++++++++++++++++---- 8 files changed, 57 insertions(+), 20 deletions(-) diff --git a/association_test.go b/association_test.go index 5f765b04..457b2183 100644 --- a/association_test.go +++ b/association_test.go @@ -63,7 +63,7 @@ func TestRelated(t *testing.T) { CreditCard: CreditCard{Number: "1234567890"}, } - db.Debug().Save(&user) + db.Save(&user) if user.CreditCard.Id == 0 { t.Errorf("After user save, credit card should have id") diff --git a/callback_query.go b/callback_query.go index acdbde20..439e3f96 100644 --- a/callback_query.go +++ b/callback_query.go @@ -17,7 +17,7 @@ func Query(scope *Scope) { ) var dest = scope.IndirectValue() - if value, ok := scope.Get("gorm:query_destination"); ok { + if value, ok := scope.InstanceGet("gorm:query_destination"); ok { dest = reflect.Indirect(reflect.ValueOf(value)) } diff --git a/callback_update.go b/callback_update.go index b8178ab5..a286fb25 100644 --- a/callback_update.go +++ b/callback_update.go @@ -7,16 +7,16 @@ import ( ) func AssignUpdateAttributes(scope *Scope) { - if attrs, ok := scope.Get("gorm:update_interface"); ok { + if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { if maps := convertInterfaceToMap(attrs); len(maps) > 0 { protected, ok := scope.Get("gorm:ignore_protected_attrs") _, updateColumn := scope.Get("gorm:update_column") updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool)) if updateColumn { - scope.Set("gorm:update_attrs", maps) + scope.InstanceSet("gorm:update_attrs", maps) } else if len(updateAttrs) > 0 { - scope.Set("gorm:update_attrs", updateAttrs) + scope.InstanceSet("gorm:update_attrs", updateAttrs) } else if !hasUpdate { scope.SkipLeft() return @@ -44,7 +44,7 @@ func Update(scope *Scope) { if !scope.HasError() { var sqls []string - updateAttrs, ok := scope.Get("gorm:update_attrs") + updateAttrs, ok := scope.InstanceGet("gorm:update_attrs") if ok { for key, value := range updateAttrs.(map[string]interface{}) { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value))) diff --git a/main.go b/main.go index 587ebf31..514c5fa8 100644 --- a/main.go +++ b/main.go @@ -185,7 +185,7 @@ func (s *DB) Rows() (*sql.Rows, error) { } func (s *DB) Scan(dest interface{}) *DB { - scope := s.clone().Set("gorm:query_destination", dest).NewScope(s.Value) + scope := s.clone().NewScope(s.Value).InstanceSet("gorm:query_destination", dest) Query(scope) return scope.db } @@ -213,7 +213,7 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { } c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates) } else if len(c.search.AssignAttrs) > 0 { - c.Set("gorm:update_interface", s.search.AssignAttrs).NewScope(out).callCallbacks(s.parent.callback.updates) + c.NewScope(out).InstanceSet("gorm:update_interface", s.search.AssignAttrs).callCallbacks(s.parent.callback.updates) } return c } @@ -223,10 +223,9 @@ func (s *DB) Update(attrs ...interface{}) *DB { } func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { - return s.clone(). - Set("gorm:update_interface", values). + return s.clone().NewScope(s.Value). Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). - NewScope(s.Value). + InstanceSet("gorm:update_interface", values). callCallbacks(s.parent.callback.updates).db } @@ -235,10 +234,9 @@ func (s *DB) UpdateColumn(attrs ...interface{}) *DB { } func (s *DB) UpdateColumns(values interface{}) *DB { - return s.clone(). - Set("gorm:update_interface", values). + return s.clone().NewScope(s.Value). Set("gorm:update_column", true). - NewScope(s.Value). + InstanceSet("gorm:update_interface", values). callCallbacks(s.parent.callback.updates).db } @@ -404,6 +402,10 @@ func (s *DB) Association(column string) *Association { // Set set value by name func (s *DB) Set(name string, value interface{}) *DB { + return s.clone().set(name, value) +} + +func (s *DB) set(name string, value interface{}) *DB { s.values[name] = value return s } diff --git a/main_private.go b/main_private.go index dab50f2a..b9326360 100644 --- a/main_private.go +++ b/main_private.go @@ -6,7 +6,11 @@ import ( ) func (s *DB) clone() *DB { - db := DB{db: s.db, parent: s.parent, logMode: s.logMode, Value: s.Value, Error: s.Error, values: s.values} + db := DB{db: s.db, parent: s.parent, logMode: s.logMode, Value: s.Value, Error: s.Error, values: map[string]interface{}{}} + + for key, value := range s.values { + db.values[key] = value + } if s.search == nil { db.search = &search{} diff --git a/main_test.go b/main_test.go index 10cd022b..e0b39e9f 100644 --- a/main_test.go +++ b/main_test.go @@ -494,6 +494,20 @@ func TestHstore(t *testing.T) { } } +func TestSetAndGet(t *testing.T) { + if value, ok := db.Set("hello", "world").Get("hello"); !ok { + t.Errorf("Should be able to get setting after set") + } else { + if value.(string) != "world" { + t.Errorf("Setted value should not be changed") + } + } + + if _, ok := db.Get("non_existing"); ok { + t.Errorf("Get non existing key should return error") + } +} + func TestCompatibilityMode(t *testing.T) { db, _ := gorm.Open("testdb", "") testdb.SetQueryFunc(func(query string) (driver.Rows, error) { diff --git a/query_test.go b/query_test.go index db335b02..ce42ae5b 100644 --- a/query_test.go +++ b/query_test.go @@ -332,7 +332,6 @@ func TestCount(t *testing.T) { func TestNot(t *testing.T) { var users1, users2, users3, users4, users5, users6, users7, users8 []User db.Find(&users1) - db.Not(users1[0].Id).Find(&users2) if len(users1)-len(users2) != 1 { diff --git a/scope.go b/scope.go index d65f5649..bf17e704 100644 --- a/scope.go +++ b/scope.go @@ -20,6 +20,7 @@ type Scope struct { db *DB skipLeft bool primaryKey string + instanceId string } func (scope *Scope) IndirectValue() reflect.Value { @@ -362,8 +363,9 @@ func (scope *Scope) Exec() *Scope { } // Set set value by name -func (scope *Scope) Set(name string, value interface{}) { - scope.db.Set(name, value) +func (scope *Scope) Set(name string, value interface{}) *Scope { + scope.db.set(name, value) + return scope } // Get get value by name @@ -371,6 +373,22 @@ func (scope *Scope) Get(name string) (interface{}, bool) { return scope.db.Get(name) } +// InstanceId get InstanceId for scope +func (scope *Scope) InstanceId() string { + if scope.instanceId == "" { + scope.instanceId = fmt.Sprintf("%v", &scope) + } + return scope.instanceId +} + +func (scope *Scope) InstanceSet(name string, value interface{}) *Scope { + return scope.Set(name+scope.InstanceId(), value) +} + +func (scope *Scope) InstanceGet(name string) (interface{}, bool) { + return scope.Get(name + scope.InstanceId()) +} + // Trace print sql log func (scope *Scope) Trace(t time.Time) { if len(scope.Sql) > 0 { @@ -383,7 +401,7 @@ func (scope *Scope) Begin() *Scope { if db, ok := scope.DB().(sqlDb); ok { if tx, err := db.Begin(); err == nil { scope.db.db = interface{}(tx).(sqlCommon) - scope.Set("gorm:started_transaction", true) + scope.InstanceSet("gorm:started_transaction", true) } } return scope @@ -391,7 +409,7 @@ func (scope *Scope) Begin() *Scope { // CommitOrRollback commit current transaction if there is no error, otherwise rollback it func (scope *Scope) CommitOrRollback() *Scope { - if _, ok := scope.Get("gorm:started_transaction"); ok { + if _, ok := scope.InstanceGet("gorm:started_transaction"); ok { if db, ok := scope.db.db.(sqlTx); ok { if scope.HasError() { db.Rollback()