From ee6a6827a888fbe3c5d231112484c3d931b026cf Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 27 Jan 2014 08:26:59 +0800 Subject: [PATCH] implement callback shared --- callback_shared.go | 60 ++++++++++++++++++++++++++ field.go | 105 +++++++++++++++++++++++++-------------------- main.go | 4 +- scope.go | 10 ++++- 4 files changed, 130 insertions(+), 49 deletions(-) diff --git a/callback_shared.go b/callback_shared.go index 896413af..00400230 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -1,5 +1,7 @@ package gorm +import "reflect" + func BeginTransaction(scope *Scope) { scope.Begin() } @@ -9,7 +11,65 @@ func CommitOrRollbackTransaction(scope *Scope) { } func SaveBeforeAssociations(scope *Scope) { + for _, field := range scope.Fields() { + if field.BeforeAssociation && !field.IsBlank && !field.IsIgnored { + value := reflect.ValueOf(field.Value) + newDB := scope.NewDB() + + if value.CanAddr() { + newDB.Save(value.Addr().Interface()) + } else { + // If can't take address, then clone the value and set it back + destValue := reflect.New(reflect.ValueOf(field.Value).Type()).Elem() + for _, f := range newDB.NewScope(field.Value).Fields() { + destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) + } + newDB.Save(destValue.Addr().Interface()) + scope.SetColumn(field.Name, destValue.Interface()) + } + + if len(field.foreignKey) > 0 { + scope.SetColumn(field.foreignKey, scope.PrimaryKeyValue()) + } + } + } } func SaveAfterAssociations(scope *Scope) { + for _, field := range scope.Fields() { + if field.AfterAssociation && !field.IsBlank && !field.IsIgnored { + value := reflect.ValueOf(field.Value) + + switch value.Kind() { + case reflect.Slice: + for i := 0; i < value.Len(); i++ { + newDB := scope.NewDB() + elem := value.Index(i).Addr().Interface() + + if len(field.foreignKey) > 0 { + newDB.NewScope(elem).SetColumn(field.foreignKey, scope.PrimaryKeyValue()) + } + + newDB.Save(elem) + } + default: + newDB := scope.NewDB() + if value.CanAddr() { + newDB.NewScope(field.Value).SetColumn(field.foreignKey, scope.PrimaryKeyValue()) + newDB.Save(field.Value) + } else { + destValue := reflect.New(reflect.TypeOf(field.Value)).Elem() + + for _, f := range newDB.NewScope(destValue).Fields() { + destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) + } + + elem := destValue.Addr().Interface() + newDB.NewScope(elem).SetColumn(field.foreignKey, scope.PrimaryKeyValue()) + newDB.Save(elem) + scope.SetColumn(field.Name, destValue.Interface()) + } + } + } + } } diff --git a/field.go b/field.go index 21083227..011289f8 100644 --- a/field.go +++ b/field.go @@ -9,28 +9,32 @@ import ( ) type Field struct { - Name string - DBName string - Value interface{} - IsBlank bool - IsIgnored bool - Tag string - AddationalTag string - Size int - SqlTag string + Name string + DBName string + Value interface{} + IsBlank bool + IsIgnored bool + Tag string + AddationalTag string + Size int + SqlTag string + ForeignKey string + BeforeAssociation bool + AfterAssociation bool - dbName string - model *Model - isBlank bool - ignoreField bool - isPrimaryKey bool - autoCreateTime bool - autoUpdateTime bool foreignKey string beforeAssociation bool afterAssociation bool - reflectValue reflect.Value - structField reflect.StructField + + dbName string + model *Model + isBlank bool + ignoreField bool + isPrimaryKey bool + autoCreateTime bool + autoUpdateTime bool + reflectValue reflect.Value + structField reflect.StructField } func (f *Field) IsScanner() bool { @@ -43,6 +47,43 @@ func (f *Field) IsTime() bool { return is_time } +func (f *Field) parseAssociation() { + elem := reflect.Indirect(reflect.ValueOf(f.Value)) + typ := elem.Type() + + switch elem.Kind() { + case reflect.Slice: + typ = typ.Elem() + + if _, ok := f.Value.([]byte); !ok { + foreignKey := typ.Name() + "Id" + if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { + f.ForeignKey = foreignKey + f.foreignKey = foreignKey + } + f.AfterAssociation = true + f.afterAssociation = true + } + case reflect.Struct: + if !f.IsTime() && !f.IsScanner() { + if elem.FieldByName(f.Name + "Id").IsValid() { + f.foreignKey = f.Name + "Id" + f.beforeAssociation = true + f.ForeignKey = f.Name + "Id" + f.BeforeAssociation = true + } else { + foreignKey := typ.Name() + "Id" + if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { + f.foreignKey = foreignKey + f.ForeignKey = foreignKey + } + f.afterAssociation = true + f.AfterAssociation = true + } + } + } +} + func (f *Field) parseBlank() { f.isBlank = isBlank(f.reflectValue) } @@ -103,34 +144,6 @@ func (f *Field) sqlTag() (str string) { return typ } -func (f *Field) parseAssociation() { - reflect_value := f.reflectValue - - switch reflect_value.Kind() { - case reflect.Slice: - if _, ok := f.Value.([]byte); !ok { - foreign_key := f.model.typeName() + "Id" - if reflect.New(reflect_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() { - f.foreignKey = foreign_key - } - f.afterAssociation = true - } - case reflect.Struct: - if !f.isTime() && !f.isScanner() { - if f.model.reflectData().FieldByName(f.Name + "Id").IsValid() { - f.foreignKey = f.Name + "Id" - f.beforeAssociation = true - } else { - foreign_key := f.model.typeName() + "Id" - if reflect.New(reflect_value.Type()).Elem().FieldByName(foreign_key).IsValid() { - f.foreignKey = foreign_key - } - f.afterAssociation = true - } - } - } -} - func parseSqlTag(str string) (typ string, addational_typ string, size int) { if str == "-" { typ = str diff --git a/main.go b/main.go index bee12997..717e535c 100644 --- a/main.go +++ b/main.go @@ -177,7 +177,7 @@ func (s *DB) UpdateColumns(values interface{}, ignore_protected_attrs ...bool) * } func (s *DB) Save(value interface{}) *DB { - scope := s.clone().newScope(value) + scope := s.clone().NewScope(value) if scope.PrimaryKeyZero() { return scope.callCallbacks(s.parent.callback.creates).db.do(value).db } else { @@ -186,7 +186,7 @@ func (s *DB) Save(value interface{}) *DB { } func (s *DB) Delete(value interface{}) *DB { - return s.clone().newScope(value).callCallbacks(s.parent.callback.deletes).db + return s.clone().NewScope(value).callCallbacks(s.parent.callback.deletes).db } func (s *DB) Raw(sql string, values ...interface{}) *DB { diff --git a/scope.go b/scope.go index c69cac99..ea2a015c 100644 --- a/scope.go +++ b/scope.go @@ -21,7 +21,7 @@ type Scope struct { startedTransaction bool } -func (db *DB) newScope(value interface{}) *Scope { +func (db *DB) NewScope(value interface{}) *Scope { return &Scope{db: db, Search: db.search, Value: value} } @@ -32,6 +32,14 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { return scope } +func (scope *Scope) New(value interface{}) *Scope { + return &Scope{db: scope.db.parent, Search: &search{}, Value: value} +} + +func (scope *Scope) NewDB() *DB { + return scope.db.parent +} + func (scope *Scope) DB() sqlCommon { return scope.db.db }