gorm/association.go

374 lines
14 KiB
Go
Raw Normal View History

2014-07-30 10:30:21 +04:00
package gorm
2014-07-30 12:22:26 +04:00
import (
2016-01-16 07:18:04 +03:00
"errors"
2014-07-30 12:22:26 +04:00
"fmt"
"reflect"
)
2016-01-16 07:18:04 +03:00
// Association Mode contains some helper methods to handle relationship things easily.
2014-07-30 10:30:21 +04:00
type Association struct {
2015-07-30 12:26:10 +03:00
Error error
2016-01-16 07:18:04 +03:00
scope *Scope
column string
field *Field
2014-07-30 10:30:21 +04:00
}
2016-01-15 17:14:21 +03:00
// Find find out all related associations
2014-07-30 16:48:36 +04:00
func (association *Association) Find(value interface{}) *Association {
2016-01-16 07:18:04 +03:00
association.scope.related(value, association.column)
return association.setErr(association.scope.db.Error)
2014-07-30 16:48:36 +04:00
}
2014-07-30 12:22:26 +04:00
2016-01-16 07:18:04 +03:00
// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to
2015-12-26 10:19:56 +03:00
func (association *Association) Append(values ...interface{}) *Association {
if association.Error != nil {
return association
}
2016-01-16 07:18:04 +03:00
if relationship := association.field.Relationship; relationship.Kind == "has_one" {
2015-12-26 10:19:56 +03:00
return association.Replace(values...)
}
return association.saveAssociations(values...)
}
2016-01-15 17:14:21 +03:00
// Replace replace current associations with new one
func (association *Association) Replace(values ...interface{}) *Association {
if association.Error != nil {
return association
}
2015-12-25 17:59:01 +03:00
var (
2016-01-16 07:18:04 +03:00
relationship = association.field.Relationship
scope = association.scope
field = association.field.Field
2015-12-25 17:59:01 +03:00
newDB = scope.NewDB()
)
// Append new values
2016-01-16 07:18:04 +03:00
association.field.Set(reflect.Zero(association.field.Field.Type()))
2015-12-26 10:19:56 +03:00
association.saveAssociations(values...)
2015-12-25 15:43:51 +03:00
2015-12-25 17:59:01 +03:00
// Belongs To
2015-12-25 15:43:51 +03:00
if relationship.Kind == "belongs_to" {
2016-01-15 17:14:21 +03:00
// Set foreign key to be null when clearing value (length equals 0)
2015-12-25 17:59:01 +03:00
if len(values) == 0 {
// Set foreign key to be nil
var foreignKeyMap = map[string]interface{}{}
for _, foreignKey := range relationship.ForeignDBNames {
foreignKeyMap[foreignKey] = nil
2015-12-25 15:43:51 +03:00
}
2015-12-25 17:59:01 +03:00
association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error)
}
2015-12-25 15:43:51 +03:00
} else {
2016-01-15 17:14:21 +03:00
// Polymorphic Relations
2015-12-26 10:19:56 +03:00
if relationship.PolymorphicDBName != "" {
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue)
2015-12-26 10:19:56 +03:00
}
2016-01-16 01:00:18 +03:00
// Delete Relations except new created
2015-12-25 17:59:01 +03:00
if len(values) > 0 {
2016-10-20 08:27:26 +03:00
var associationForeignFieldNames, associationForeignDBNames []string
if relationship.Kind == "many_to_many" {
2016-01-16 01:00:18 +03:00
// if many to many relations, get association fields name from association foreign keys
2016-03-07 07:15:15 +03:00
associationScope := scope.New(reflect.New(field.Type()).Interface())
2016-10-20 08:27:26 +03:00
for idx, dbName := range relationship.AssociationForeignFieldNames {
2016-03-07 07:15:15 +03:00
if field, ok := associationScope.FieldByName(dbName); ok {
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
2016-10-20 08:27:26 +03:00
associationForeignDBNames = append(associationForeignDBNames, relationship.AssociationForeignDBNames[idx])
2016-03-07 07:15:15 +03:00
}
2016-01-16 01:00:18 +03:00
}
2015-12-25 17:59:01 +03:00
} else {
2016-10-20 08:27:26 +03:00
// If has one/many relations, use primary keys
2016-01-16 01:00:18 +03:00
for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
2016-10-20 08:27:26 +03:00
associationForeignDBNames = append(associationForeignDBNames, field.DBName)
2016-01-16 01:00:18 +03:00
}
2015-12-25 17:59:01 +03:00
}
2016-01-16 01:00:18 +03:00
newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface())
if len(newPrimaryKeys) > 0 {
2016-10-20 08:27:26 +03:00
sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, associationForeignDBNames), toQueryMarks(newPrimaryKeys))
newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...)
}
2015-12-25 15:43:51 +03:00
}
if relationship.Kind == "many_to_many" {
2016-01-16 01:00:18 +03:00
// if many to many relations, delete related relations from join table
2016-03-07 07:15:15 +03:00
var sourceForeignFieldNames []string
2016-01-16 01:00:18 +03:00
for _, dbName := range relationship.ForeignFieldNames {
2016-03-07 07:15:15 +03:00
if field, ok := scope.FieldByName(dbName); ok {
sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name)
}
2016-01-16 01:00:18 +03:00
}
if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
2016-01-15 17:14:21 +03:00
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
2016-01-15 17:14:21 +03:00
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
}
2015-12-25 15:43:51 +03:00
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
2016-01-16 01:00:18 +03:00
// has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
var foreignKeyMap = map[string]interface{}{}
for idx, foreignKey := range relationship.ForeignDBNames {
foreignKeyMap[foreignKey] = nil
if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok {
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
}
}
2016-01-16 07:18:04 +03:00
fieldValue := reflect.New(association.field.Field.Type()).Interface()
2015-12-25 19:23:17 +03:00
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
2015-12-25 15:43:51 +03:00
}
}
return association
}
2016-01-15 17:14:21 +03:00
// Delete remove relationship between source & passed arguments, but won't delete those arguments
2015-07-30 12:26:10 +03:00
func (association *Association) Delete(values ...interface{}) *Association {
if association.Error != nil {
return association
}
2015-12-25 19:23:17 +03:00
var (
2016-01-16 07:18:04 +03:00
relationship = association.field.Relationship
scope = association.scope
field = association.field.Field
2015-12-25 19:23:17 +03:00
newDB = scope.NewDB()
)
2014-07-30 18:10:12 +04:00
if len(values) == 0 {
return association
}
2015-12-26 11:06:53 +03:00
var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string
2016-01-16 01:00:18 +03:00
for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name)
deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName)
2015-12-26 11:06:53 +03:00
}
2016-01-15 17:53:09 +03:00
deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...)
2015-12-26 11:06:53 +03:00
2015-07-30 12:26:10 +03:00
if relationship.Kind == "many_to_many" {
2015-12-26 11:06:53 +03:00
// source value's foreign keys
2015-07-30 12:26:10 +03:00
for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
2015-12-25 19:23:17 +03:00
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
2014-07-30 16:48:36 +04:00
}
}
2014-07-30 18:10:12 +04:00
2016-01-16 01:00:18 +03:00
// get association's foreign fields name
2016-03-07 07:15:15 +03:00
var associationScope = scope.New(reflect.New(field.Type()).Interface())
2016-01-16 01:00:18 +03:00
var associationForeignFieldNames []string
for _, associationDBName := range relationship.AssociationForeignFieldNames {
2016-03-07 07:15:15 +03:00
if field, ok := associationScope.FieldByName(associationDBName); ok {
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
}
2016-01-16 01:00:18 +03:00
}
2015-12-26 11:06:53 +03:00
// association value's foreign keys
2016-01-16 01:00:18 +03:00
deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...)
2015-12-26 11:06:53 +03:00
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
2014-07-30 16:48:36 +04:00
2015-12-26 11:06:53 +03:00
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
2015-07-30 12:26:10 +03:00
} else {
2015-12-25 19:23:17 +03:00
var foreignKeyMap = map[string]interface{}{}
for _, foreignKey := range relationship.ForeignDBNames {
foreignKeyMap[foreignKey] = nil
}
if relationship.Kind == "belongs_to" {
2015-12-25 19:23:17 +03:00
// find with deleting relation's foreign keys
2016-01-15 17:53:09 +03:00
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...)
2015-12-25 19:23:17 +03:00
newDB = newDB.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)...,
)
2016-01-16 01:00:18 +03:00
// set foreign key to be null if there are some records affected
modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
if results.RowsAffected > 0 {
scope.updatedAttrsWithValues(foreignKeyMap)
}
} else {
association.setErr(results.Error)
}
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
2015-12-25 19:23:17 +03:00
// find all relations
2016-01-15 17:53:09 +03:00
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
2015-12-25 19:23:17 +03:00
newDB = newDB.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)...,
)
2015-12-25 19:23:17 +03:00
// only include those deleting relations
newDB = newDB.Where(
2015-12-26 11:06:53 +03:00
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)),
toQueryValues(deletingPrimaryKeys)...,
2015-12-25 19:23:17 +03:00
)
2015-12-25 19:23:17 +03:00
// set matched relation's foreign key to be null
2016-01-16 07:18:04 +03:00
fieldValue := reflect.New(association.field.Field.Type()).Interface()
2015-12-26 11:06:53 +03:00
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
}
}
2016-01-16 01:00:18 +03:00
// Remove deleted records from source's field
2015-12-26 11:06:53 +03:00
if association.Error == nil {
2016-01-16 07:18:04 +03:00
if field.Kind() == reflect.Slice {
leftValues := reflect.Zero(field.Type())
2015-12-26 11:06:53 +03:00
2016-01-16 07:18:04 +03:00
for i := 0; i < field.Len(); i++ {
reflectValue := field.Index(i)
2016-01-16 01:00:18 +03:00
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
var isDeleted = false
2015-12-26 11:06:53 +03:00
for _, pk := range deletingPrimaryKeys {
if equalAsString(primaryKey, pk) {
2016-01-16 01:00:18 +03:00
isDeleted = true
break
2015-12-26 11:06:53 +03:00
}
}
2016-01-16 01:00:18 +03:00
if !isDeleted {
2015-12-26 11:06:53 +03:00
leftValues = reflect.Append(leftValues, reflectValue)
}
}
2016-01-16 07:18:04 +03:00
association.field.Set(leftValues)
} else if field.Kind() == reflect.Struct {
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0]
2015-12-26 11:06:53 +03:00
for _, pk := range deletingPrimaryKeys {
if equalAsString(primaryKey, pk) {
2016-01-16 07:18:04 +03:00
association.field.Set(reflect.Zero(field.Type()))
2015-12-26 11:06:53 +03:00
break
}
}
}
2014-07-30 16:48:36 +04:00
}
2015-12-25 19:23:17 +03:00
2014-07-30 12:22:26 +04:00
return association
2014-07-30 10:30:21 +04:00
}
2016-01-15 17:14:21 +03:00
// Clear remove relationship between source & current associations, won't delete those associations
func (association *Association) Clear() *Association {
2015-12-25 14:33:57 +03:00
return association.Replace()
2014-07-30 10:30:21 +04:00
}
2016-01-15 17:14:21 +03:00
// Count return the count of current associations
2014-07-30 17:43:53 +04:00
func (association *Association) Count() int {
2016-01-12 07:16:22 +03:00
var (
count = 0
2016-01-16 07:18:04 +03:00
relationship = association.field.Relationship
scope = association.scope
fieldValue = association.field.Field.Interface()
2016-01-16 01:00:18 +03:00
query = scope.DB()
2016-01-12 07:16:22 +03:00
)
2014-07-30 10:30:21 +04:00
2014-07-31 07:08:26 +04:00
if relationship.Kind == "many_to_many" {
2016-01-16 07:18:04 +03:00
query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
2016-01-16 01:00:18 +03:00
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
query = query.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)...,
)
2014-07-31 07:08:26 +04:00
} else if relationship.Kind == "belongs_to" {
2016-01-16 01:00:18 +03:00
primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value)
query = query.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)...,
)
}
if relationship.PolymorphicType != "" {
query = query.Where(
fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)),
relationship.PolymorphicValue,
2016-01-16 01:00:18 +03:00
)
2014-07-30 16:48:36 +04:00
}
2014-07-30 17:43:53 +04:00
2016-01-16 01:00:18 +03:00
query.Model(fieldValue).Count(&count)
2014-07-30 17:43:53 +04:00
return count
2014-07-30 10:30:21 +04:00
}
2016-01-16 07:18:04 +03:00
// saveAssociations save passed values as associations
func (association *Association) saveAssociations(values ...interface{}) *Association {
var (
scope = association.scope
field = association.field
relationship = field.Relationship
)
saveAssociation := func(reflectValue reflect.Value) {
// value has to been pointer
if reflectValue.Kind() != reflect.Ptr {
reflectPtr := reflect.New(reflectValue.Type())
reflectPtr.Elem().Set(reflectValue)
reflectValue = reflectPtr
}
// value has to been saved for many2many
if relationship.Kind == "many_to_many" {
if scope.New(reflectValue.Interface()).PrimaryKeyZero() {
association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error)
}
}
// Assign Fields
var fieldType = field.Field.Type()
var setFieldBackToValue, setSliceFieldBackToValue bool
if reflectValue.Type().AssignableTo(fieldType) {
field.Set(reflectValue)
} else if reflectValue.Type().Elem().AssignableTo(fieldType) {
// if field's type is struct, then need to set value back to argument after save
setFieldBackToValue = true
field.Set(reflectValue.Elem())
} else if fieldType.Kind() == reflect.Slice {
if reflectValue.Type().AssignableTo(fieldType.Elem()) {
field.Set(reflect.Append(field.Field, reflectValue))
} else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) {
// if field's type is slice of struct, then need to set value back to argument after save
setSliceFieldBackToValue = true
field.Set(reflect.Append(field.Field, reflectValue.Elem()))
}
}
if relationship.Kind == "many_to_many" {
association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
} else {
association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error)
if setFieldBackToValue {
reflectValue.Elem().Set(field.Field)
} else if setSliceFieldBackToValue {
reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1))
}
}
}
for _, value := range values {
reflectValue := reflect.ValueOf(value)
indirectReflectValue := reflect.Indirect(reflectValue)
if indirectReflectValue.Kind() == reflect.Struct {
saveAssociation(reflectValue)
} else if indirectReflectValue.Kind() == reflect.Slice {
for i := 0; i < indirectReflectValue.Len(); i++ {
saveAssociation(indirectReflectValue.Index(i))
}
} else {
association.setErr(errors.New("invalid value type"))
}
}
return association
}
func (association *Association) setErr(err error) *Association {
if err != nil {
association.Error = err
}
return association
}