diff --git a/association.go b/association.go index 89ea2b9f..0ce2904f 100644 --- a/association.go +++ b/association.go @@ -87,11 +87,11 @@ func (association *Association) Delete(values ...interface{}) *Association { } else { relationship := association.Field.Relationship // many to many - if relationship.joinTable != "" { + if relationship.kind == "many_to_many" { whereSql := fmt.Sprintf("%v.%v IN (?)", relationship.joinTable, scope.Quote(ToSnake(relationship.associationForeignKey))) scope.db.Table(relationship.joinTable).Where(whereSql, primaryKeys).Delete("") } else { - association.err(errors.New("only many to many support delete")) + association.err(errors.New("delete only support many to many")) } } return association @@ -105,25 +105,36 @@ func (association *Association) Clear(value interface{}) *Association { return association } -func (association *Association) Count() (count int) { +func (association *Association) Count() int { + count := -1 relationship := association.Field.Relationship scope := association.Scope field := scope.IndirectValue().FieldByName(association.Column) fieldValue := field.Interface() + newScope := scope.New(fieldValue) - // many to many - if relationship.joinTable != "" { - newScope := scope.New(fieldValue) + if relationship.kind == "many_to_many" { whereSql := fmt.Sprintf("%v.%v IN (SELECT %v.%v FROM %v WHERE %v.%v = ?)", newScope.QuotedTableName(), scope.Quote(newScope.PrimaryKey()), relationship.joinTable, - scope.Quote(relationship.associationForeignKey), + scope.Quote(ToSnake(relationship.associationForeignKey)), relationship.joinTable, relationship.joinTable, - scope.Quote(relationship.foreignKey)) - scope.db.Table(newScope.QuotedTableName()).Where(whereSql, scope.PrimaryKey()).Count(&count) + scope.Quote(ToSnake(relationship.foreignKey))) + scope.db.Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).NewScope("").count(&count) + } else if relationship.kind == "has_many" { + whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.foreignKey))) + scope.db.Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).NewScope("").count(&count) + } else if relationship.kind == "has_one" { + whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), relationship.foreignKey) + scope.db.Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).NewScope("").count(&count) + } else if relationship.kind == "belongs_to" { + if v, ok := scope.FieldByName(association.Column); ok { + whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), relationship.foreignKey) + scope.db.Table(newScope.QuotedTableName()).Where(whereSql, v).NewScope("").count(&count) + } } - // association.Scope.related(value, association.Column) - return -1 + + return count } diff --git a/association_test.go b/association_test.go index 3c8f7f5c..561725a5 100644 --- a/association_test.go +++ b/association_test.go @@ -145,6 +145,10 @@ func TestManyToMany(t *testing.T) { t.Errorf("Should be able to find many to many relations") } + if db.Model(&user).Association("Languages").Count() != len([]string{"ZH", "EN"}) { + t.Errorf("Count should return correct result") + } + // Append db.Model(&user).Association("Languages").Append(&Language{Name: "DE"}) if db.Where("name = ?", "DE").First(&Language{}).RecordNotFound() { @@ -161,9 +165,7 @@ func TestManyToMany(t *testing.T) { totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"} - newLanguages = []Language{} - db.Model(&user).Related(&newLanguages, "Languages") - if len(newLanguages) != len(totalLanguages) { + if db.Model(&user).Association("Languages").Count() != len(totalLanguages) { t.Errorf("All appended languages should be saved") } @@ -171,10 +173,7 @@ func TestManyToMany(t *testing.T) { var language Language db.Where("name = ?", "EE").First(&language) db.Model(&user).Association("Languages").Delete(language, &language) - - newLanguages = []Language{} - db.Model(&user).Related(&newLanguages, "Languages") - if len(newLanguages) != len(totalLanguages)-1 { + if db.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 { t.Errorf("Relations should be deleted with Delete") } if db.Where("name = ?", "EE").First(&Language{}).RecordNotFound() { @@ -184,9 +183,7 @@ func TestManyToMany(t *testing.T) { languages = []Language{} db.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages) db.Model(&user).Association("Languages").Delete(languages, &languages) - newLanguages = []Language{} - db.Model(&user).Related(&newLanguages, "Languages") - if len(newLanguages) != len(totalLanguages)-3 { + if db.Model(&user).Association("Languages").Count() != len(totalLanguages)-3 { t.Errorf("Relations should be deleted with Delete") }