diff --git a/callbacks/associations.go b/callbacks/associations.go index 8cc96029..98e0d254 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -18,6 +18,15 @@ func SaveBeforeAssociations(db *gorm.DB) { continue } + setupReferences := func(obj reflect.Value, elem reflect.Value) { + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + pv, _ := ref.PrimaryKey.ValueOf(elem) + ref.ForeignKey.Set(obj, pv) + } + } + } + switch db.Statement.ReflectValue.Kind() { case reflect.Slice: var ( @@ -43,12 +52,7 @@ func SaveBeforeAssociations(db *gorm.DB) { elems = reflect.Append(elems, rv.Addr()) } } else { - for _, ref := range rel.References { - if !ref.OwnPrimaryKey { - pv, _ := ref.PrimaryKey.ValueOf(rv) - ref.ForeignKey.Set(objs[i], pv) - } - } + setupReferences(obj, rv) } } } @@ -56,31 +60,20 @@ func SaveBeforeAssociations(db *gorm.DB) { if elems.Len() > 0 { if db.AddError(db.Session(&gorm.Session{}).Create(elems.Interface()).Error) == nil { for i := 0; i < elems.Len(); i++ { - for _, ref := range rel.References { - if !ref.OwnPrimaryKey { - pv, _ := ref.PrimaryKey.ValueOf(elems.Index(i)) - ref.ForeignKey.Set(objs[i], pv) - } - } + setupReferences(objs[i], elems.Index(i)) } } } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value - if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { - if rv.Kind() == reflect.Ptr { - db.Session(&gorm.Session{}).Create(rv.Interface()) - } else { - db.Session(&gorm.Session{}).Create(rv.Addr().Interface()) - } + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() + } - for _, ref := range rel.References { - if !ref.OwnPrimaryKey { - pv, _ := ref.PrimaryKey.ValueOf(rv) - ref.ForeignKey.Set(db.Statement.ReflectValue, pv) - } - } + if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { + db.Session(&gorm.Session{}).Create(rv.Interface()) + setupReferences(db.Statement.ReflectValue, rv) } } } @@ -113,8 +106,13 @@ func SaveAfterAssociations(db *gorm.DB) { for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) - if rv, zero := rel.Field.ValueOf(obj); !zero { - rv := reflect.ValueOf(rv) + + if _, zero := rel.Field.ValueOf(obj); !zero { + rv := rel.Field.ReflectValueOf(obj) + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() + } + for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(obj) @@ -125,11 +123,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { - if isPtr { - elems = reflect.Append(elems, rv) - } else { - elems = reflect.Append(elems, rv.Addr()) - } + elems = reflect.Append(elems, rv) } } } @@ -140,6 +134,9 @@ func SaveAfterAssociations(db *gorm.DB) { case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + if f.Kind() != reflect.Ptr { + f = f.Addr() + } for _, ref := range rel.References { if ref.OwnPrimaryKey { @@ -151,11 +148,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f); isZero { - if f.Kind() == reflect.Ptr { - db.Session(&gorm.Session{}).Create(f.Interface()) - } else { - db.Session(&gorm.Session{}).Create(f.Addr().Interface()) - } + db.Session(&gorm.Session{}).Create(f.Interface()) } } } @@ -168,9 +161,8 @@ func SaveAfterAssociations(db *gorm.DB) { } fieldType := rel.Field.IndirectFieldType.Elem() - isPtr := true - if fieldType.Kind() != reflect.Ptr { - isPtr = false + isPtr := fieldType.Kind() == reflect.Ptr + if !isPtr { fieldType = reflect.PtrTo(fieldType) } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) @@ -221,46 +213,71 @@ func SaveAfterAssociations(db *gorm.DB) { } fieldType := rel.Field.IndirectFieldType.Elem() - isPtr := true - if fieldType.Kind() != reflect.Ptr { - isPtr = false + isPtr := fieldType.Kind() == reflect.Ptr + if !isPtr { fieldType = reflect.PtrTo(fieldType) } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) + joins := reflect.MakeSlice(reflect.SliceOf(rel.JoinTable.ModelType), 0, 0) + objs := []reflect.Value{} - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - db.Statement.ReflectValue.Index(i) + appendToJoins := func(obj reflect.Value, elem reflect.Value) { + joinValue := reflect.New(rel.JoinTable.ModelType) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(obj) + ref.ForeignKey.Set(joinValue, fv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(joinValue, ref.PrimaryValue) + } else { + fv, _ := ref.PrimaryKey.ValueOf(elem) + ref.ForeignKey.Set(joinValue, fv) + } } - case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.ReflectValue)) + + joins = reflect.Append(joins, joinValue) + } + + appendToElems := func(v reflect.Value) { + if _, zero := rel.Field.ValueOf(v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(v)) for i := 0; i < f.Len(); i++ { elem := f.Index(i) - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) - ref.ForeignKey.Set(elem, fv) - } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(elem, ref.PrimaryValue) - } - } if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem); isZero { + objs = append(objs, v) if isPtr { elems = reflect.Append(elems, elem) } else { elems = reflect.Append(elems, elem.Addr()) } + } else { + appendToJoins(v, elem) } } } } + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + appendToElems(db.Statement.ReflectValue.Index(i)) + } + case reflect.Struct: + appendToElems(db.Statement.ReflectValue) + } + if elems.Len() > 0 { db.Session(&gorm.Session{}).Create(elems.Interface()) + + for i := 0; i < elems.Len(); i++ { + appendToJoins(objs[i], elems.Index(i)) + } + } + + if joins.Len() > 0 { + db.Session(&gorm.Session{}).Create(joins.Interface()) } } } diff --git a/tests/create.go b/tests/create.go index 218e1e59..b4bdd47e 100644 --- a/tests/create.go +++ b/tests/create.go @@ -40,16 +40,53 @@ func TestCreate(t *testing.T, db *gorm.DB) { } else { AssertObjEqual(t, newUser, user, "Name", "Age", "Birthday") } - }) - TestCreateAssociations(t, db) + TestCreateAssociations(t, db) + }) } func TestCreateAssociations(t *testing.T, db *gorm.DB) { + TestCreateBelongsToAssociations(t, db) + TestCreateHasOneAssociations(t, db) + TestCreateHasManyAssociations(t, db) + TestCreateMany2ManyAssociations(t, db) +} + +func TestCreateBelongsToAssociations(t *testing.T, db *gorm.DB) { db.Migrator().DropTable(&Company{}) db.Migrator().AutoMigrate(&Company{}) - t.Run("Create-BelongsToAssociation", func(t *testing.T) { + check := func(t *testing.T, user User) { + if user.Company.Name != "" { + if user.CompanyID == nil { + t.Errorf("Company's foreign key should be saved") + } else { + var company Company + db.First(&company, "id = ?", *user.CompanyID) + if company.Name != user.Company.Name { + t.Errorf("Company's name should be same") + } + } + } else if user.CompanyID != nil { + t.Errorf("Company should not be created for zero value, got: %+v", user.CompanyID) + } + + if user.Manager != nil { + if user.ManagerID == nil { + t.Errorf("Manager's foreign key should be saved") + } else { + var manager User + db.First(&manager, "id = ?", *user.ManagerID) + if manager.Name != user.Manager.Name { + t.Errorf("Manager's name should be same") + } + } + } else if user.ManagerID != nil { + t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) + } + } + + t.Run("BelongsTo", func(t *testing.T) { var user = User{ Name: "create", Age: 18, @@ -62,28 +99,113 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) { t.Fatalf("errors happened when create: %v", err) } - if user.CompanyID == nil { - t.Errorf("Failed to create belongs to association - Company") - } else { - var company Company - db.First(&company, "id = ?", *user.CompanyID) - if company.Name != "company-belongs-to-association" { - t.Errorf("Failed to query saved belongs to association - Company") - } + check(t, user) + }) + + t.Run("BelongsToForBulkInsert", func(t *testing.T) { + var users = []User{{ + Name: "create-1", + Age: 18, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-1"}, + Manager: &User{Name: "manager-belongs-to-association-1"}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-2"}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-3"}, + Manager: &User{Name: "manager-belongs-to-association-3"}, + }} + + if err := db.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) } - if user.ManagerID == nil { - t.Errorf("Failed to create belongs to association - Manager") - } else { - var manager User - db.First(&manager, "id = ?", *user.ManagerID) - if manager.Name != "manager-belongs-to-association" { - t.Errorf("Failed to query saved belongs to association - Manager") - } + for _, user := range users { + check(t, user) } }) - t.Run("Create-HasOneAssociation", func(t *testing.T) { + t.Run("BelongsToForBulkInsertPtrData", func(t *testing.T) { + var users = []*User{{ + Name: "create-1", + Age: 18, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-1"}, + Manager: &User{Name: "manager-belongs-to-association-1"}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-2"}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-3"}, + Manager: &User{Name: "manager-belongs-to-association-3"}, + }} + + if err := db.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, user := range users { + check(t, *user) + } + }) + + t.Run("BelongsToForBulkInsertWithoutPtr", func(t *testing.T) { + var users = []*User{{ + Name: "create-1", + Age: 18, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-1"}, + Manager: &User{Name: "manager-belongs-to-association-1"}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-2"}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association-3"}, + Manager: &User{Name: "manager-belongs-to-association-3"}, + }} + + if err := db.Create(users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, user := range users { + check(t, *user) + } + }) +} + +func TestCreateHasOneAssociations(t *testing.T, db *gorm.DB) { + check := func(t *testing.T, user User) { + if user.Account.ID == 0 { + t.Errorf("Account should be saved") + } else if user.Account.UserID.Int64 != int64(user.ID) { + t.Errorf("Account's foreign key should be saved") + } else { + var account Account + db.First(&account, "id = ?", user.Account.ID) + if account.Number != user.Account.Number { + t.Errorf("Account's number should be sme") + } + } + } + + t.Run("HasOne", func(t *testing.T) { var user = User{ Name: "create", Age: 18, @@ -95,20 +217,103 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) { t.Fatalf("errors happened when create: %v", err) } - if user.Account.ID == 0 { - t.Errorf("Failed to create has one association - Account") - } else if user.Account.UserID.Int64 != int64(user.ID) { - t.Errorf("Failed to create has one association - Account") - } else { - var account Account - db.First(&account, "id = ?", user.Account.ID) - if user.Account.Number != "account-has-one-association" { - t.Errorf("Failed to query saved has one association - Account") - } + check(t, user) + }) + + t.Run("HasOneForBulkInsert", func(t *testing.T) { + var users = []User{{ + Name: "create-1", + Age: 18, + Birthday: Now(), + Account: Account{Number: "account-has-one-association-1"}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Account: Account{Number: "account-has-one-association-2"}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Account: Account{Number: "account-has-one-association-3"}, + }} + + if err := db.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, user := range users { + check(t, user) } }) - t.Run("Create-HasOneAssociation-Polymorphic", func(t *testing.T) { + t.Run("HasOneForBulkInsertPtrData", func(t *testing.T) { + var users = []*User{{ + Name: "create-1", + Age: 18, + Birthday: Now(), + Account: Account{Number: "account-has-one-association-1"}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Account: Account{Number: "account-has-one-association-2"}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Account: Account{Number: "account-has-one-association-3"}, + }} + + if err := db.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, user := range users { + check(t, *user) + } + }) + + t.Run("HasOneForBulkInsertWithoutPtr", func(t *testing.T) { + var users = []User{{ + Name: "create-1", + Age: 18, + Birthday: Now(), + Account: Account{Number: "account-has-one-association-1"}, + }, { + Name: "create-2", + Age: 28, + Birthday: Now(), + Account: Account{Number: "account-has-one-association-2"}, + }, { + Name: "create-3", + Age: 38, + Birthday: Now(), + Account: Account{Number: "account-has-one-association-3"}, + }} + + if err := db.Create(users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, user := range users { + check(t, user) + } + }) + + checkPet := func(t *testing.T, pet Pet) { + if pet.Toy.OwnerID != fmt.Sprint(pet.ID) || pet.Toy.OwnerType != "pets" { + t.Errorf("Failed to create polymorphic has one association - toy owner id %v, owner type %v", pet.Toy.OwnerID, pet.Toy.OwnerType) + } else { + var toy Toy + db.First(&toy, "owner_id = ? and owner_type = ?", pet.Toy.OwnerID, pet.Toy.OwnerType) + if toy.Name != pet.Toy.Name { + t.Errorf("Failed to query saved polymorphic has one association") + } + } + } + + t.Run("PolymorphicHasOne", func(t *testing.T) { var pet = Pet{ Name: "create", Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic"}, @@ -118,18 +323,75 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) { t.Fatalf("errors happened when create: %v", err) } - if pet.Toy.OwnerID != fmt.Sprint(pet.ID) || pet.Toy.OwnerType != "pets" { - t.Errorf("Failed to create polymorphic has one association - toy owner id %v, owner type %v", pet.Toy.OwnerID, pet.Toy.OwnerType) - } else { - var toy Toy - db.First(&toy, "owner_id = ? and owner_type = ?", pet.Toy.OwnerID, pet.Toy.OwnerType) - if toy.Name != "Create-HasOneAssociation-Polymorphic" { - t.Errorf("Failed to query saved polymorphic has one association") - } + checkPet(t, pet) + }) + + t.Run("PolymorphicHasOneForBulkInsert", func(t *testing.T) { + var pets = []Pet{{ + Name: "create-1", + Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-1"}, + }, { + Name: "create-2", + Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-2"}, + }, { + Name: "create-3", + Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-3"}, + }} + + if err := db.Create(&pets).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, pet := range pets { + checkPet(t, pet) } }) - t.Run("Create-HasManyAssociation", func(t *testing.T) { + t.Run("PolymorphicHasOneForBulkInsertPtrData", func(t *testing.T) { + var pets = []*Pet{{ + Name: "create-1", + Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-1"}, + }, { + Name: "create-2", + Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-2"}, + }, { + Name: "create-3", + Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-3"}, + }} + + if err := db.Create(&pets).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, pet := range pets { + checkPet(t, *pet) + } + }) + + t.Run("PolymorphicHasOneForBulkInsertWithoutPtr", func(t *testing.T) { + var pets = []*Pet{{ + Name: "create-1", + Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-1"}, + }, { + Name: "create-2", + Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-2"}, + }, { + Name: "create-3", + Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-3"}, + }} + + if err := db.Create(pets).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, pet := range pets { + checkPet(t, *pet) + } + }) +} + +func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) { + t.Run("HasMany", func(t *testing.T) { var user = User{ Name: "create", Age: 18, @@ -156,7 +418,7 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) { } }) - t.Run("Create-HasManyAssociation-Polymorphic", func(t *testing.T) { + t.Run("PolymorphicHasMany", func(t *testing.T) { var user = User{ Name: "create", Age: 18, @@ -183,3 +445,6 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) { } }) } + +func TestCreateMany2ManyAssociations(t *testing.T, db *gorm.DB) { +}