diff --git a/association.go b/association.go index a9345255..82a2274e 100644 --- a/association.go +++ b/association.go @@ -3,6 +3,7 @@ package gorm import ( "fmt" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/schema" ) @@ -31,10 +32,6 @@ func (db *DB) Association(column string) *Association { func (association *Association) Find(out interface{}, conds ...interface{}) error { if association.Error == nil { - for _, ref := range association.Relationship.References { - if ref.OwnPrimaryKey { - } - } } return association.Error @@ -53,9 +50,27 @@ func (association *Association) Delete(values ...interface{}) error { } func (association *Association) Clear() error { - return association.Error + return association.Replace() } -func (association *Association) Count() int { - return 0 +func (association *Association) Count() (count int) { + if association.Error == nil { + var ( + tx = association.DB + conds = association.Relationship.ToQueryConditions(tx.Statement.ReflectValue) + ) + + if association.Relationship.JoinTable != nil { + tx.Clauses(clause.From{Joins: []clause.Join{{ + Table: clause.Table{Name: association.Relationship.JoinTable.Table}, + ON: clause.Where{Exprs: conds}, + }}}) + } else { + tx.Clauses(clause.Where{Exprs: conds}) + } + + association.Error = tx.Count(&count).Error + } + + return } diff --git a/schema/relationship.go b/schema/relationship.go index 4ffea8b3..59aaa7e4 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -6,6 +6,7 @@ import ( "regexp" "strings" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/inflection" ) @@ -345,3 +346,47 @@ func (rel *Relationship) ParseConstraint() *Constraint { return &constraint } + +func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) { + foreignFields := []*Field{} + relForeignKeys := []string{} + + if rel.JoinTable != nil { + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + } else if ref.PrimaryValue != "" { + conds = append(conds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } else { + conds = append(conds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, + }) + } + } + } else { + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + foreignFields = append(foreignFields, ref.PrimaryKey) + } else if ref.PrimaryValue != "" { + conds = append(conds, clause.Eq{ + Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } else { + relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + } + + _, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields) + column, values := ToQueryValues(relForeignKeys, foreignValues) + conds = append(conds, clause.IN{Column: column, Values: values}) + return +} diff --git a/schema/schema.go b/schema/schema.go index 5a28797b..79faae12 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -8,6 +8,7 @@ import ( "reflect" "sync" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/logger" ) @@ -26,6 +27,10 @@ type Schema struct { FieldsByDBName map[string]*Field FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database Relationships Relationships + CreateClauses []clause.Interface + QueryClauses []clause.Interface + UpdateClauses []clause.Interface + DeleteClauses []clause.Interface BeforeCreate, AfterCreate bool BeforeUpdate, AfterUpdate bool BeforeDelete, AfterDelete bool