diff --git a/callbacks/associations.go b/callbacks/associations.go index 1df0103a..283a2666 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -10,15 +10,18 @@ import ( func SaveBeforeAssociations(db *gorm.DB) { if db.Statement.Schema != nil { + // Save Belongs To associations for _, rel := range db.Statement.Schema.Relationships.BelongsTo { creatable, updatable, saveRef := saveAssociationCheck(db, rel.Field) + if !(creatable || updatable) { + continue + } switch db.Statement.ReflectValue.Kind() { case reflect.Slice: case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) - _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f) if isZero && creatable { @@ -51,6 +54,51 @@ func SaveBeforeAssociations(db *gorm.DB) { } } +func SaveAfterAssociations(db *gorm.DB) { + // Save Has One associations + for _, rel := range db.Statement.Schema.Relationships.HasOne { + creatable, updatable, saveRef := saveAssociationCheck(db, rel.Field) + if !(creatable || updatable) { + continue + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice: + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { + f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + + if saveRef { + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) + ref.ForeignKey.Set(f, fv) + } + } + } + + _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f) + + if isZero && creatable { + if f.Kind() == reflect.Ptr { + db.Session(&gorm.Session{}).Create(f.Interface()) + } else { + db.Session(&gorm.Session{}).Create(f.Addr().Interface()) + } + } else if !isZero && updatable { + if f.Kind() == reflect.Ptr { + db.Session(&gorm.Session{}).Save(f.Interface()) + } else { + db.Session(&gorm.Session{}).Save(f.Addr().Interface()) + } + } else { + continue + } + } + } + } +} + func saveAssociationCheck(db *gorm.DB, field *schema.Field) (bool, bool, bool) { creatable := field.Creatable updatable := field.Updatable diff --git a/callbacks/create.go b/callbacks/create.go index 829c9c4c..9dc8dc67 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -151,9 +151,6 @@ func CreateWithReturning(db *gorm.DB) { } } -func SaveAfterAssociations(db *gorm.DB) { -} - func AfterCreate(db *gorm.DB) { if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { callMethod := func(value interface{}) bool { diff --git a/tests/create.go b/tests/create.go index 74a010dc..b8e9245b 100644 --- a/tests/create.go +++ b/tests/create.go @@ -81,4 +81,29 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) { } } }) + + t.Run("Create-HasOneAssociation", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + Account: Account{Number: "account-has-one-association"}, + } + + if err := db.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + if user.Account.ID == 0 { + t.Errorf("Failed to create has one association - Account") + } else if user.Account.UserID.Int64 != int64(user.ID) { + t.Errorf("Failed to create has one association - Account") + } else { + var account Account + db.First(&account, "id = ?", user.Account.ID) + if user.Account.Number != "account-has-one-association" { + t.Errorf("Failed to query saved has one association - Account") + } + } + }) }