diff --git a/association.go b/association.go index f65e77c2..9405d962 100644 --- a/association.go +++ b/association.go @@ -179,69 +179,71 @@ func (association *Association) Replace(values ...interface{}) error { func (association *Association) Delete(values ...interface{}) error { if association.Error == nil { var ( - reflectValue = association.DB.Statement.ReflectValue - rel = association.Relationship - tx = association.DB - relFields []*schema.Field - foreignKeyFields []*schema.Field - foreignKeys []string - updateAttrs = map[string]interface{}{} + reflectValue = association.DB.Statement.ReflectValue + rel = association.Relationship + tx = association.DB + primaryFields, foreignFields []*schema.Field + foreignKeys []string + updateAttrs = map[string]interface{}{} ) for _, ref := range rel.References { if ref.PrimaryValue == "" { - if rel.JoinTable == nil || !ref.OwnPrimaryKey { - if ref.OwnPrimaryKey { - relFields = append(relFields, ref.ForeignKey) - } else { - relFields = append(relFields, ref.PrimaryKey) - foreignKeyFields = append(foreignKeyFields, ref.ForeignKey) - } - - foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) - updateAttrs[ref.ForeignKey.DBName] = nil - } + primaryFields = append(primaryFields, ref.PrimaryKey) + foreignFields = append(foreignFields, ref.ForeignKey) + foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) + updateAttrs[ref.ForeignKey.DBName] = nil } } - relValuesMap, relQueryValues := schema.GetIdentityFieldValuesMapFromValues(values, relFields) - column, values := schema.ToQueryValues(foreignKeys, relQueryValues) - tx = tx.Session(&Session{}).Where(clause.IN{Column: column, Values: values}) - switch rel.Type { case schema.HasOne, schema.HasMany: - modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() - conds := rel.ToQueryConditions(reflectValue) - tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) - case schema.BelongsTo: - primaryKeys := []string{} - for _, field := range rel.Schema.PrimaryFields { - primaryKeys = append(primaryKeys, field.DBName) - } - _, queryValues := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) - if column, values := schema.ToQueryValues(primaryKeys, queryValues); len(values) > 0 { - tx.Where(clause.IN{Column: column, Values: values}) - } + var ( + modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() + _, queryValues = schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + _, relQueryValues = schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) + ) - modelValue := reflect.New(rel.Schema.ModelType).Interface() - tx.Model(modelValue).UpdateColumns(updateAttrs) + column, values := schema.ToQueryValues(foreignKeys, queryValues) + relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.PrimaryFieldDBNames, relQueryValues) + + tx.Session(&Session{}).Model(modelValue).Clauses( + clause.IN{Column: column, Values: values}, + clause.IN{Column: relColumn, Values: relValues}, + ).UpdateColumns(updateAttrs) + case schema.BelongsTo: + var ( + modelValue = reflect.New(rel.Schema.ModelType).Interface() + _, queryValues = schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) + _, relQueryValues = schema.GetIdentityFieldValuesMapFromValues(values, primaryFields) + ) + + column, values := schema.ToQueryValues(rel.Schema.PrimaryFieldDBNames, queryValues) + relColumn, relValues := schema.ToQueryValues(foreignKeys, relQueryValues) + + tx.Session(&Session{}).Model(modelValue).Clauses( + clause.IN{Column: column, Values: values}, + clause.IN{Column: relColumn, Values: relValues}, + ).UpdateColumns(updateAttrs) case schema.Many2Many: modelValue := reflect.New(rel.JoinTable.ModelType).Interface() conds := rel.ToQueryConditions(reflectValue) tx.Clauses(clause.Where{Exprs: conds}).Delete(modelValue) } + relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) + if tx.Error == nil { cleanUpDeletedRelations := func(data reflect.Value) { if _, zero := rel.Field.ValueOf(data); !zero { fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) - fieldValues := make([]interface{}, len(relFields)) + fieldValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields)) switch fieldValue.Kind() { case reflect.Slice, reflect.Array: validFieldValues := reflect.Zero(rel.Field.FieldType) for i := 0; i < fieldValue.Len(); i++ { - for idx, field := range relFields { + for idx, field := range rel.FieldSchema.PrimaryFields { fieldValues[idx], _ = field.ValueOf(fieldValue.Index(i)) } @@ -252,13 +254,18 @@ func (association *Association) Delete(values ...interface{}) error { rel.Field.Set(data, validFieldValues.Interface()) case reflect.Struct: - for idx, field := range relFields { + for idx, field := range rel.FieldSchema.PrimaryFields { fieldValues[idx], _ = field.ValueOf(fieldValue) } if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; ok { rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()) - for _, field := range foreignKeyFields { - field.Set(data, reflect.Zero(field.FieldType).Interface()) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } else if ref.PrimaryValue == "" { + // FIXME + ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } } } } @@ -337,9 +344,9 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } case schema.HasMany, schema.Many2Many: elemType := association.Relationship.Field.IndirectFieldType.Elem() - fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(reflectValue)) + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source)) if clear { - fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType) + fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem() } appendToFieldValues := func(ev reflect.Value) { @@ -355,14 +362,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ switch rv.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < rv.Len(); i++ { - appendToFieldValues(reflect.Indirect(rv.Index(i))) + appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr()) } case reflect.Struct: - appendToFieldValues(rv) + appendToFieldValues(rv.Addr()) } if association.Error == nil { - association.Error = association.Relationship.Field.Set(source, fieldValue.Addr().Interface()) + association.Error = association.Relationship.Field.Set(source, fieldValue.Interface()) } } } diff --git a/schema/schema.go b/schema/schema.go index caae55ac..e66084a3 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -22,6 +22,7 @@ type Schema struct { PrioritizedPrimaryField *Field DBNames []string PrimaryFields []*Field + PrimaryFieldDBNames []string Fields []*Field FieldsByName map[string]*Field FieldsByDBName map[string]*Field @@ -165,6 +166,10 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } + for _, field := range schema.PrimaryFields { + schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) + } + schema.FieldsWithDefaultDBValue = map[string]*Field{} for db, field := range schema.FieldsByDBName { if field.HasDefaultValue && field.DefaultValueInterface == nil { diff --git a/tests/associations_test.go b/tests/associations_test.go index 2b81a719..08733005 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -467,3 +467,152 @@ func TestPolymorphicHasOneAssociationForSlice(t *testing.T) { DB.Model(&pets).Association("Toy").Clear() AssertAssociationCount(t, pets, "Toy", 0, "After Clear") } + +func TestHasManyAssociation(t *testing.T) { + var user = *GetUser("hasmany", Config{Pets: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Pets").Find(&user2.Pets) + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Pets", 2, "") + + // Append + var pet = Pet{Name: "pet-has-many-append"} + + if err := DB.Model(&user2).Association("Pets").Append(&pet); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + if pet.ID == 0 { + t.Fatalf("Pet's ID should be created") + } + + user.Pets = append(user.Pets, &pet) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Pets", 3, "AfterAppend") + + var pets = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}} + + if err := DB.Model(&user2).Association("Pets").Append(&pets); err != nil { + t.Fatalf("Error happened when append pet, got %v", err) + } + + for _, pet := range pets { + var pet = pet + if pet.ID == 0 { + t.Fatalf("Pet's ID should be created") + } + + user.Pets = append(user.Pets, &pet) + } + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Pets", 5, "AfterAppendSlice") + + // Replace + var pet2 = Pet{Name: "pet-has-many-replace"} + + if err := DB.Model(&user2).Association("Pets").Replace(&pet2); err != nil { + t.Fatalf("Error happened when append pet, got %v", err) + } + + if pet2.ID == 0 { + t.Fatalf("pet2's ID should be created") + } + + user.Pets = []*Pet{&pet2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Pets", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Pets").Delete(&Pet{}); err != nil { + t.Fatalf("Error happened when delete pet, got %v", err) + } + AssertAssociationCount(t, user2, "Pets", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Pets").Delete(&pet2); err != nil { + t.Fatalf("Error happened when delete Pets, got %v", err) + } + AssertAssociationCount(t, user2, "Pets", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Pets").Append(&pet); err != nil { + t.Fatalf("Error happened when append Pets, got %v", err) + } + + AssertAssociationCount(t, user2, "Pets", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Pets").Clear(); err != nil { + t.Errorf("Error happened when clear Pets, got %v", err) + } + + AssertAssociationCount(t, user2, "Pets", 0, "after clear") +} + +func TestHasManyAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-hasmany-1", Config{Pets: 2}), + *GetUser("slice-hasmany-2", Config{Pets: 0}), + *GetUser("slice-hasmany-3", Config{Pets: 4}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Pets", 6, "") + + // Find + var pets []Pet + if DB.Model(&users).Association("Pets").Find(&pets); len(pets) != 6 { + t.Errorf("pets count should be %v, but got %v", 6, len(pets)) + } + + // Append + DB.Model(&users).Association("Pets").Append( + &Pet{Name: "pet-slice-append-1"}, + []*Pet{{Name: "pet-slice-append-2-1"}, {Name: "pet-slice-append-2-2"}}, + &Pet{Name: "pet-slice-append-3"}, + ) + + AssertAssociationCount(t, users, "Pets", 10, "After Append") + + // Replace -> same as append + DB.Model(&users).Association("Pets").Replace( + []*Pet{{Name: "pet-slice-replace-1-1"}, {Name: "pet-slice-replace-1-2"}}, + []*Pet{{Name: "pet-slice-replace-2-1"}, {Name: "pet-slice-replace-2-2"}}, + &Pet{Name: "pet-slice-replace-3"}, + ) + + AssertAssociationCount(t, users, "Pets", 5, "After Append") + + // Delete + if err := DB.Model(&users).Association("Pets").Delete(&users[2].Pets); err != nil { + t.Errorf("no error should happend when deleting pet, but got %v", err) + } + + AssertAssociationCount(t, users, "Pets", 4, "after delete") + + if err := DB.Debug().Model(&users).Association("Pets").Delete(users[0].Pets[0], users[1].Pets[1]); err != nil { + t.Errorf("no error should happend when deleting pet, but got %v", err) + } + + AssertAssociationCount(t, users, "Pets", 2, "after delete") + + // Clear + DB.Model(&users).Association("Pets").Clear() + AssertAssociationCount(t, users, "Pets", 0, "After Clear") +}