forked from mirror/gorm
Refactor association
This commit is contained in:
parent
c19a3abefb
commit
94c6bb980b
|
@ -43,32 +43,8 @@ func (db *DB) Association(column string) *Association {
|
||||||
|
|
||||||
func (association *Association) Find(out interface{}, conds ...interface{}) error {
|
func (association *Association) Find(out interface{}, conds ...interface{}) error {
|
||||||
if association.Error == nil {
|
if association.Error == nil {
|
||||||
var (
|
association.Error = association.buildCondition().Find(out, conds...).Error
|
||||||
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
|
|
||||||
tx = association.DB.Model(out)
|
|
||||||
)
|
|
||||||
|
|
||||||
if association.Relationship.JoinTable != nil {
|
|
||||||
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
|
||||||
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
|
||||||
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
|
||||||
joinStmt.AddClause(queryClause)
|
|
||||||
}
|
|
||||||
joinStmt.Build("WHERE")
|
|
||||||
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
|
||||||
}
|
|
||||||
|
|
||||||
tx.Clauses(clause.From{Joins: []clause.Join{{
|
|
||||||
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
|
||||||
ON: clause.Where{Exprs: queryConds},
|
|
||||||
}}})
|
|
||||||
} else {
|
|
||||||
tx.Clauses(clause.Where{Exprs: queryConds})
|
|
||||||
}
|
|
||||||
|
|
||||||
association.Error = tx.Find(out, conds...).Error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return association.Error
|
return association.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,7 +56,7 @@ func (association *Association) Append(values ...interface{}) error {
|
||||||
association.Error = association.Replace(values...)
|
association.Error = association.Replace(values...)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
association.saveAssociation(false, values...)
|
association.saveAssociation( /*clear*/ false, values...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,7 +66,7 @@ func (association *Association) Append(values ...interface{}) error {
|
||||||
func (association *Association) Replace(values ...interface{}) error {
|
func (association *Association) Replace(values ...interface{}) error {
|
||||||
if association.Error == nil {
|
if association.Error == nil {
|
||||||
// save associations
|
// save associations
|
||||||
association.saveAssociation(true, values...)
|
association.saveAssociation( /*clear*/ true, values...)
|
||||||
|
|
||||||
// set old associations's foreign key to null
|
// set old associations's foreign key to null
|
||||||
reflectValue := association.DB.Statement.ReflectValue
|
reflectValue := association.DB.Statement.ReflectValue
|
||||||
|
@ -234,7 +210,7 @@ func (association *Association) Delete(values ...interface{}) error {
|
||||||
var (
|
var (
|
||||||
primaryFields, relPrimaryFields []*schema.Field
|
primaryFields, relPrimaryFields []*schema.Field
|
||||||
joinPrimaryKeys, joinRelPrimaryKeys []string
|
joinPrimaryKeys, joinRelPrimaryKeys []string
|
||||||
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
joinValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, ref := range rel.References {
|
for _, ref := range rel.References {
|
||||||
|
@ -259,10 +235,11 @@ func (association *Association) Delete(values ...interface{}) error {
|
||||||
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
|
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
|
||||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||||
|
|
||||||
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue).Error
|
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
if association.Error == nil {
|
if association.Error == nil {
|
||||||
|
// clean up deleted values's foreign key
|
||||||
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
|
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
|
||||||
|
|
||||||
cleanUpDeletedRelations := func(data reflect.Value) {
|
cleanUpDeletedRelations := func(data reflect.Value) {
|
||||||
|
@ -328,33 +305,8 @@ func (association *Association) Clear() error {
|
||||||
|
|
||||||
func (association *Association) Count() (count int64) {
|
func (association *Association) Count() (count int64) {
|
||||||
if association.Error == nil {
|
if association.Error == nil {
|
||||||
var (
|
association.Error = association.buildCondition().Count(&count).Error
|
||||||
conds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
|
|
||||||
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
|
|
||||||
tx = association.DB.Model(modelValue)
|
|
||||||
)
|
|
||||||
|
|
||||||
if association.Relationship.JoinTable != nil {
|
|
||||||
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
|
||||||
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
|
||||||
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
|
||||||
joinStmt.AddClause(queryClause)
|
|
||||||
}
|
|
||||||
joinStmt.Build("WHERE", "LIMIT")
|
|
||||||
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
|
||||||
}
|
|
||||||
|
|
||||||
tx.Clauses(clause.From{Joins: []clause.Join{{
|
|
||||||
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
|
||||||
ON: clause.Where{Exprs: conds},
|
|
||||||
}}})
|
|
||||||
} else {
|
|
||||||
tx.Clauses(clause.Where{Exprs: conds})
|
|
||||||
}
|
|
||||||
|
|
||||||
association.Error = tx.Count(&count).Error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -435,6 +387,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
if len(values) != reflectValue.Len() {
|
if len(values) != reflectValue.Len() {
|
||||||
|
// clear old data
|
||||||
if clear && len(values) == 0 {
|
if clear && len(values) == 0 {
|
||||||
for i := 0; i < reflectValue.Len(); i++ {
|
for i := 0; i < reflectValue.Len(); i++ {
|
||||||
if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
|
if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
|
||||||
|
@ -467,6 +420,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||||
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error
|
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
|
// clear old data
|
||||||
if clear && len(values) == 0 {
|
if clear && len(values) == 0 {
|
||||||
association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
|
association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
|
||||||
|
|
||||||
|
@ -498,3 +452,31 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (association *Association) buildCondition() *DB {
|
||||||
|
var (
|
||||||
|
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
|
||||||
|
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
|
||||||
|
tx = association.DB.Model(modelValue)
|
||||||
|
)
|
||||||
|
|
||||||
|
if association.Relationship.JoinTable != nil {
|
||||||
|
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
||||||
|
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
||||||
|
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
||||||
|
joinStmt.AddClause(queryClause)
|
||||||
|
}
|
||||||
|
joinStmt.Build("WHERE")
|
||||||
|
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.Clauses(clause.From{Joins: []clause.Join{{
|
||||||
|
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
||||||
|
ON: clause.Where{Exprs: queryConds},
|
||||||
|
}}})
|
||||||
|
} else {
|
||||||
|
tx.Clauses(clause.Where{Exprs: queryConds})
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue