forked from mirror/gorm
implement callback shared
This commit is contained in:
parent
7b8e91377b
commit
ee6a6827a8
|
@ -1,5 +1,7 @@
|
|||
package gorm
|
||||
|
||||
import "reflect"
|
||||
|
||||
func BeginTransaction(scope *Scope) {
|
||||
scope.Begin()
|
||||
}
|
||||
|
@ -9,7 +11,65 @@ func CommitOrRollbackTransaction(scope *Scope) {
|
|||
}
|
||||
|
||||
func SaveBeforeAssociations(scope *Scope) {
|
||||
for _, field := range scope.Fields() {
|
||||
if field.BeforeAssociation && !field.IsBlank && !field.IsIgnored {
|
||||
value := reflect.ValueOf(field.Value)
|
||||
newDB := scope.NewDB()
|
||||
|
||||
if value.CanAddr() {
|
||||
newDB.Save(value.Addr().Interface())
|
||||
} else {
|
||||
// If can't take address, then clone the value and set it back
|
||||
destValue := reflect.New(reflect.ValueOf(field.Value).Type()).Elem()
|
||||
for _, f := range newDB.NewScope(field.Value).Fields() {
|
||||
destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
|
||||
}
|
||||
newDB.Save(destValue.Addr().Interface())
|
||||
scope.SetColumn(field.Name, destValue.Interface())
|
||||
}
|
||||
|
||||
if len(field.foreignKey) > 0 {
|
||||
scope.SetColumn(field.foreignKey, scope.PrimaryKeyValue())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func SaveAfterAssociations(scope *Scope) {
|
||||
for _, field := range scope.Fields() {
|
||||
if field.AfterAssociation && !field.IsBlank && !field.IsIgnored {
|
||||
value := reflect.ValueOf(field.Value)
|
||||
|
||||
switch value.Kind() {
|
||||
case reflect.Slice:
|
||||
for i := 0; i < value.Len(); i++ {
|
||||
newDB := scope.NewDB()
|
||||
elem := value.Index(i).Addr().Interface()
|
||||
|
||||
if len(field.foreignKey) > 0 {
|
||||
newDB.NewScope(elem).SetColumn(field.foreignKey, scope.PrimaryKeyValue())
|
||||
}
|
||||
|
||||
newDB.Save(elem)
|
||||
}
|
||||
default:
|
||||
newDB := scope.NewDB()
|
||||
if value.CanAddr() {
|
||||
newDB.NewScope(field.Value).SetColumn(field.foreignKey, scope.PrimaryKeyValue())
|
||||
newDB.Save(field.Value)
|
||||
} else {
|
||||
destValue := reflect.New(reflect.TypeOf(field.Value)).Elem()
|
||||
|
||||
for _, f := range newDB.NewScope(destValue).Fields() {
|
||||
destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
|
||||
}
|
||||
|
||||
elem := destValue.Addr().Interface()
|
||||
newDB.NewScope(elem).SetColumn(field.foreignKey, scope.PrimaryKeyValue())
|
||||
newDB.Save(elem)
|
||||
scope.SetColumn(field.Name, destValue.Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
105
field.go
105
field.go
|
@ -9,28 +9,32 @@ import (
|
|||
)
|
||||
|
||||
type Field struct {
|
||||
Name string
|
||||
DBName string
|
||||
Value interface{}
|
||||
IsBlank bool
|
||||
IsIgnored bool
|
||||
Tag string
|
||||
AddationalTag string
|
||||
Size int
|
||||
SqlTag string
|
||||
Name string
|
||||
DBName string
|
||||
Value interface{}
|
||||
IsBlank bool
|
||||
IsIgnored bool
|
||||
Tag string
|
||||
AddationalTag string
|
||||
Size int
|
||||
SqlTag string
|
||||
ForeignKey string
|
||||
BeforeAssociation bool
|
||||
AfterAssociation bool
|
||||
|
||||
dbName string
|
||||
model *Model
|
||||
isBlank bool
|
||||
ignoreField bool
|
||||
isPrimaryKey bool
|
||||
autoCreateTime bool
|
||||
autoUpdateTime bool
|
||||
foreignKey string
|
||||
beforeAssociation bool
|
||||
afterAssociation bool
|
||||
reflectValue reflect.Value
|
||||
structField reflect.StructField
|
||||
|
||||
dbName string
|
||||
model *Model
|
||||
isBlank bool
|
||||
ignoreField bool
|
||||
isPrimaryKey bool
|
||||
autoCreateTime bool
|
||||
autoUpdateTime bool
|
||||
reflectValue reflect.Value
|
||||
structField reflect.StructField
|
||||
}
|
||||
|
||||
func (f *Field) IsScanner() bool {
|
||||
|
@ -43,6 +47,43 @@ func (f *Field) IsTime() bool {
|
|||
return is_time
|
||||
}
|
||||
|
||||
func (f *Field) parseAssociation() {
|
||||
elem := reflect.Indirect(reflect.ValueOf(f.Value))
|
||||
typ := elem.Type()
|
||||
|
||||
switch elem.Kind() {
|
||||
case reflect.Slice:
|
||||
typ = typ.Elem()
|
||||
|
||||
if _, ok := f.Value.([]byte); !ok {
|
||||
foreignKey := typ.Name() + "Id"
|
||||
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
|
||||
f.ForeignKey = foreignKey
|
||||
f.foreignKey = foreignKey
|
||||
}
|
||||
f.AfterAssociation = true
|
||||
f.afterAssociation = true
|
||||
}
|
||||
case reflect.Struct:
|
||||
if !f.IsTime() && !f.IsScanner() {
|
||||
if elem.FieldByName(f.Name + "Id").IsValid() {
|
||||
f.foreignKey = f.Name + "Id"
|
||||
f.beforeAssociation = true
|
||||
f.ForeignKey = f.Name + "Id"
|
||||
f.BeforeAssociation = true
|
||||
} else {
|
||||
foreignKey := typ.Name() + "Id"
|
||||
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
|
||||
f.foreignKey = foreignKey
|
||||
f.ForeignKey = foreignKey
|
||||
}
|
||||
f.afterAssociation = true
|
||||
f.AfterAssociation = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Field) parseBlank() {
|
||||
f.isBlank = isBlank(f.reflectValue)
|
||||
}
|
||||
|
@ -103,34 +144,6 @@ func (f *Field) sqlTag() (str string) {
|
|||
return typ
|
||||
}
|
||||
|
||||
func (f *Field) parseAssociation() {
|
||||
reflect_value := f.reflectValue
|
||||
|
||||
switch reflect_value.Kind() {
|
||||
case reflect.Slice:
|
||||
if _, ok := f.Value.([]byte); !ok {
|
||||
foreign_key := f.model.typeName() + "Id"
|
||||
if reflect.New(reflect_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() {
|
||||
f.foreignKey = foreign_key
|
||||
}
|
||||
f.afterAssociation = true
|
||||
}
|
||||
case reflect.Struct:
|
||||
if !f.isTime() && !f.isScanner() {
|
||||
if f.model.reflectData().FieldByName(f.Name + "Id").IsValid() {
|
||||
f.foreignKey = f.Name + "Id"
|
||||
f.beforeAssociation = true
|
||||
} else {
|
||||
foreign_key := f.model.typeName() + "Id"
|
||||
if reflect.New(reflect_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
|
||||
f.foreignKey = foreign_key
|
||||
}
|
||||
f.afterAssociation = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseSqlTag(str string) (typ string, addational_typ string, size int) {
|
||||
if str == "-" {
|
||||
typ = str
|
||||
|
|
4
main.go
4
main.go
|
@ -177,7 +177,7 @@ func (s *DB) UpdateColumns(values interface{}, ignore_protected_attrs ...bool) *
|
|||
}
|
||||
|
||||
func (s *DB) Save(value interface{}) *DB {
|
||||
scope := s.clone().newScope(value)
|
||||
scope := s.clone().NewScope(value)
|
||||
if scope.PrimaryKeyZero() {
|
||||
return scope.callCallbacks(s.parent.callback.creates).db.do(value).db
|
||||
} else {
|
||||
|
@ -186,7 +186,7 @@ func (s *DB) Save(value interface{}) *DB {
|
|||
}
|
||||
|
||||
func (s *DB) Delete(value interface{}) *DB {
|
||||
return s.clone().newScope(value).callCallbacks(s.parent.callback.deletes).db
|
||||
return s.clone().NewScope(value).callCallbacks(s.parent.callback.deletes).db
|
||||
}
|
||||
|
||||
func (s *DB) Raw(sql string, values ...interface{}) *DB {
|
||||
|
|
10
scope.go
10
scope.go
|
@ -21,7 +21,7 @@ type Scope struct {
|
|||
startedTransaction bool
|
||||
}
|
||||
|
||||
func (db *DB) newScope(value interface{}) *Scope {
|
||||
func (db *DB) NewScope(value interface{}) *Scope {
|
||||
return &Scope{db: db, Search: db.search, Value: value}
|
||||
}
|
||||
|
||||
|
@ -32,6 +32,14 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
|
|||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) New(value interface{}) *Scope {
|
||||
return &Scope{db: scope.db.parent, Search: &search{}, Value: value}
|
||||
}
|
||||
|
||||
func (scope *Scope) NewDB() *DB {
|
||||
return scope.db.parent
|
||||
}
|
||||
|
||||
func (scope *Scope) DB() sqlCommon {
|
||||
return scope.db.db
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue