implement callback shared

This commit is contained in:
Jinzhu 2014-01-27 08:26:59 +08:00
parent 7b8e91377b
commit ee6a6827a8
4 changed files with 130 additions and 49 deletions

View File

@ -1,5 +1,7 @@
package gorm package gorm
import "reflect"
func BeginTransaction(scope *Scope) { func BeginTransaction(scope *Scope) {
scope.Begin() scope.Begin()
} }
@ -9,7 +11,65 @@ func CommitOrRollbackTransaction(scope *Scope) {
} }
func SaveBeforeAssociations(scope *Scope) { func SaveBeforeAssociations(scope *Scope) {
for _, field := range scope.Fields() {
if field.BeforeAssociation && !field.IsBlank && !field.IsIgnored {
value := reflect.ValueOf(field.Value)
newDB := scope.NewDB()
if value.CanAddr() {
newDB.Save(value.Addr().Interface())
} else {
// If can't take address, then clone the value and set it back
destValue := reflect.New(reflect.ValueOf(field.Value).Type()).Elem()
for _, f := range newDB.NewScope(field.Value).Fields() {
destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
}
newDB.Save(destValue.Addr().Interface())
scope.SetColumn(field.Name, destValue.Interface())
}
if len(field.foreignKey) > 0 {
scope.SetColumn(field.foreignKey, scope.PrimaryKeyValue())
}
}
}
} }
func SaveAfterAssociations(scope *Scope) { func SaveAfterAssociations(scope *Scope) {
for _, field := range scope.Fields() {
if field.AfterAssociation && !field.IsBlank && !field.IsIgnored {
value := reflect.ValueOf(field.Value)
switch value.Kind() {
case reflect.Slice:
for i := 0; i < value.Len(); i++ {
newDB := scope.NewDB()
elem := value.Index(i).Addr().Interface()
if len(field.foreignKey) > 0 {
newDB.NewScope(elem).SetColumn(field.foreignKey, scope.PrimaryKeyValue())
}
newDB.Save(elem)
}
default:
newDB := scope.NewDB()
if value.CanAddr() {
newDB.NewScope(field.Value).SetColumn(field.foreignKey, scope.PrimaryKeyValue())
newDB.Save(field.Value)
} else {
destValue := reflect.New(reflect.TypeOf(field.Value)).Elem()
for _, f := range newDB.NewScope(destValue).Fields() {
destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
}
elem := destValue.Addr().Interface()
newDB.NewScope(elem).SetColumn(field.foreignKey, scope.PrimaryKeyValue())
newDB.Save(elem)
scope.SetColumn(field.Name, destValue.Interface())
}
}
}
}
} }

105
field.go
View File

