implement callback shared

This commit is contained in:
Jinzhu 2014-01-27 08:26:59 +08:00
parent 7b8e91377b
commit ee6a6827a8
4 changed files with 130 additions and 49 deletions

View File

@ -1,5 +1,7 @@
package gorm
import "reflect"
func BeginTransaction(scope *Scope) {
scope.Begin()
}
@ -9,7 +11,65 @@ func CommitOrRollbackTransaction(scope *Scope) {
}
func SaveBeforeAssociations(scope *Scope) {
for _, field := range scope.Fields() {
if field.BeforeAssociation && !field.IsBlank && !field.IsIgnored {
value := reflect.ValueOf(field.Value)
newDB := scope.NewDB()
if value.CanAddr() {
newDB.Save(value.Addr().Interface())
} else {
// If can't take address, then clone the value and set it back
destValue := reflect.New(reflect.ValueOf(field.Value).Type()).Elem()
for _, f := range newDB.NewScope(field.Value).Fields() {
destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
}
newDB.Save(destValue.Addr().Interface())
scope.SetColumn(field.Name, destValue.Interface())
}
if len(field.foreignKey) > 0 {
scope.SetColumn(field.foreignKey, scope.PrimaryKeyValue())
}
}
}
}
func SaveAfterAssociations(scope *Scope) {
for _, field := range scope.Fields() {
if field.AfterAssociation && !field.IsBlank && !field.IsIgnored {
value := reflect.ValueOf(field.Value)
switch value.Kind() {
case reflect.Slice:
for i := 0; i < value.Len(); i++ {
newDB := scope.NewDB()
elem := value.Index(i).Addr().Interface()
if len(field.foreignKey) > 0 {
newDB.NewScope(elem).SetColumn(field.foreignKey, scope.PrimaryKeyValue())
}
newDB.Save(elem)
}
default:
newDB := scope.NewDB()
if value.CanAddr() {
newDB.NewScope(field.Value).SetColumn(field.foreignKey, scope.PrimaryKeyValue())
newDB.Save(field.Value)
} else {
destValue := reflect.New(reflect.TypeOf(field.Value)).Elem()
for _, f := range newDB.NewScope(destValue).Fields() {
destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
}
elem := destValue.Addr().Interface()
newDB.NewScope(elem).SetColumn(field.foreignKey, scope.PrimaryKeyValue())
newDB.Save(elem)
scope.SetColumn(field.Name, destValue.Interface())
}
}
}
}
}

105
field.go
View File

@ -9,28 +9,32 @@ import (
)
type Field struct {
Name string
DBName string
Value interface{}
IsBlank bool
IsIgnored bool
Tag string
AddationalTag string
Size int
SqlTag string
Name string
DBName string
Value interface{}
IsBlank bool
IsIgnored bool
Tag string
AddationalTag string
Size int
SqlTag string
ForeignKey string
BeforeAssociation bool
AfterAssociation bool
dbName string
model *Model
isBlank bool
ignoreField bool
isPrimaryKey bool
autoCreateTime bool
autoUpdateTime bool
foreignKey string
beforeAssociation bool
afterAssociation bool
reflectValue reflect.Value
structField reflect.StructField
dbName string
model *Model
isBlank bool
ignoreField bool
isPrimaryKey bool
autoCreateTime bool
autoUpdateTime bool
reflectValue reflect.Value
structField reflect.StructField
}
func (f *Field) IsScanner() bool {
@ -43,6 +47,43 @@ func (f *Field) IsTime() bool {
return is_time
}
func (f *Field) parseAssociation() {
elem := reflect.Indirect(reflect.ValueOf(f.Value))
typ := elem.Type()
switch elem.Kind() {
case reflect.Slice:
typ = typ.Elem()
if _, ok := f.Value.([]byte); !ok {
foreignKey := typ.Name() + "Id"
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
f.ForeignKey = foreignKey
f.foreignKey = foreignKey
}
f.AfterAssociation = true
f.afterAssociation = true
}
case reflect.Struct:
if !f.IsTime() && !f.IsScanner() {
if elem.FieldByName(f.Name + "Id").IsValid() {
f.foreignKey = f.Name + "Id"
f.beforeAssociation = true
f.ForeignKey = f.Name + "Id"
f.BeforeAssociation = true
} else {
foreignKey := typ.Name() + "Id"
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
f.foreignKey = foreignKey
f.ForeignKey = foreignKey
}
f.afterAssociation = true
f.AfterAssociation = true
}
}
}
}
func (f *Field) parseBlank() {
f.isBlank = isBlank(f.reflectValue)
}
@ -103,34 +144,6 @@ func (f *Field) sqlTag() (str string) {
return typ
}
func (f *Field) parseAssociation() {
reflect_value := f.reflectValue
switch reflect_value.Kind() {
case reflect.Slice:
if _, ok := f.Value.([]byte); !ok {
foreign_key := f.model.typeName() + "Id"
if reflect.New(reflect_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() {
f.foreignKey = foreign_key
}
f.afterAssociation = true
}
case reflect.Struct:
if !f.isTime() && !f.isScanner() {
if f.model.reflectData().FieldByName(f.Name + "Id").IsValid() {
f.foreignKey = f.Name + "Id"
f.beforeAssociation = true
} else {
foreign_key := f.model.typeName() + "Id"
if reflect.New(reflect_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
f.foreignKey = foreign_key
}
f.afterAssociation = true
}
}
}
}
func parseSqlTag(str string) (typ string, addational_typ string, size int) {
if str == "-" {
typ = str

View File

@ -177,7 +177,7 @@ func (s *DB) UpdateColumns(values interface{}, ignore_protected_attrs ...bool) *
}
func (s *DB) Save(value interface{}) *DB {
scope := s.clone().newScope(value)
scope := s.clone().NewScope(value)
if scope.PrimaryKeyZero() {
return scope.callCallbacks(s.parent.callback.creates).db.do(value).db
} else {
@ -186,7 +186,7 @@ func (s *DB) Save(value interface{}) *DB {
}
func (s *DB) Delete(value interface{}) *DB {
return s.clone().newScope(value).callCallbacks(s.parent.callback.deletes).db
return s.clone().NewScope(value).callCallbacks(s.parent.callback.deletes).db
}
func (s *DB) Raw(sql string, values ...interface{}) *DB {

View File

@ -21,7 +21,7 @@ type Scope struct {
startedTransaction bool
}
func (db *DB) newScope(value interface{}) *Scope {
func (db *DB) NewScope(value interface{}) *Scope {
return &Scope{db: db, Search: db.search, Value: value}
}
@ -32,6 +32,14 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
return scope
}
func (scope *Scope) New(value interface{}) *Scope {
return &Scope{db: scope.db.parent, Search: &search{}, Value: value}
}
func (scope *Scope) NewDB() *DB {
return scope.db.parent
}
func (scope *Scope) DB() sqlCommon {
return scope.db.db
}