diff --git a/association.go b/association.go index 25e1fe8d..db77cc4e 100644 --- a/association.go +++ b/association.go @@ -43,32 +43,8 @@ func (db *DB) Association(column string) *Association { func (association *Association) Find(out interface{}, conds ...interface{}) error { if association.Error == nil { - var ( - queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) - tx = association.DB.Model(out) - ) - - if association.Relationship.JoinTable != nil { - if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { - joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} - for _, queryClause := range association.Relationship.JoinTable.QueryClauses { - joinStmt.AddClause(queryClause) - } - joinStmt.Build("WHERE") - tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) - } - - tx.Clauses(clause.From{Joins: []clause.Join{{ - Table: clause.Table{Name: association.Relationship.JoinTable.Table}, - ON: clause.Where{Exprs: queryConds}, - }}}) - } else { - tx.Clauses(clause.Where{Exprs: queryConds}) - } - - association.Error = tx.Find(out, conds...).Error + association.Error = association.buildCondition().Find(out, conds...).Error } - return association.Error } @@ -80,7 +56,7 @@ func (association *Association) Append(values ...interface{}) error { association.Error = association.Replace(values...) } default: - association.saveAssociation(false, values...) + association.saveAssociation( /*clear*/ false, values...) } } @@ -90,7 +66,7 @@ func (association *Association) Append(values ...interface{}) error { func (association *Association) Replace(values ...interface{}) error { if association.Error == nil { // save associations - association.saveAssociation(true, values...) + association.saveAssociation( /*clear*/ true, values...) // set old associations's foreign key to null reflectValue := association.DB.Statement.ReflectValue @@ -234,7 +210,7 @@ func (association *Association) Delete(values ...interface{}) error { var ( primaryFields, relPrimaryFields []*schema.Field joinPrimaryKeys, joinRelPrimaryKeys []string - modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + joinValue = reflect.New(rel.JoinTable.ModelType).Interface() ) for _, ref := range rel.References { @@ -259,10 +235,11 @@ func (association *Association) Delete(values ...interface{}) error { relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) - association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue).Error + association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error } if association.Error == nil { + // clean up deleted values's foreign key relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { @@ -328,33 +305,8 @@ func (association *Association) Clear() error { func (association *Association) Count() (count int64) { if association.Error == nil { - var ( - conds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) - modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() - tx = association.DB.Model(modelValue) - ) - - if association.Relationship.JoinTable != nil { - if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { - joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} - for _, queryClause := range association.Relationship.JoinTable.QueryClauses { - joinStmt.AddClause(queryClause) - } - joinStmt.Build("WHERE", "LIMIT") - tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) - } - - tx.Clauses(clause.From{Joins: []clause.Join{{ - Table: clause.Table{Name: association.Relationship.JoinTable.Table}, - ON: clause.Where{Exprs: conds}, - }}}) - } else { - tx.Clauses(clause.Where{Exprs: conds}) - } - - association.Error = tx.Count(&count).Error + association.Error = association.buildCondition().Count(&count).Error } - return } @@ -435,6 +387,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if len(values) != reflectValue.Len() { + // clear old data if clear && len(values) == 0 { for i := 0; i < reflectValue.Len(); i++ { if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { @@ -467,6 +420,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error } case reflect.Struct: + // clear old data if clear && len(values) == 0 { association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) @@ -498,3 +452,31 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } } } + +func (association *Association) buildCondition() *DB { + var ( + queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) + modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() + tx = association.DB.Model(modelValue) + ) + + if association.Relationship.JoinTable != nil { + if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { + joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} + for _, queryClause := range association.Relationship.JoinTable.QueryClauses { + joinStmt.AddClause(queryClause) + } + joinStmt.Build("WHERE") + tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) + } + + tx.Clauses(clause.From{Joins: []clause.Join{{ + Table: clause.Table{Name: association.Relationship.JoinTable.Table}, + ON: clause.Where{Exprs: queryConds}, + }}}) + } else { + tx.Clauses(clause.Where{Exprs: queryConds}) + } + + return tx +}