@ -9,28 +9,32 @@ import (
) )
type Field struct { type Field struct {
Name string Name string
DBName string DBName string
Value interface{} Value interface{}
IsBlank bool IsBlank bool
IsIgnored bool IsIgnored bool
Tag string Tag string
AddationalTag string AddationalTag string
Size int Size int
SqlTag string SqlTag string
ForeignKey string
BeforeAssociation bool
AfterAssociation bool
dbName string
model *Model
isBlank bool
ignoreField bool
isPrimaryKey bool
autoCreateTime bool
autoUpdateTime bool
foreignKey string foreignKey string
beforeAssociation bool beforeAssociation bool
afterAssociation bool afterAssociation bool
reflectValue reflect.Value
structField reflect.StructField dbName string
model *Model
isBlank bool
ignoreField bool
isPrimaryKey bool
autoCreateTime bool
autoUpdateTime bool
reflectValue reflect.Value
structField reflect.StructField
} }
func (f *Field) IsScanner() bool { func (f *Field) IsScanner() bool {
@ -43,6 +47,43 @@ func (f *Field) IsTime() bool {
return is_time return is_time
} }
func (f *Field) parseAssociation() {
elem := reflect.Indirect(reflect.ValueOf(f.Value))
typ := elem.Type()
switch elem.Kind() {
case reflect.Slice:
typ = typ.Elem()
if _, ok := f.Value.([]byte); !ok {
foreignKey := typ.Name() + "Id"
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
f.ForeignKey = foreignKey
f.foreignKey = foreignKey
}
f.AfterAssociation = true
f.afterAssociation = true
}
case reflect.Struct:
if !f.IsTime() && !f.IsScanner() {
if elem.FieldByName(f.Name + "Id").IsValid() {
f.foreignKey = f.Name + "Id"
f.beforeAssociation = true
f.ForeignKey = f.Name + "Id"
f.BeforeAssociation = true
} else {
foreignKey := typ.Name() + "Id"
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
f.foreignKey = foreignKey
f.ForeignKey = foreignKey
}
f.afterAssociation = true
f.AfterAssociation = true
}
}
}
}
func (f *Field) parseBlank() { func (f *Field) parseBlank() {
f.isBlank = isBlank(f.reflectValue) f.isBlank = isBlank(f.reflectValue)
} }
@ -103,34 +144,6 @@ func (f *Field) sqlTag() (str string) {
return typ return typ
} }
func (f *Field) parseAssociation() {
reflect_value := f.reflectValue
switch reflect_value.Kind() {
case reflect.Slice:
if _, ok := f.Value.([]byte); !ok {
foreign_key := f.model.typeName() + "Id"
if reflect.New(reflect_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() {
f.foreignKey = foreign_key
}
f.afterAssociation = true
}
case reflect.Struct:
if !f.isTime() && !f.isScanner() {
if f.model.reflectData().FieldByName(f.Name + "Id").IsValid() {
f.foreignKey = f.Name + "Id"
f.beforeAssociation = true
} else {
foreign_key := f.model.typeName() + "Id"
if reflect.New(reflect_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
f.foreignKey = foreign_key
}
f.afterAssociation = true
}
}
}
}
func parseSqlTag(str string) (typ string, addational_typ string, size int) { func parseSqlTag(str string) (typ string, addational_typ string, size int) {
if str == "-" { if str == "-" {
typ = str typ = str

View File

@ -177,7 +177,7 @@ func (s *DB) UpdateColumns(values interface{}, ignore_protected_attrs ...bool) *
} }
func (s *DB) Save(value interface{}) *DB { func (s *DB) Save(value interface{}) *DB {
scope := s.clone().newScope(value) scope := s.clone().NewScope(value)
if scope.PrimaryKeyZero() { if scope.PrimaryKeyZero() {
return scope.callCallbacks(s.parent.callback.creates).db.do(value).db return scope.callCallbacks(s.parent.callback.creates).db.do(value).db
} else { } else {
@ -186,7 +186,7 @@ func (s *DB) Save(value interface{}) *DB {
} }
func (s *DB) Delete(value interface{}) *DB { func (s *DB) Delete(value interface{}) *DB {
return s.clone().newScope(value).callCallbacks(s.parent.callback.deletes).db return s.clone().NewScope(value).callCallbacks(s.parent.callback.deletes).db
} }
func (s *DB) Raw(sql string, values ...interface{}) *DB { func (s *DB) Raw(sql string, values ...interface{}) *DB {

View File

@ -21,7 +21,7 @@ type Scope struct {
startedTransaction bool startedTransaction bool
} }
func (db *DB) newScope(value interface{}) *Scope { func (db *DB) NewScope(value interface{}) *Scope {
return &Scope{db: db, Search: db.search, Value: value} return &Scope{db: db, Search: db.search, Value: value}
} }
@ -32,6 +32,14 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
return scope return scope
} }
func (scope *Scope) New(value interface{}) *Scope {
return &Scope{db: scope.db.parent, Search: &search{}, Value: value}
}
func (scope *Scope) NewDB() *DB {
return scope.db.parent
}
func (scope *Scope) DB() sqlCommon { func (scope *Scope) DB() sqlCommon {
return scope.db.db return scope.db.db
} }