mirror of https://github.com/go-gorm/gorm.git
Add InstanceSet, InstanceGet fomr Scope
This commit is contained in:
parent
9012c9ea09
commit
39ac95adbb
|
@ -63,7 +63,7 @@ func TestRelated(t *testing.T) {
|
||||||
CreditCard: CreditCard{Number: "1234567890"},
|
CreditCard: CreditCard{Number: "1234567890"},
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Debug().Save(&user)
|
db.Save(&user)
|
||||||
|
|
||||||
if user.CreditCard.Id == 0 {
|
if user.CreditCard.Id == 0 {
|
||||||
t.Errorf("After user save, credit card should have id")
|
t.Errorf("After user save, credit card should have id")
|
||||||
|
|
|
@ -17,7 +17,7 @@ func Query(scope *Scope) {
|
||||||
)
|
)
|
||||||
|
|
||||||
var dest = scope.IndirectValue()
|
var dest = scope.IndirectValue()
|
||||||
if value, ok := scope.Get("gorm:query_destination"); ok {
|
if value, ok := scope.InstanceGet("gorm:query_destination"); ok {
|
||||||
dest = reflect.Indirect(reflect.ValueOf(value))
|
dest = reflect.Indirect(reflect.ValueOf(value))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,16 +7,16 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func AssignUpdateAttributes(scope *Scope) {
|
func AssignUpdateAttributes(scope *Scope) {
|
||||||
if attrs, ok := scope.Get("gorm:update_interface"); ok {
|
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
|
||||||
if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
|
if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
|
||||||
protected, ok := scope.Get("gorm:ignore_protected_attrs")
|
protected, ok := scope.Get("gorm:ignore_protected_attrs")
|
||||||
_, updateColumn := scope.Get("gorm:update_column")
|
_, updateColumn := scope.Get("gorm:update_column")
|
||||||
updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool))
|
updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool))
|
||||||
|
|
||||||
if updateColumn {
|
if updateColumn {
|
||||||
scope.Set("gorm:update_attrs", maps)
|
scope.InstanceSet("gorm:update_attrs", maps)
|
||||||
} else if len(updateAttrs) > 0 {
|
} else if len(updateAttrs) > 0 {
|
||||||
scope.Set("gorm:update_attrs", updateAttrs)
|
scope.InstanceSet("gorm:update_attrs", updateAttrs)
|
||||||
} else if !hasUpdate {
|
} else if !hasUpdate {
|
||||||
scope.SkipLeft()
|
scope.SkipLeft()
|
||||||
return
|
return
|
||||||
|
@ -44,7 +44,7 @@ func Update(scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
var sqls []string
|
var sqls []string
|
||||||
|
|
||||||
updateAttrs, ok := scope.Get("gorm:update_attrs")
|
updateAttrs, ok := scope.InstanceGet("gorm:update_attrs")
|
||||||
if ok {
|
if ok {
|
||||||
for key, value := range updateAttrs.(map[string]interface{}) {
|
for key, value := range updateAttrs.(map[string]interface{}) {
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
|
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
|
||||||
|
|
18
main.go
18
main.go
|
@ -185,7 +185,7 @@ func (s *DB) Rows() (*sql.Rows, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Scan(dest interface{}) *DB {
|
func (s *DB) Scan(dest interface{}) *DB {
|
||||||
scope := s.clone().Set("gorm:query_destination", dest).NewScope(s.Value)
|
scope := s.clone().NewScope(s.Value).InstanceSet("gorm:query_destination", dest)
|
||||||
Query(scope)
|
Query(scope)
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
@ -213,7 +213,7 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
|
||||||
}
|
}
|
||||||
c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates)
|
c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates)
|
||||||
} else if len(c.search.AssignAttrs) > 0 {
|
} else if len(c.search.AssignAttrs) > 0 {
|
||||||
c.Set("gorm:update_interface", s.search.AssignAttrs).NewScope(out).callCallbacks(s.parent.callback.updates)
|
c.NewScope(out).InstanceSet("gorm:update_interface", s.search.AssignAttrs).callCallbacks(s.parent.callback.updates)
|
||||||
}
|
}
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
@ -223,10 +223,9 @@ func (s *DB) Update(attrs ...interface{}) *DB {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
|
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
|
||||||
return s.clone().
|
return s.clone().NewScope(s.Value).
|
||||||
Set("gorm:update_interface", values).
|
|
||||||
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
|
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
|
||||||
NewScope(s.Value).
|
InstanceSet("gorm:update_interface", values).
|
||||||
callCallbacks(s.parent.callback.updates).db
|
callCallbacks(s.parent.callback.updates).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -235,10 +234,9 @@ func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) UpdateColumns(values interface{}) *DB {
|
func (s *DB) UpdateColumns(values interface{}) *DB {
|
||||||
return s.clone().
|
return s.clone().NewScope(s.Value).
|
||||||
Set("gorm:update_interface", values).
|
|
||||||
Set("gorm:update_column", true).
|
Set("gorm:update_column", true).
|
||||||
NewScope(s.Value).
|
InstanceSet("gorm:update_interface", values).
|
||||||
callCallbacks(s.parent.callback.updates).db
|
callCallbacks(s.parent.callback.updates).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -404,6 +402,10 @@ func (s *DB) Association(column string) *Association {
|
||||||
|
|
||||||
// Set set value by name
|
// Set set value by name
|
||||||
func (s *DB) Set(name string, value interface{}) *DB {
|
func (s *DB) Set(name string, value interface{}) *DB {
|
||||||
|
return s.clone().set(name, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) set(name string, value interface{}) *DB {
|
||||||
s.values[name] = value
|
s.values[name] = value
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,11 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *DB) clone() *DB {
|
func (s *DB) clone() *DB {
|
||||||
db := DB{db: s.db, parent: s.parent, logMode: s.logMode, Value: s.Value, Error: s.Error, values: s.values}
|
db := DB{db: s.db, parent: s.parent, logMode: s.logMode, Value: s.Value, Error: s.Error, values: map[string]interface{}{}}
|
||||||
|
|
||||||
|
for key, value := range s.values {
|
||||||
|
db.values[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
if s.search == nil {
|
if s.search == nil {
|
||||||
db.search = &search{}
|
db.search = &search{}
|
||||||
|
|
14
main_test.go
14
main_test.go
|
@ -494,6 +494,20 @@ func TestHstore(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetAndGet(t *testing.T) {
|
||||||
|
if value, ok := db.Set("hello", "world").Get("hello"); !ok {
|
||||||
|
t.Errorf("Should be able to get setting after set")
|
||||||
|
} else {
|
||||||
|
if value.(string) != "world" {
|
||||||
|
t.Errorf("Setted value should not be changed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := db.Get("non_existing"); ok {
|
||||||
|
t.Errorf("Get non existing key should return error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCompatibilityMode(t *testing.T) {
|
func TestCompatibilityMode(t *testing.T) {
|
||||||
db, _ := gorm.Open("testdb", "")
|
db, _ := gorm.Open("testdb", "")
|
||||||
testdb.SetQueryFunc(func(query string) (driver.Rows, error) {
|
testdb.SetQueryFunc(func(query string) (driver.Rows, error) {
|
||||||
|
|
|
@ -332,7 +332,6 @@ func TestCount(t *testing.T) {
|
||||||
func TestNot(t *testing.T) {
|
func TestNot(t *testing.T) {
|
||||||
var users1, users2, users3, users4, users5, users6, users7, users8 []User
|
var users1, users2, users3, users4, users5, users6, users7, users8 []User
|
||||||
db.Find(&users1)
|
db.Find(&users1)
|
||||||
|
|
||||||
db.Not(users1[0].Id).Find(&users2)
|
db.Not(users1[0].Id).Find(&users2)
|
||||||
|
|
||||||
if len(users1)-len(users2) != 1 {
|
if len(users1)-len(users2) != 1 {
|
||||||
|
|
26
scope.go
26
scope.go
|
@ -20,6 +20,7 @@ type Scope struct {
|
||||||
db *DB
|
db *DB
|
||||||
skipLeft bool
|
skipLeft bool
|
||||||
primaryKey string
|
primaryKey string
|
||||||
|
instanceId string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) IndirectValue() reflect.Value {
|
func (scope *Scope) IndirectValue() reflect.Value {
|
||||||
|
@ -362,8 +363,9 @@ func (scope *Scope) Exec() *Scope {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set set value by name
|
// Set set value by name
|
||||||
func (scope *Scope) Set(name string, value interface{}) {
|
func (scope *Scope) Set(name string, value interface{}) *Scope {
|
||||||
scope.db.Set(name, value)
|
scope.db.set(name, value)
|
||||||
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get get value by name
|
// Get get value by name
|
||||||
|
@ -371,6 +373,22 @@ func (scope *Scope) Get(name string) (interface{}, bool) {
|
||||||
return scope.db.Get(name)
|
return scope.db.Get(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InstanceId get InstanceId for scope
|
||||||
|
func (scope *Scope) InstanceId() string {
|
||||||
|
if scope.instanceId == "" {
|
||||||
|
scope.instanceId = fmt.Sprintf("%v", &scope)
|
||||||
|
}
|
||||||
|
return scope.instanceId
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) InstanceSet(name string, value interface{}) *Scope {
|
||||||
|
return scope.Set(name+scope.InstanceId(), value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
|
||||||
|
return scope.Get(name + scope.InstanceId())
|
||||||
|
}
|
||||||
|
|
||||||
// Trace print sql log
|
// Trace print sql log
|
||||||
func (scope *Scope) Trace(t time.Time) {
|
func (scope *Scope) Trace(t time.Time) {
|
||||||
if len(scope.Sql) > 0 {
|
if len(scope.Sql) > 0 {
|
||||||
|
@ -383,7 +401,7 @@ func (scope *Scope) Begin() *Scope {
|
||||||
if db, ok := scope.DB().(sqlDb); ok {
|
if db, ok := scope.DB().(sqlDb); ok {
|
||||||
if tx, err := db.Begin(); err == nil {
|
if tx, err := db.Begin(); err == nil {
|
||||||
scope.db.db = interface{}(tx).(sqlCommon)
|
scope.db.db = interface{}(tx).(sqlCommon)
|
||||||
scope.Set("gorm:started_transaction", true)
|
scope.InstanceSet("gorm:started_transaction", true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return scope
|
return scope
|
||||||
|
@ -391,7 +409,7 @@ func (scope *Scope) Begin() *Scope {
|
||||||
|
|
||||||
// CommitOrRollback commit current transaction if there is no error, otherwise rollback it
|
// CommitOrRollback commit current transaction if there is no error, otherwise rollback it
|
||||||
func (scope *Scope) CommitOrRollback() *Scope {
|
func (scope *Scope) CommitOrRollback() *Scope {
|
||||||
if _, ok := scope.Get("gorm:started_transaction"); ok {
|
if _, ok := scope.InstanceGet("gorm:started_transaction"); ok {
|
||||||
if db, ok := scope.db.db.(sqlTx); ok {
|
if db, ok := scope.db.db.(sqlTx); ok {
|
||||||
if scope.HasError() {
|
if scope.HasError() {
|
||||||
db.Rollback()
|
db.Rollback()
|
||||||
|
|
Loading…
Reference in New Issue