From 4a540f3ac83a6765b8876be67994ecde022c6198 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 27 Oct 2016 10:31:46 +0800 Subject: [PATCH] Add tag to support skip nested save for associations --- association_test.go | 21 ++++++++++ callback_save.go | 97 ++++++++++++++++++++++++--------------------- scope.go | 9 ++++- 3 files changed, 80 insertions(+), 47 deletions(-) diff --git a/association_test.go b/association_test.go index 02974a98..c84f84ed 100644 --- a/association_test.go +++ b/association_test.go @@ -884,3 +884,24 @@ func TestHasManyChildrenWithOneStruct(t *testing.T) { DB.Save(&category) } + +func TestSkipSaveAssociation(t *testing.T) { + type Company struct { + gorm.Model + Name string + } + + type User struct { + gorm.Model + Name string + CompanyID uint + Company Company `gorm:"save_associations:false"` + } + DB.AutoMigrate(&Company{}, &User{}) + + DB.Save(&User{Name: "jinzhu", Company: Company{Name: "skip_save_association"}}) + + if !DB.Where("name = ?", "skip_save_association").First(&Company{}).RecordNotFound() { + t.Errorf("Company skip_save_association should not been saved") + } +} diff --git a/callback_save.go b/callback_save.go index ea9ec174..f4bc918e 100644 --- a/callback_save.go +++ b/callback_save.go @@ -10,22 +10,31 @@ func commitOrRollbackTransactionCallback(scope *Scope) { scope.CommitOrRollback() } +func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) { + if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { + if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") { + if relationship := field.Relationship; relationship != nil { + return true, relationship + } + } + } + return false, nil +} + func saveBeforeAssociationsCallback(scope *Scope) { if !scope.shouldSaveAssociations() { return } for _, field := range scope.Fields() { - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - fieldValue := field.Field.Addr().Interface() - scope.Err(scope.NewDB().Save(fieldValue).Error) - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) - } + if ok, relationship := saveFieldAsAssociation(scope, field); ok && relationship.Kind == "belongs_to" { + fieldValue := field.Field.Addr().Interface() + scope.Err(scope.NewDB().Save(fieldValue).Error) + if len(relationship.ForeignFieldNames) != 0 { + // set value's foreign key + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) } } } @@ -38,41 +47,18 @@ func saveAfterAssociationsCallback(scope *Scope) { return } for _, field := range scope.Fields() { - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if relationship := field.Relationship; relationship != nil && - (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { - value := field.Field + if ok, relationship := saveFieldAsAssociation(scope, field); ok && + (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { + value := field.Field - switch value.Kind() { - case reflect.Slice: - for i := 0; i < value.Len(); i++ { - newDB := scope.NewDB() - elem := value.Index(i).Addr().Interface() - newScope := newDB.NewScope(elem) + switch value.Kind() { + case reflect.Slice: + for i := 0; i < value.Len(); i++ { + newDB := scope.NewDB() + elem := value.Index(i).Addr().Interface() + newScope := newDB.NewScope(elem) - if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) - } - - scope.Err(newDB.Save(elem).Error) - - if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { - scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) - } - } - default: - elem := value.Addr().Interface() - newScope := scope.New(elem) - if len(relationship.ForeignFieldNames) != 0 { + if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { for idx, fieldName := range relationship.ForeignFieldNames { associationForeignName := relationship.AssociationForeignDBNames[idx] if f, ok := scope.FieldByName(associationForeignName); ok { @@ -84,8 +70,29 @@ func saveAfterAssociationsCallback(scope *Scope) { if relationship.PolymorphicType != "" { scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) } - scope.Err(scope.NewDB().Save(elem).Error) + + scope.Err(newDB.Save(elem).Error) + + if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { + scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) + } } + default: + elem := value.Addr().Interface() + newScope := scope.New(elem) + if len(relationship.ForeignFieldNames) != 0 { + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if f, ok := scope.FieldByName(associationForeignName); ok { + scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) + } + } + } + + if relationship.PolymorphicType != "" { + scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) + } + scope.Err(scope.NewDB().Save(elem).Error) } } } diff --git a/scope.go b/scope.go index 4c9c7922..2a7eeea4 100644 --- a/scope.go +++ b/scope.go @@ -964,8 +964,13 @@ func (scope *Scope) changeableField(field *Field) bool { } func (scope *Scope) shouldSaveAssociations() bool { - if saveAssociations, ok := scope.Get("gorm:save_associations"); ok && !saveAssociations.(bool) { - return false + if saveAssociations, ok := scope.Get("gorm:save_associations"); ok { + if v, ok := saveAssociations.(bool); ok && !v { + return false + } + if v, ok := saveAssociations.(string); ok && (v != "skip") { + return false + } } return true && !scope.HasError() }