forked from mirror/gorm
Refactor association
This commit is contained in:
parent
c19a3abefb
commit
94c6bb980b
|
@ -43,32 +43,8 @@ func (db *DB) Association(column string) *Association {
|
|||
|
||||
func (association *Association) Find(out interface{}, conds ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
var (
|
||||
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
|
||||
tx = association.DB.Model(out)
|
||||
)
|
||||
|
||||
if association.Relationship.JoinTable != nil {
|
||||
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
||||
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
||||
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
||||
joinStmt.AddClause(queryClause)
|
||||
}
|
||||
joinStmt.Build("WHERE")
|
||||
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
||||
}
|
||||
|
||||
tx.Clauses(clause.From{Joins: []clause.Join{{
|
||||
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
||||
ON: clause.Where{Exprs: queryConds},
|
||||
}}})
|
||||
} else {
|
||||
tx.Clauses(clause.Where{Exprs: queryConds})
|
||||
}
|
||||
|
||||
association.Error = tx.Find(out, conds...).Error
|
||||
association.Error = association.buildCondition().Find(out, conds...).Error
|
||||
}
|
||||
|
||||
return association.Error
|
||||
}
|
||||
|
||||
|
@ -80,7 +56,7 @@ func (association *Association) Append(values ...interface{}) error {
|
|||
association.Error = association.Replace(values...)
|
||||
}
|
||||
default:
|
||||
association.saveAssociation(false, values...)
|
||||
association.saveAssociation( /*clear*/ false, values...)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -90,7 +66,7 @@ func (association *Association) Append(values ...interface{}) error {
|
|||
func (association *Association) Replace(values ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
// save associations
|
||||
association.saveAssociation(true, values...)
|
||||
association.saveAssociation( /*clear*/ true, values...)
|
||||
|
||||
// set old associations's foreign key to null
|
||||
reflectValue := association.DB.Statement.ReflectValue
|
||||
|
@ -234,7 +210,7 @@ func (association *Association) Delete(values ...interface{}) error {
|
|||
var (
|
||||
primaryFields, relPrimaryFields []*schema.Field
|
||||
joinPrimaryKeys, joinRelPrimaryKeys []string
|
||||
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
joinValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
)
|
||||
|
||||
for _, ref := range rel.References {
|
||||
|
@ -259,10 +235,11 @@ func (association *Association) Delete(values ...interface{}) error {
|
|||
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
|
||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||
|
||||
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue).Error
|
||||
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error
|
||||
}
|
||||
|
||||
if association.Error == nil {
|
||||
// clean up deleted values's foreign key
|
||||
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
|
||||
|
||||
cleanUpDeletedRelations := func(data reflect.Value) {
|
||||
|
@ -328,33 +305,8 @@ func (association *Association) Clear() error {
|
|||
|
||||
func (association *Association) Count() (count int64) {
|
||||
if association.Error == nil {
|
||||
var (
|
||||
conds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
|
||||
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
|
||||
tx = association.DB.Model(modelValue)
|
||||
)
|
||||
|
||||
if association.Relationship.JoinTable != nil {
|
||||
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
||||
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
||||
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
||||
joinStmt.AddClause(queryClause)
|
||||
}
|
||||
joinStmt.Build("WHERE", "LIMIT")
|
||||
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
||||
}
|
||||
|
||||
tx.Clauses(clause.From{Joins: []clause.Join{{
|
||||
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
||||
ON: clause.Where{Exprs: conds},
|
||||
}}})
|
||||
} else {
|
||||
tx.Clauses(clause.Where{Exprs: conds})
|
||||
}
|
||||
|
||||
association.Error = tx.Count(&count).Error
|
||||
association.Error = association.buildCondition().Count(&count).Error
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -435,6 +387,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
|||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if len(values) != reflectValue.Len() {
|
||||
// clear old data
|
||||
if clear && len(values) == 0 {
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
|
||||
|
@ -467,6 +420,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
|||
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error
|
||||
}
|
||||
case reflect.Struct:
|
||||
// clear old data
|
||||
if clear && len(values) == 0 {
|
||||
association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
|
||||
|
||||
|
@ -498,3 +452,31 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (association *Association) buildCondition() *DB {
|
||||
var (
|
||||
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
|
||||
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
|
||||
tx = association.DB.Model(modelValue)
|
||||
)
|
||||
|
||||
if association.Relationship.JoinTable != nil {
|
||||
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
||||
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
||||
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
||||
joinStmt.AddClause(queryClause)
|
||||
}
|
||||
joinStmt.Build("WHERE")
|
||||
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
||||
}
|
||||
|
||||
tx.Clauses(clause.From{Joins: []clause.Join{{
|
||||
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
||||
ON: clause.Where{Exprs: queryConds},
|
||||
}}})
|
||||
} else {
|
||||
tx.Clauses(clause.Where{Exprs: queryConds})
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue