diff --git a/callback_shared.go b/callback_shared.go index d007ea7c..dd9445c9 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -16,24 +16,27 @@ func CommitOrRollbackTransaction(scope *Scope) { func SaveBeforeAssociations(scope *Scope) { for _, field := range scope.Fields() { - if field.BeforeAssociation && !field.IsBlank && !field.IsIgnored { - value := reflect.ValueOf(field.Value) - newDB := scope.NewDB() + if !field.IsBlank && !field.IsIgnored { + relationship := field.Relationship + if relationship != nil && relationship.kind == "belongs_to" { + value := reflect.ValueOf(field.Value) + newDB := scope.NewDB() - if value.CanAddr() { - scope.Err(newDB.Save(value.Addr().Interface()).Error) - } else { - // If can't take address, then clone the value and set it back - value = reflect.New(reflect.ValueOf(field.Value).Type()).Elem() - for _, f := range newDB.NewScope(field.Value).Fields() { - value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) + if value.CanAddr() { + scope.Err(newDB.Save(value.Addr().Interface()).Error) + } else { + // If can't take address, then clone the value and set it back + value = reflect.New(reflect.ValueOf(field.Value).Type()).Elem() + for _, f := range newDB.NewScope(field.Value).Fields() { + value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) + } + scope.Err(newDB.Save(value.Addr().Interface()).Error) + scope.SetColumn(field.Name, value.Interface()) } - scope.Err(newDB.Save(value.Addr().Interface()).Error) - scope.SetColumn(field.Name, value.Interface()) - } - if field.Relationship != nil && field.Relationship.foreignKey != "" { - scope.SetColumn(field.Relationship.foreignKey, newDB.NewScope(value.Interface()).PrimaryKeyValue()) + if relationship.foreignKey != "" { + scope.SetColumn(relationship.foreignKey, newDB.NewScope(value.Interface()).PrimaryKeyValue()) + } } } } @@ -41,66 +44,70 @@ func SaveBeforeAssociations(scope *Scope) { func SaveAfterAssociations(scope *Scope) { for _, field := range scope.Fields() { - if field.AfterAssociation && !field.IsBlank && !field.IsIgnored { - value := reflect.ValueOf(field.Value) + if !field.IsBlank && !field.IsIgnored { + relationship := field.Relationship + if relationship != nil && + (relationship.kind == "has_one" || relationship.kind == "has_many" || relationship.kind == "many_to_many") { + value := reflect.ValueOf(field.Value) - switch value.Kind() { - case reflect.Slice: - for i := 0; i < value.Len(); i++ { - newDB := scope.NewDB() - elem := value.Index(i).Addr().Interface() + switch value.Kind() { + case reflect.Slice: + for i := 0; i < value.Len(); i++ { + newDB := scope.NewDB() + elem := value.Index(i).Addr().Interface() - if field.Relationship != nil && field.Relationship.joinTable == "" && field.Relationship.foreignKey != "" { - newDB.NewScope(elem).SetColumn(field.Relationship.foreignKey, scope.PrimaryKeyValue()) - } + if relationship.joinTable == "" && relationship.foreignKey != "" { + newDB.NewScope(elem).SetColumn(relationship.foreignKey, scope.PrimaryKeyValue()) + } - scope.Err(newDB.Save(elem).Error) + scope.Err(newDB.Save(elem).Error) - if field.Relationship != nil && field.Relationship.joinTable != "" { - newScope := scope.New(elem) - joinTable := field.Relationship.joinTable - foreignKey := ToSnake(field.Relationship.foreignKey) - foreignValue := fmt.Sprintf("%v", scope.PrimaryKeyValue()) - associationForeignKey := ToSnake(field.Relationship.associationForeignKey) - associationForeignValue := fmt.Sprintf("%v", newScope.PrimaryKeyValue()) + if relationship.joinTable != "" { + newScope := scope.New(elem) + joinTable := relationship.joinTable + foreignKey := ToSnake(relationship.foreignKey) + foreignValue := fmt.Sprintf("%v", scope.PrimaryKeyValue()) + associationForeignKey := ToSnake(relationship.associationForeignKey) + associationForeignValue := fmt.Sprintf("%v", newScope.PrimaryKeyValue()) - newScope.Raw(fmt.Sprintf( - "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = %v AND %v = %v);", - joinTable, - strings.Join([]string{scope.Quote(foreignKey), scope.Quote(associationForeignKey)}, ","), - strings.Join([]string{newScope.AddToVars(foreignValue), newScope.AddToVars(associationForeignValue)}, ","), - scope.Dialect().SelectFromDummyTable(), - joinTable, - scope.Quote(foreignKey), - newScope.AddToVars(foreignValue), - scope.Quote(associationForeignKey), - newScope.AddToVars(associationForeignValue), - )) - if _, err := scope.DB().Exec(newScope.Sql, newScope.SqlVars...); err != nil { - scope.Err(err) + newScope.Raw(fmt.Sprintf( + "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = %v AND %v = %v);", + joinTable, + strings.Join([]string{scope.Quote(foreignKey), scope.Quote(associationForeignKey)}, ","), + strings.Join([]string{newScope.AddToVars(foreignValue), newScope.AddToVars(associationForeignValue)}, ","), + scope.Dialect().SelectFromDummyTable(), + joinTable, + scope.Quote(foreignKey), + newScope.AddToVars(foreignValue), + scope.Quote(associationForeignKey), + newScope.AddToVars(associationForeignValue), + )) + if _, err := scope.DB().Exec(newScope.Sql, newScope.SqlVars...); err != nil { + scope.Err(err) + } } } - } - default: - newDB := scope.NewDB() - if value.CanAddr() { - if field.Relationship != nil { - newDB.NewScope(field.Value).SetColumn(field.Relationship.foreignKey, scope.PrimaryKeyValue()) - } - scope.Err(newDB.Save(field.Value).Error) - } else { - destValue := reflect.New(reflect.TypeOf(field.Value)).Elem() + default: + newDB := scope.NewDB() + if value.CanAddr() { + if relationship.foreignKey != "" { + newDB.NewScope(field.Value).SetColumn(relationship.foreignKey, scope.PrimaryKeyValue()) + } + scope.Err(newDB.Save(field.Value).Error) + } else { + destValue := reflect.New(reflect.TypeOf(field.Value)).Elem() - for _, f := range newDB.NewScope(field.Value).Fields() { - destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) - } + for _, f := range newDB.NewScope(field.Value).Fields() { + destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) + } - elem := destValue.Addr().Interface() - if field.Relationship != nil { - newDB.NewScope(elem).SetColumn(field.Relationship.foreignKey, scope.PrimaryKeyValue()) + elem := destValue.Addr().Interface() + if relationship.foreignKey != "" { + newDB.NewScope(elem).SetColumn(relationship.foreignKey, scope.PrimaryKeyValue()) + } + scope.Err(newDB.Save(elem).Error) + scope.SetColumn(field.Name, destValue.Interface()) } - scope.Err(newDB.Save(elem).Error) - scope.SetColumn(field.Name, destValue.Interface()) } } } diff --git a/field.go b/field.go index 57cd80ba..4134e58f 100644 --- a/field.go +++ b/field.go @@ -10,20 +10,19 @@ type relationship struct { joinTable string foreignKey string associationForeignKey string + kind string } type Field struct { - Name string - DBName string - Value interface{} - IsBlank bool - IsIgnored bool - Tag reflect.StructTag - SqlTag string - BeforeAssociation bool - AfterAssociation bool - isPrimaryKey bool - Relationship *relationship + Name string + DBName string + Value interface{} + IsBlank bool + IsIgnored bool + Tag reflect.StructTag + SqlTag string + isPrimaryKey bool + Relationship *relationship } func (f *Field) IsScanner() bool { diff --git a/scope.go b/scope.go index 25747bf4..f1a07b98 100644 --- a/scope.go +++ b/scope.go @@ -284,29 +284,30 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field { } } - field.AfterAssociation = true field.Relationship = &relationship{ joinTable: many2many, foreignKey: foreignKey, associationForeignKey: associationForeignKey, + kind: "has_many", + } + + if many2many != "" { + field.Relationship.kind = "many_to_many" } } case reflect.Struct: if !field.IsTime() && !field.IsScanner() { if foreignKey == "" && scope.HasColumn(field.Name+"Id") { - field.Relationship = &relationship{foreignKey: field.Name + "Id"} - field.BeforeAssociation = true + field.Relationship = &relationship{foreignKey: field.Name + "Id", kind: "belongs_to"} } else if scope.HasColumn(foreignKey) { - field.Relationship = &relationship{foreignKey: foreignKey} - field.BeforeAssociation = true + field.Relationship = &relationship{foreignKey: foreignKey, kind: "belongs_to"} } else { if foreignKey == "" { foreignKey = scopeTyp.Name() + "Id" } if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { - field.Relationship = &relationship{foreignKey: foreignKey} + field.Relationship = &relationship{foreignKey: foreignKey, kind: "has_one"} } - field.AfterAssociation = true } } }