Test shared association

This commit is contained in:
Jinzhu 2020-05-24 21:46:33 +08:00
parent 2db33730b6
commit 677c745b62
2 changed files with 67 additions and 13 deletions

View File

@ -195,6 +195,8 @@ func (association *Association) Delete(values ...interface{}) error {
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
updateAttrs[ref.ForeignKey.DBName] = nil updateAttrs[ref.ForeignKey.DBName] = nil
} }
} else {
tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryKey})
} }
} }
@ -208,6 +210,15 @@ func (association *Association) Delete(values ...interface{}) error {
conds := rel.ToQueryConditions(reflectValue) conds := rel.ToQueryConditions(reflectValue)
tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs)
case schema.BelongsTo: case schema.BelongsTo:
primaryKeys := []string{}
for _, field := range rel.Schema.PrimaryFields {
primaryKeys = append(primaryKeys, field.DBName)
}
_, queryValues := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields)
if column, values := schema.ToQueryValues(primaryKeys, queryValues); len(values) > 0 {
tx.Where(clause.IN{Column: column, Values: values})
}
modelValue := reflect.New(rel.Schema.ModelType).Interface() modelValue := reflect.New(rel.Schema.ModelType).Interface()
tx.Model(modelValue).UpdateColumns(updateAttrs) tx.Model(modelValue).UpdateColumns(updateAttrs)
case schema.Many2Many: case schema.Many2Many:
@ -353,7 +364,6 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
} }
selectedColumns := []string{association.Relationship.Name} selectedColumns := []string{association.Relationship.Name}
hasZero := false
for _, ref := range association.Relationship.References { for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey { if !ref.OwnPrimaryKey {
selectedColumns = append(selectedColumns, ref.ForeignKey.Name) selectedColumns = append(selectedColumns, ref.ForeignKey.Name)
@ -375,13 +385,16 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
break break
} }
association.Error = errors.New("invalid association values, length doesn't match") association.Error = errors.New("invalid association values, length doesn't match")
return
} }
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
if !hasZero { if len(values) > 0 {
_, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue.Index(i)) // TODO support save slice data, sql with case
err := association.DB.Session(&Session{}).Select(selectedColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error
association.DB.AddError(err)
} }
} }
case reflect.Struct: case reflect.Struct:
@ -399,13 +412,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
appendToRelations(reflectValue, rv, clear && idx == 0) appendToRelations(reflectValue, rv, clear && idx == 0)
} }
_, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue)
}
if len(values) > 0 { if len(values) > 0 {
if hasZero {
association.DB.Create(reflectValue.Addr().Interface())
} else {
association.DB.Select(selectedColumns).Model(nil).Save(reflectValue.Addr().Interface()) association.DB.Select(selectedColumns).Model(nil).Save(reflectValue.Addr().Interface())
} }
} }

View File

@ -164,20 +164,49 @@ func TestBelongsToAssociationForSlice(t *testing.T) {
// Find // Find
var companies []Company var companies []Company
if DB.Model(users).Association("Company").Find(&companies); len(companies) != 3 { if DB.Model(&users).Association("Company").Find(&companies); len(companies) != 3 {
t.Errorf("companies count should be %v, but got %v", 3, len(companies)) t.Errorf("companies count should be %v, but got %v", 3, len(companies))
} }
var managers []User var managers []User
if DB.Model(users).Association("Manager").Find(&managers); len(managers) != 2 { if DB.Model(&users).Association("Manager").Find(&managers); len(managers) != 2 {
t.Errorf("managers count should be %v, but got %v", 2, len(managers)) t.Errorf("managers count should be %v, but got %v", 2, len(managers))
} }
// Append // Append
DB.Model(&users).Association("Company").Append(
&Company{Name: "company-slice-append-1"},
&Company{Name: "company-slice-append-2"},
&Company{Name: "company-slice-append-3"},
)
// Replace AssertAssociationCount(t, users, "Company", 3, "After Append")
DB.Model(&users).Association("Manager").Append(
GetUser("manager-slice-belongs-to-1", Config{}),
GetUser("manager-slice-belongs-to-2", Config{}),
GetUser("manager-slice-belongs-to-3", Config{}),
)
AssertAssociationCount(t, users, "Manager", 3, "After Append")
if err := DB.Model(&users).Association("Manager").Append(
GetUser("manager-slice-belongs-to-test-1", Config{}),
).Error; err == nil {
t.Errorf("unmatched length when update user's manager")
}
// Replace -> same as append
// Delete // Delete
if err := DB.Model(&users).Association("Company").Delete(&users[0].Company); err != nil {
t.Errorf("no error should happend when deleting company, but got %v", err)
}
if users[0].CompanyID != nil || users[0].Company.ID != 0 {
t.Errorf("users[0]'s company should be deleted'")
}
AssertAssociationCount(t, users, "Company", 2, "After Delete")
// Clear // Clear
DB.Model(&users).Association("Company").Clear() DB.Model(&users).Association("Company").Clear()
@ -185,4 +214,22 @@ func TestBelongsToAssociationForSlice(t *testing.T) {
DB.Model(&users).Association("Manager").Clear() DB.Model(&users).Association("Manager").Clear()
AssertAssociationCount(t, users, "Manager", 0, "After Clear") AssertAssociationCount(t, users, "Manager", 0, "After Clear")
// shared company
company := Company{Name: "shared"}
if err := DB.Model(&users[0]).Association("Company").Append(&company); err != nil {
t.Errorf("Error happened when append company to user, got %v", err)
}
if err := DB.Model(&users[1]).Association("Company").Append(&company); err != nil {
t.Errorf("Error happened when append company to user, got %v", err)
}
if users[0].CompanyID == nil || users[1].CompanyID == nil || *users[0].CompanyID != *users[1].CompanyID {
t.Errorf("user's company id should exists and equal, but its: %v, %v", users[0].CompanyID, users[1].CompanyID)
}
DB.Model(&users[0]).Association("Company").Delete(&company)
AssertAssociationCount(t, users[0], "Company", 0, "After Delete")
AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete")
} }