Add bulk create associations tests

This commit is contained in:
Jinzhu 2020-04-20 11:47:29 +08:00
parent 7bcd95d4b8
commit 43a814ae70
2 changed files with 380 additions and 98 deletions

View File

@ -18,6 +18,15 @@ func SaveBeforeAssociations(db *gorm.DB) {
continue continue
} }
setupReferences := func(obj reflect.Value, elem reflect.Value) {
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(elem)
ref.ForeignKey.Set(obj, pv)
}
}
}
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice: case reflect.Slice:
var ( var (
@ -43,12 +52,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
elems = reflect.Append(elems, rv.Addr()) elems = reflect.Append(elems, rv.Addr())
} }
} else { } else {
for _, ref := range rel.References { setupReferences(obj, rv)
if !ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(rv)
ref.ForeignKey.Set(objs[i], pv)
}
}
} }
} }
} }
@ -56,31 +60,20 @@ func SaveBeforeAssociations(db *gorm.DB) {
if elems.Len() > 0 { if elems.Len() > 0 {
if db.AddError(db.Session(&gorm.Session{}).Create(elems.Interface()).Error) == nil { if db.AddError(db.Session(&gorm.Session{}).Create(elems.Interface()).Error) == nil {
for i := 0; i < elems.Len(); i++ { for i := 0; i < elems.Len(); i++ {
for _, ref := range rel.References { setupReferences(objs[i], elems.Index(i))
if !ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(elems.Index(i))
ref.ForeignKey.Set(objs[i], pv)
}
}
} }
} }
} }
case reflect.Struct: case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { if rv.Kind() != reflect.Ptr {
if rv.Kind() == reflect.Ptr { rv = rv.Addr()
db.Session(&gorm.Session{}).Create(rv.Interface())
} else {
db.Session(&gorm.Session{}).Create(rv.Addr().Interface())
} }
for _, ref := range rel.References { if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero {
if !ref.OwnPrimaryKey { db.Session(&gorm.Session{}).Create(rv.Interface())
pv, _ := ref.PrimaryKey.ValueOf(rv) setupReferences(db.Statement.ReflectValue, rv)
ref.ForeignKey.Set(db.Statement.ReflectValue, pv)
}
}
} }
} }
} }
@ -113,8 +106,13 @@ func SaveAfterAssociations(db *gorm.DB) {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i) obj := db.Statement.ReflectValue.Index(i)
if rv, zero := rel.Field.ValueOf(obj); !zero {
rv := reflect.ValueOf(rv) if _, zero := rel.Field.ValueOf(obj); !zero {
rv := rel.Field.ReflectValueOf(obj)
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj) fv, _ := ref.PrimaryKey.ValueOf(obj)
@ -125,11 +123,7 @@ func SaveAfterAssociations(db *gorm.DB) {
} }
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero { if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero {
if isPtr {
elems = reflect.Append(elems, rv) elems = reflect.Append(elems, rv)
} else {
elems = reflect.Append(elems, rv.Addr())
}
} }
} }
} }
@ -140,6 +134,9 @@ func SaveAfterAssociations(db *gorm.DB) {
case reflect.Struct: case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
if f.Kind() != reflect.Ptr {
f = f.Addr()
}
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
@ -151,11 +148,7 @@ func SaveAfterAssociations(db *gorm.DB) {
} }
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f); isZero { if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f); isZero {
if f.Kind() == reflect.Ptr {
db.Session(&gorm.Session{}).Create(f.Interface()) db.Session(&gorm.Session{}).Create(f.Interface())
} else {
db.Session(&gorm.Session{}).Create(f.Addr().Interface())
}
} }
} }
} }
@ -168,9 +161,8 @@ func SaveAfterAssociations(db *gorm.DB) {
} }
fieldType := rel.Field.IndirectFieldType.Elem() fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := true isPtr := fieldType.Kind() == reflect.Ptr
if fieldType.Kind() != reflect.Ptr { if !isPtr {
isPtr = false
fieldType = reflect.PtrTo(fieldType) fieldType = reflect.PtrTo(fieldType)
} }
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
@ -221,46 +213,71 @@ func SaveAfterAssociations(db *gorm.DB) {
} }
fieldType := rel.Field.IndirectFieldType.Elem() fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := true isPtr := fieldType.Kind() == reflect.Ptr
if fieldType.Kind() != reflect.Ptr { if !isPtr {
isPtr = false
fieldType = reflect.PtrTo(fieldType) fieldType = reflect.PtrTo(fieldType)
} }
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
joins := reflect.MakeSlice(reflect.SliceOf(rel.JoinTable.ModelType), 0, 0)
objs := []reflect.Value{}
switch db.Statement.ReflectValue.Kind() { appendToJoins := func(obj reflect.Value, elem reflect.Value) {
case reflect.Slice: joinValue := reflect.New(rel.JoinTable.ModelType)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { for _, ref := range rel.References {
db.Statement.ReflectValue.Index(i) if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj)
ref.ForeignKey.Set(joinValue, fv)
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(joinValue, ref.PrimaryValue)
} else {
fv, _ := ref.PrimaryKey.ValueOf(elem)
ref.ForeignKey.Set(joinValue, fv)
} }
case reflect.Struct: }
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.ReflectValue)) joins = reflect.Append(joins, joinValue)
}
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v))
for i := 0; i < f.Len(); i++ { for i := 0; i < f.Len(); i++ {
elem := f.Index(i) elem := f.Index(i)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue)
ref.ForeignKey.Set(elem, fv)
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(elem, ref.PrimaryValue)
}
}
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem); isZero { if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem); isZero {
objs = append(objs, v)
if isPtr { if isPtr {
elems = reflect.Append(elems, elem) elems = reflect.Append(elems, elem)
} else { } else {
elems = reflect.Append(elems, elem.Addr()) elems = reflect.Append(elems, elem.Addr())
} }
} else {
appendToJoins(v, elem)
} }
} }
} }
} }
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
appendToElems(db.Statement.ReflectValue.Index(i))
}
case reflect.Struct:
appendToElems(db.Statement.ReflectValue)
}
if elems.Len() > 0 { if elems.Len() > 0 {
db.Session(&gorm.Session{}).Create(elems.Interface()) db.Session(&gorm.Session{}).Create(elems.Interface())
for i := 0; i < elems.Len(); i++ {
appendToJoins(objs[i], elems.Index(i))
}
}
if joins.Len() > 0 {
db.Session(&gorm.Session{}).Create(joins.Interface())
} }
} }
} }

