Refactor association

This commit is contained in:
Jinzhu 2020-08-28 12:25:25 +08:00
parent c19a3abefb
commit 94c6bb980b
1 changed files with 37 additions and 55 deletions

View File

@ -43,32 +43,8 @@ func (db *DB) Association(column string) *Association {
func (association *Association) Find(out interface{}, conds ...interface{}) error { func (association *Association) Find(out interface{}, conds ...interface{}) error {
if association.Error == nil { if association.Error == nil {
var ( association.Error = association.buildCondition().Find(out, conds...).Error
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
tx = association.DB.Model(out)
)
if association.Relationship.JoinTable != nil {
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
joinStmt.AddClause(queryClause)
} }
joinStmt.Build("WHERE")
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
tx.Clauses(clause.From{Joins: []clause.Join{{
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
ON: clause.Where{Exprs: queryConds},
}}})
} else {
tx.Clauses(clause.Where{Exprs: queryConds})
}
association.Error = tx.Find(out, conds...).Error
}
return association.Error return association.Error
} }
@ -80,7 +56,7 @@ func (association *Association) Append(values ...interface{}) error {
association.Error = association.Replace(values...) association.Error = association.Replace(values...)
} }
default: default:
association.saveAssociation(false, values...) association.saveAssociation( /*clear*/ false, values...)
} }
} }
@ -90,7 +66,7 @@ func (association *Association) Append(values ...interface{}) error {
func (association *Association) Replace(values ...interface{}) error { func (association *Association) Replace(values ...interface{}) error {
if association.Error == nil { if association.Error == nil {
// save associations // save associations
association.saveAssociation(true, values...) association.saveAssociation( /*clear*/ true, values...)
// set old associations's foreign key to null // set old associations's foreign key to null
reflectValue := association.DB.Statement.ReflectValue reflectValue := association.DB.Statement.ReflectValue
@ -234,7 +210,7 @@ func (association *Association) Delete(values ...interface{}) error {
var ( var (
primaryFields, relPrimaryFields []*schema.Field primaryFields, relPrimaryFields []*schema.Field
joinPrimaryKeys, joinRelPrimaryKeys []string joinPrimaryKeys, joinRelPrimaryKeys []string
modelValue = reflect.New(rel.JoinTable.ModelType).Interface() joinValue = reflect.New(rel.JoinTable.ModelType).Interface()
) )
for _, ref := range rel.References { for _, ref := range rel.References {
@ -259,10 +235,11 @@ func (association *Association) Delete(values ...interface{}) error {
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue).Error association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error
} }
if association.Error == nil { if association.Error == nil {
// clean up deleted values's foreign key
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
cleanUpDeletedRelations := func(data reflect.Value) { cleanUpDeletedRelations := func(data reflect.Value) {
@ -328,33 +305,8 @@ func (association *Association) Clear() error {
func (association *Association) Count() (count int64) { func (association *Association) Count() (count int64) {
if association.Error == nil { if association.Error == nil {
var ( association.Error = association.buildCondition().Count(&count).Error
conds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
if association.Relationship.JoinTable != nil {
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
joinStmt.AddClause(queryClause)
} }
joinStmt.Build("WHERE", "LIMIT")
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
tx.Clauses(clause.From{Joins: []clause.Join{{
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
ON: clause.Where{Exprs: conds},
}}})
} else {
tx.Clauses(clause.Where{Exprs: conds})
}
association.Error = tx.Count(&count).Error
}
return return
} }
@ -435,6 +387,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if len(values) != reflectValue.Len() { if len(values) != reflectValue.Len() {
// clear old data
if clear && len(values) == 0 { if clear && len(values) == 0 {
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
@ -467,6 +420,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error
} }
case reflect.Struct: case reflect.Struct:
// clear old data
if clear && len(values) == 0 { if clear && len(values) == 0 {
association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
@ -498,3 +452,31 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
} }
} }
} }
func (association *Association) buildCondition() *DB {
var (
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
if association.Relationship.JoinTable != nil {
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
joinStmt.AddClause(queryClause)
}
joinStmt.Build("WHERE")
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
tx.Clauses(clause.From{Joins: []clause.Join{{
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
ON: clause.Where{Exprs: queryConds},
}}})
} else {
tx.Clauses(clause.Where{Exprs: queryConds})
}
return tx
}