gorm/association.go

428 lines
14 KiB
Go
Raw Normal View History

2020-01-29 14:22:44 +03:00
package gorm
import (
2020-05-20 18:44:50 +03:00
"errors"
"fmt"
2020-05-19 21:03:43 +03:00
"reflect"
2020-05-19 16:50:06 +03:00
"github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/schema"
2020-05-19 21:03:43 +03:00
"github.com/jinzhu/gorm/utils"
)
2020-01-29 14:22:44 +03:00
// Association Mode contains some helper methods to handle relationship things easily.
type Association struct {
DB *DB
Relationship *schema.Relationship
Error error
2020-01-29 14:22:44 +03:00
}
2020-02-23 18:28:35 +03:00
func (db *DB) Association(column string) *Association {
association := &Association{DB: db}
2020-05-24 12:24:23 +03:00
table := db.Statement.Table
if err := db.Statement.Parse(db.Statement.Model); err == nil {
2020-05-24 12:24:23 +03:00
db.Statement.Table = table
association.Relationship = db.Statement.Schema.Relationships.Relations[column]
if association.Relationship == nil {
association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column)
}
2020-05-23 06:57:28 +03:00
db.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(db.Statement.Model))
} else {
association.Error = err
}
return association
}
func (association *Association) Find(out interface{}, conds ...interface{}) error {
if association.Error == nil {
2020-05-20 18:44:50 +03:00
var (
2020-05-23 06:57:28 +03:00
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
tx = association.DB.Model(out).Table("")
2020-05-20 18:44:50 +03:00
)
if association.Relationship.JoinTable != nil {
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
tx.Clauses(queryClause)
}
tx.Clauses(clause.From{Joins: []clause.Join{{
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
ON: clause.Where{Exprs: queryConds},
}}})
} else {
tx.Clauses(clause.Where{Exprs: queryConds})
}
association.Error = tx.Find(out, conds...).Error
}
return association.Error
}
func (association *Association) Append(values ...interface{}) error {
2020-05-20 18:44:50 +03:00
if association.Error == nil {
switch association.Relationship.Type {
case schema.HasOne, schema.BelongsTo:
if len(values) > 0 {
association.Error = association.Replace(values...)
}
default:
association.saveAssociation(false, values...)
}
}
return association.Error
}
func (association *Association) Replace(values ...interface{}) error {
2020-05-20 18:44:50 +03:00
if association.Error == nil {
association.saveAssociation(true, values...)
reflectValue := association.DB.Statement.ReflectValue
rel := association.Relationship
switch rel.Type {
2020-05-24 12:24:23 +03:00
case schema.BelongsTo:
if len(values) == 0 {
updateMap := map[string]interface{}{}
for _, ref := range rel.References {
updateMap[ref.ForeignKey.DBName] = nil
}
association.DB.UpdateColumns(updateMap)
}
2020-05-20 18:44:50 +03:00
case schema.HasOne, schema.HasMany:
var (
2020-05-24 17:52:16 +03:00
tx = association.DB
primaryFields []*schema.Field
foreignKeys []string
updateMap = map[string]interface{}{}
relPrimaryKeys = []string{}
relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel})
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
2020-05-20 18:44:50 +03:00
)
2020-05-24 17:52:16 +03:00
for _, field := range rel.FieldSchema.PrimaryFields {
relPrimaryKeys = append(relPrimaryKeys, field.DBName)
}
if _, qvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(qvs) > 0 {
if column, values := schema.ToQueryValues(relPrimaryKeys, qvs); len(values) > 0 {
tx = tx.Not(clause.IN{Column: column, Values: values})
}
2020-05-24 12:24:23 +03:00
}
2020-05-20 18:44:50 +03:00
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
primaryFields = append(primaryFields, ref.PrimaryKey)
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
updateMap[ref.ForeignKey.DBName] = nil
}
}
2020-05-24 17:52:16 +03:00
if _, qvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(qvs) > 0 {
column, values := schema.ToQueryValues(foreignKeys, qvs)
tx.Model(modelValue).Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap)
2020-05-23 16:03:28 +03:00
}
2020-05-20 18:44:50 +03:00
case schema.Many2Many:
var primaryFields, relPrimaryFields []*schema.Field
var foreignKeys, relForeignKeys []string
modelValue := reflect.New(rel.JoinTable.ModelType).Interface()
conds := []clause.Expression{}
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
primaryFields = append(primaryFields, ref.PrimaryKey)
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
} else if ref.PrimaryValue != "" {
conds = append(conds, clause.Eq{
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
})
} else {
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
}
}
generateConds := func(rv reflect.Value) {
_, values := schema.GetIdentityFieldValuesMap(rv, primaryFields)
column, queryValues := schema.ToQueryValues(foreignKeys, values)
relValue := rel.Field.ReflectValueOf(rv)
_, relValues := schema.GetIdentityFieldValuesMap(relValue, relPrimaryFields)
relColumn, relQueryValues := schema.ToQueryValues(relForeignKeys, relValues)
conds = append(conds, clause.And(
clause.IN{Column: column, Values: queryValues},
clause.Not(clause.IN{Column: relColumn, Values: relQueryValues}),
))
}
switch reflectValue.Kind() {
case reflect.Struct:
generateConds(reflectValue)
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
generateConds(reflectValue.Index(i))
}
}
association.DB.Where(conds).Delete(modelValue)
}
}
return association.Error
}
func (association *Association) Delete(values ...interface{}) error {
2020-05-19 21:03:43 +03:00
if association.Error == nil {
var (
2020-05-24 18:28:06 +03:00
reflectValue = association.DB.Statement.ReflectValue
2020-05-24 12:24:23 +03:00
rel = association.Relationship
2020-05-24 18:28:06 +03:00
tx = association.DB
2020-05-24 12:24:23 +03:00
relFields []*schema.Field
foreignKeyFields []*schema.Field
foreignKeys []string
updateAttrs = map[string]interface{}{}
2020-05-19 21:03:43 +03:00
)
for _, ref := range rel.References {
if ref.PrimaryValue == "" {
if rel.JoinTable == nil || !ref.OwnPrimaryKey {
if ref.OwnPrimaryKey {
relFields = append(relFields, ref.ForeignKey)
} else {
relFields = append(relFields, ref.PrimaryKey)
2020-05-24 12:24:23 +03:00
foreignKeyFields = append(foreignKeyFields, ref.ForeignKey)
2020-05-19 21:03:43 +03:00
}
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
updateAttrs[ref.ForeignKey.DBName] = nil
}
}
}
relValuesMap, relQueryValues := schema.GetIdentityFieldValuesMapFromValues(values, relFields)
column, values := schema.ToQueryValues(foreignKeys, relQueryValues)
2020-05-24 18:28:06 +03:00
tx = tx.Session(&Session{}).Where(clause.IN{Column: column, Values: values})
2020-05-19 21:03:43 +03:00
2020-05-20 18:44:50 +03:00
switch rel.Type {
2020-05-19 21:03:43 +03:00
case schema.HasOne, schema.HasMany:
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
2020-05-24 12:24:23 +03:00
conds := rel.ToQueryConditions(reflectValue)
2020-05-19 21:03:43 +03:00
tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs)
case schema.BelongsTo:
2020-05-24 16:46:33 +03:00
primaryKeys := []string{}
for _, field := range rel.Schema.PrimaryFields {
primaryKeys = append(primaryKeys, field.DBName)
}
_, queryValues := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields)
if column, values := schema.ToQueryValues(primaryKeys, queryValues); len(values) > 0 {
tx.Where(clause.IN{Column: column, Values: values})
}
2020-05-24 12:24:23 +03:00
modelValue := reflect.New(rel.Schema.ModelType).Interface()
tx.Model(modelValue).UpdateColumns(updateAttrs)
2020-05-19 21:03:43 +03:00
case schema.Many2Many:
modelValue := reflect.New(rel.JoinTable.ModelType).Interface()
2020-05-24 12:24:23 +03:00
conds := rel.ToQueryConditions(reflectValue)
2020-05-19 21:03:43 +03:00
tx.Clauses(clause.Where{Exprs: conds}).Delete(modelValue)
}
if tx.Error == nil {
cleanUpDeletedRelations := func(data reflect.Value) {
if _, zero := rel.Field.ValueOf(data); !zero {
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data))
2020-05-23 16:03:28 +03:00
fieldValues := make([]interface{}, len(relFields))
2020-05-19 21:03:43 +03:00
switch fieldValue.Kind() {
case reflect.Slice, reflect.Array:
validFieldValues := reflect.Zero(rel.Field.FieldType)
for i := 0; i < fieldValue.Len(); i++ {
for idx, field := range relFields {
2020-05-23 16:03:28 +03:00
fieldValues[idx], _ = field.ValueOf(fieldValue.Index(i))
2020-05-19 21:03:43 +03:00
}
if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; !ok {
validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i))
}
}
2020-05-24 12:24:23 +03:00
rel.Field.Set(data, validFieldValues.Interface())
2020-05-19 21:03:43 +03:00
case reflect.Struct:
for idx, field := range relFields {
2020-05-24 12:24:23 +03:00
fieldValues[idx], _ = field.ValueOf(fieldValue)
2020-05-19 21:03:43 +03:00
}
if _, ok := relValuesMap[utils.ToStringKey(fieldValues...)]; ok {
2020-05-24 12:24:23 +03:00
rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface())
for _, field := range foreignKeyFields {
field.Set(data, reflect.Zero(field.FieldType).Interface())
}
2020-05-19 21:03:43 +03:00
}
}
}
}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
cleanUpDeletedRelations(reflect.Indirect(reflectValue.Index(i)))
}
case reflect.Struct:
cleanUpDeletedRelations(reflectValue)
}
} else {
association.Error = tx.Error
}
}
return association.Error
}
func (association *Association) Clear() error {
2020-05-19 16:50:06 +03:00
return association.Replace()
}
2020-05-24 06:32:59 +03:00
func (association *Association) Count() (count int64) {
2020-05-19 16:50:06 +03:00
if association.Error == nil {
var (
2020-05-24 06:32:59 +03:00
conds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue)
2020-05-19 16:50:06 +03:00
)
if association.Relationship.JoinTable != nil {
2020-05-19 21:03:43 +03:00
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
tx.Clauses(queryClause)
}
2020-05-19 16:50:06 +03:00
tx.Clauses(clause.From{Joins: []clause.Join{{
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
ON: clause.Where{Exprs: conds},
}}})
} else {
tx.Clauses(clause.Where{Exprs: conds})
}
association.Error = tx.Count(&count).Error
}
return
2020-02-23 18:28:35 +03:00
}
2020-05-20 18:44:50 +03:00
func (association *Association) saveAssociation(clear bool, values ...interface{}) {
2020-05-24 12:24:23 +03:00
var (
reflectValue = association.DB.Statement.ReflectValue
assignBacks = [][2]reflect.Value{}
assignBack = association.Relationship.Field.FieldType.Kind() == reflect.Struct
)
2020-05-20 18:44:50 +03:00
appendToRelations := func(source, rv reflect.Value, clear bool) {
switch association.Relationship.Type {
case schema.HasOne, schema.BelongsTo:
switch rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() > 0 {
2020-05-24 12:24:23 +03:00
association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface())
if assignBack {
assignBacks = append(assignBacks, [2]reflect.Value{source, rv.Index(0)})
}
2020-05-20 18:44:50 +03:00
}
case reflect.Struct:
2020-05-24 12:24:23 +03:00
association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface())
if assignBack {
assignBacks = append(assignBacks, [2]reflect.Value{source, rv})
}
2020-05-20 18:44:50 +03:00
}
case schema.HasMany, schema.Many2Many:
elemType := association.Relationship.Field.IndirectFieldType.Elem()
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(reflectValue))
if clear {
fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType)
}
appendToFieldValues := func(ev reflect.Value) {
if ev.Type().AssignableTo(elemType) {
fieldValue = reflect.Append(fieldValue, ev)
} else if ev.Type().Elem().AssignableTo(elemType) {
fieldValue = reflect.Append(fieldValue, ev.Elem())
} else {
association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name)
}
}
switch rv.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < rv.Len(); i++ {
appendToFieldValues(reflect.Indirect(rv.Index(i)))
}
case reflect.Struct:
appendToFieldValues(rv)
}
if association.Error == nil {
2020-05-24 12:24:23 +03:00
association.Error = association.Relationship.Field.Set(source, fieldValue.Addr().Interface())
2020-05-20 18:44:50 +03:00
}
}
}
selectedColumns := []string{association.Relationship.Name}
for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey {
selectedColumns = append(selectedColumns, ref.ForeignKey.Name)
}
}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if len(values) != reflectValue.Len() {
if clear && len(values) == 0 {
for i := 0; i < reflectValue.Len(); i++ {
2020-05-24 12:24:23 +03:00
association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
2020-05-24 15:44:37 +03:00
for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey {
ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface())
}
}
2020-05-20 18:44:50 +03:00
}
break
}
association.Error = errors.New("invalid association values, length doesn't match")
2020-05-24 16:46:33 +03:00
return
2020-05-20 18:44:50 +03:00
}
for i := 0; i < reflectValue.Len(); i++ {
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
2020-05-24 16:46:33 +03:00
if len(values) > 0 {
// TODO support save slice data, sql with case
err := association.DB.Session(&Session{}).Select(selectedColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error
association.DB.AddError(err)
2020-05-20 18:44:50 +03:00
}
}
case reflect.Struct:
if clear && len(values) == 0 {
2020-05-24 12:24:23 +03:00
association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
2020-05-24 15:44:37 +03:00
for _, ref := range association.Relationship.References {
2020-05-24 18:28:06 +03:00
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
2020-05-24 15:44:37 +03:00
ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
}
}
2020-05-20 18:44:50 +03:00
}
for idx, value := range values {
2020-05-24 12:24:23 +03:00
rv := reflect.Indirect(reflect.ValueOf(value))
appendToRelations(reflectValue, rv, clear && idx == 0)
2020-05-20 18:44:50 +03:00
}
2020-05-24 16:46:33 +03:00
if len(values) > 0 {
2020-05-24 17:52:16 +03:00
association.DB.Session(&Session{}).Select(selectedColumns).Model(nil).Save(reflectValue.Addr().Interface())
2020-05-24 15:44:37 +03:00
}
2020-05-24 12:24:23 +03:00
}
for _, assignBack := range assignBacks {
reflect.Indirect(assignBack[1]).Set(association.Relationship.Field.ReflectValueOf(assignBack[0]))
2020-05-20 18:44:50 +03:00
}
}