diff --git a/association.go b/association.go index 9a3a338b..14fd1c35 100644 --- a/association.go +++ b/association.go @@ -68,26 +68,28 @@ func (association *Association) Replace(values ...interface{}) *Association { // Delete Relations except new created if len(values) > 0 { - var associationForeignFieldNames []string + var associationForeignFieldNames, associationForeignDBNames []string if relationship.Kind == "many_to_many" { // if many to many relations, get association fields name from association foreign keys associationScope := scope.New(reflect.New(field.Type()).Interface()) - for _, dbName := range relationship.AssociationForeignFieldNames { + for idx, dbName := range relationship.AssociationForeignFieldNames { if field, ok := associationScope.FieldByName(dbName); ok { associationForeignFieldNames = append(associationForeignFieldNames, field.Name) + associationForeignDBNames = append(associationForeignDBNames, relationship.AssociationForeignDBNames[idx]) } } } else { - // If other relations, use primary keys + // If has one/many relations, use primary keys for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { associationForeignFieldNames = append(associationForeignFieldNames, field.Name) + associationForeignDBNames = append(associationForeignDBNames, field.DBName) } } newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface()) if len(newPrimaryKeys) > 0 { - sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys)) + sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, associationForeignDBNames), toQueryMarks(newPrimaryKeys)) newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...) } } diff --git a/association_test.go b/association_test.go index ad56d84e..02974a98 100644 --- a/association_test.go +++ b/association_test.go @@ -872,3 +872,15 @@ func TestLongForeignKey(t *testing.T) { func TestLongForeignKeyWithShortDest(t *testing.T) { testForeignKey(t, &ReallyLongThingThatReferencesShort{}, "ShortID", &Short{}, "ID") } + +func TestHasManyChildrenWithOneStruct(t *testing.T) { + category := Category{ + Name: "main", + Categories: []Category{ + {Name: "sub1"}, + {Name: "sub2"}, + }, + } + + DB.Save(&category) +} diff --git a/migration_test.go b/migration_test.go index 07603c19..8b3c4ab6 100644 --- a/migration_test.go +++ b/migration_test.go @@ -180,6 +180,9 @@ type Post struct { type Category struct { gorm.Model Name string + + Categories []Category + CategoryID *uint } type Comment struct {