View File

@ -40,16 +40,53 @@ func TestCreate(t *testing.T, db *gorm.DB) {
} else { } else {
AssertObjEqual(t, newUser, user, "Name", "Age", "Birthday") AssertObjEqual(t, newUser, user, "Name", "Age", "Birthday")
} }
})
TestCreateAssociations(t, db) TestCreateAssociations(t, db)
})
} }
func TestCreateAssociations(t *testing.T, db *gorm.DB) { func TestCreateAssociations(t *testing.T, db *gorm.DB) {
TestCreateBelongsToAssociations(t, db)
TestCreateHasOneAssociations(t, db)
TestCreateHasManyAssociations(t, db)
TestCreateMany2ManyAssociations(t, db)
}
func TestCreateBelongsToAssociations(t *testing.T, db *gorm.DB) {
db.Migrator().DropTable(&Company{}) db.Migrator().DropTable(&Company{})
db.Migrator().AutoMigrate(&Company{}) db.Migrator().AutoMigrate(&Company{})
t.Run("Create-BelongsToAssociation", func(t *testing.T) { check := func(t *testing.T, user User) {
if user.Company.Name != "" {
if user.CompanyID == nil {
t.Errorf("Company's foreign key should be saved")
} else {
var company Company
db.First(&company, "id = ?", *user.CompanyID)
if company.Name != user.Company.Name {
t.Errorf("Company's name should be same")
}
}
} else if user.CompanyID != nil {
t.Errorf("Company should not be created for zero value, got: %+v", user.CompanyID)
}
if user.Manager != nil {
if user.ManagerID == nil {
t.Errorf("Manager's foreign key should be saved")
} else {
var manager User
db.First(&manager, "id = ?", *user.ManagerID)
if manager.Name != user.Manager.Name {
t.Errorf("Manager's name should be same")
}
}
} else if user.ManagerID != nil {
t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID)
}
}
t.Run("BelongsTo", func(t *testing.T) {
var user = User{ var user = User{
Name: "create", Name: "create",
Age: 18, Age: 18,
@ -62,28 +99,113 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) {
t.Fatalf("errors happened when create: %v", err) t.Fatalf("errors happened when create: %v", err)
} }
if user.CompanyID == nil { check(t, user)
t.Errorf("Failed to create belongs to association - Company") })
} else {
var company Company t.Run("BelongsToForBulkInsert", func(t *testing.T) {
db.First(&company, "id = ?", *user.CompanyID) var users = []User{{
if company.Name != "company-belongs-to-association" { Name: "create-1",
t.Errorf("Failed to query saved belongs to association - Company") Age: 18,
} Birthday: Now(),
Company: Company{Name: "company-belongs-to-association-1"},
Manager: &User{Name: "manager-belongs-to-association-1"},
}, {
Name: "create-2",
Age: 28,
Birthday: Now(),
Company: Company{Name: "company-belongs-to-association-2"},
}, {
Name: "create-3",
Age: 38,
Birthday: Now(),
Company: Company{Name: "company-belongs-to-association-3"},
Manager: &User{Name: "manager-belongs-to-association-3"},
}}
if err := db.Create(&users).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
} }
if user.ManagerID == nil { for _, user := range users {
t.Errorf("Failed to create belongs to association - Manager") check(t, user)
} else {
var manager User
db.First(&manager, "id = ?", *user.ManagerID)
if manager.Name != "manager-belongs-to-association" {
t.Errorf("Failed to query saved belongs to association - Manager")
}
} }
}) })
t.Run("Create-HasOneAssociation", func(t *testing.T) { t.Run("BelongsToForBulkInsertPtrData", func(t *testing.T) {
var users = []*User{{
Name: "create-1",
Age: 18,
Birthday: Now(),
Company: Company{Name: "company-belongs-to-association-1"},
Manager: &User{Name: "manager-belongs-to-association-1"},
}, {
Name: "create-2",
Age: 28,
Birthday: Now(),
Company: Company{Name: "company-belongs-to-association-2"},
}, {
Name: "create-3",
Age: 38,
Birthday: Now(),
Company: Company{Name: "company-belongs-to-association-3"},
Manager: &User{Name: "manager-belongs-to-association-3"},
}}
if err := db.Create(&users).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
}
for _, user := range users {
check(t, *user)
}
})
t.Run("BelongsToForBulkInsertWithoutPtr", func(t *testing.T) {
var users = []*User{{
Name: "create-1",
Age: 18,
Birthday: Now(),
Company: Company{Name: "company-belongs-to-association-1"},
Manager: &User{Name: "manager-belongs-to-association-1"},
}, {
Name: "create-2",
Age: 28,
Birthday: Now(),
Company: Company{Name: "company-belongs-to-association-2"},
}, {
Name: "create-3",
Age: 38,
Birthday: Now(),
Company: Company{Name: "company-belongs-to-association-3"},
Manager: &User{Name: "manager-belongs-to-association-3"},
}}
if err := db.Create(users).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
}
for _, user := range users {
check(t, *user)
}
})
}
func TestCreateHasOneAssociations(t *testing.T, db *gorm.DB) {
check := func(t *testing.T, user User) {
if user.Account.ID == 0 {
t.Errorf("Account should be saved")
} else if user.Account.UserID.Int64 != int64(user.ID) {
t.Errorf("Account's foreign key should be saved")
} else {
var account Account
db.First(&account, "id = ?", user.Account.ID)
if account.Number != user.Account.Number {
t.Errorf("Account's number should be sme")
}
}
}
t.Run("HasOne", func(t *testing.T) {
var user = User{ var user = User{
Name: "create", Name: "create",
Age: 18, Age: 18,
@ -95,20 +217,103 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) {
t.Fatalf("errors happened when create: %v", err) t.Fatalf("errors happened when create: %v", err)
} }
if user.Account.ID == 0 { check(t, user)
t.Errorf("Failed to create has one association - Account") })
} else if user.Account.UserID.Int64 != int64(user.ID) {
t.Errorf("Failed to create has one association - Account") t.Run("HasOneForBulkInsert", func(t *testing.T) {
} else { var users = []User{{
var account Account Name: "create-1",
db.First(&account, "id = ?", user.Account.ID) Age: 18,
if user.Account.Number != "account-has-one-association" { Birthday: Now(),
t.Errorf("Failed to query saved has one association - Account") Account: Account{Number: "account-has-one-association-1"},
}, {
Name: "create-2",
Age: 28,
Birthday: Now(),
Account: Account{Number: "account-has-one-association-2"},
}, {
Name: "create-3",
Age: 38,
Birthday: Now(),
Account: Account{Number: "account-has-one-association-3"},
}}
if err := db.Create(&users).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
} }
for _, user := range users {
check(t, user)
} }
}) })
t.Run("Create-HasOneAssociation-Polymorphic", func(t *testing.T) { t.Run("HasOneForBulkInsertPtrData", func(t *testing.T) {
var users = []*User{{
Name: "create-1",
Age: 18,
Birthday: Now(),
Account: Account{Number: "account-has-one-association-1"},
}, {
Name: "create-2",
Age: 28,
Birthday: Now(),
Account: Account{Number: "account-has-one-association-2"},
}, {
Name: "create-3",
Age: 38,
Birthday: Now(),
Account: Account{Number: "account-has-one-association-3"},
}}
if err := db.Create(&users).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
}
for _, user := range users {
check(t, *user)
}
})
t.Run("HasOneForBulkInsertWithoutPtr", func(t *testing.T) {
var users = []User{{
Name: "create-1",
Age: 18,
Birthday: Now(),
Account: Account{Number: "account-has-one-association-1"},
}, {
Name: "create-2",
Age: 28,
Birthday: Now(),
Account: Account{Number: "account-has-one-association-2"},
}, {
Name: "create-3",
Age: 38,
Birthday: Now(),
Account: Account{Number: "account-has-one-association-3"},
}}
if err := db.Create(users).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
}
for _, user := range users {
check(t, user)
}
})
checkPet := func(t *testing.T, pet Pet) {
if pet.Toy.OwnerID != fmt.Sprint(pet.ID) || pet.Toy.OwnerType != "pets" {
t.Errorf("Failed to create polymorphic has one association - toy owner id %v, owner type %v", pet.Toy.OwnerID, pet.Toy.OwnerType)
} else {
var toy Toy
db.First(&toy, "owner_id = ? and owner_type = ?", pet.Toy.OwnerID, pet.Toy.OwnerType)
if toy.Name != pet.Toy.Name {
t.Errorf("Failed to query saved polymorphic has one association")
}
}
}
t.Run("PolymorphicHasOne", func(t *testing.T) {
var pet = Pet{ var pet = Pet{
Name: "create", Name: "create",
Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic"}, Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic"},
@ -118,18 +323,75 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) {
t.Fatalf("errors happened when create: %v", err) t.Fatalf("errors happened when create: %v", err)
} }
if pet.Toy.OwnerID != fmt.Sprint(pet.ID) || pet.Toy.OwnerType != "pets" { checkPet(t, pet)
t.Errorf("Failed to create polymorphic has one association - toy owner id %v, owner type %v", pet.Toy.OwnerID, pet.Toy.OwnerType) })
} else {
var toy Toy t.Run("PolymorphicHasOneForBulkInsert", func(t *testing.T) {
db.First(&toy, "owner_id = ? and owner_type = ?", pet.Toy.OwnerID, pet.Toy.OwnerType) var pets = []Pet{{
if toy.Name != "Create-HasOneAssociation-Polymorphic" { Name: "create-1",
t.Errorf("Failed to query saved polymorphic has one association") Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-1"},
}, {
Name: "create-2",
Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-2"},
}, {
Name: "create-3",
Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-3"},
}}
if err := db.Create(&pets).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
} }
for _, pet := range pets {
checkPet(t, pet)
} }
}) })
t.Run("Create-HasManyAssociation", func(t *testing.T) { t.Run("PolymorphicHasOneForBulkInsertPtrData", func(t *testing.T) {
var pets = []*Pet{{
Name: "create-1",
Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-1"},
}, {
Name: "create-2",
Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-2"},
}, {
Name: "create-3",
Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-3"},
}}
if err := db.Create(&pets).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
}
for _, pet := range pets {
checkPet(t, *pet)
}
})
t.Run("PolymorphicHasOneForBulkInsertWithoutPtr", func(t *testing.T) {
var pets = []*Pet{{
Name: "create-1",
Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-1"},
}, {
Name: "create-2",
Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-2"},
}, {
Name: "create-3",
Toy: Toy{Name: "Create-HasOneAssociation-Polymorphic-3"},
}}
if err := db.Create(pets).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
}
for _, pet := range pets {
checkPet(t, *pet)
}
})
}
func TestCreateHasManyAssociations(t *testing.T, db *gorm.DB) {
t.Run("HasMany", func(t *testing.T) {
var user = User{ var user = User{
Name: "create", Name: "create",
Age: 18, Age: 18,
@ -156,7 +418,7 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) {
} }
}) })
t.Run("Create-HasManyAssociation-Polymorphic", func(t *testing.T) { t.Run("PolymorphicHasMany", func(t *testing.T) {
var user = User{ var user = User{
Name: "create", Name: "create",
Age: 18, Age: 18,
@ -183,3 +445,6 @@ func TestCreateAssociations(t *testing.T, db *gorm.DB) {
} }
}) })
} }
func TestCreateMany2ManyAssociations(t *testing.T, db *gorm.DB) {
}