forked from mirror/gorm
Refactor Association Mode
This commit is contained in:
parent
8d716be896
commit
41870191b0
185
association.go
185
association.go
|
@ -1,12 +1,11 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Association Association Mode contains some helper methods to handle relationship things easily.
|
||||
type Association struct {
|
||||
Scope *Scope
|
||||
Column string
|
||||
|
@ -14,86 +13,13 @@ type Association struct {
|
|||
Field *Field
|
||||
}
|
||||
|
||||
func (association *Association) setErr(err error) *Association {
|
||||
if err != nil {
|
||||
association.Error = err
|
||||
}
|
||||
return association
|
||||
}
|
||||
|
||||
// Find find out all related associations
|
||||
func (association *Association) Find(value interface{}) *Association {
|
||||
association.Scope.related(value, association.Column)
|
||||
return association.setErr(association.Scope.db.Error)
|
||||
}
|
||||
|
||||
func (association *Association) saveAssociations(values ...interface{}) *Association {
|
||||
scope := association.Scope
|
||||
field := association.Field
|
||||
relationship := association.Field.Relationship
|
||||
|
||||
saveAssociation := func(reflectValue reflect.Value) {
|
||||
// value has to been pointer
|
||||
if reflectValue.Kind() != reflect.Ptr {
|
||||
reflectPtr := reflect.New(reflectValue.Type())
|
||||
reflectPtr.Elem().Set(reflectValue)
|
||||
reflectValue = reflectPtr
|
||||
}
|
||||
|
||||
// value has to been saved for many2many
|
||||
if relationship.Kind == "many_to_many" {
|
||||
if scope.New(reflectValue.Interface()).PrimaryKeyZero() {
|
||||
association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error)
|
||||
}
|
||||
}
|
||||
|
||||
// Assign Fields
|
||||
var fieldType = field.Field.Type()
|
||||
var setFieldBackToValue, setSliceFieldBackToValue bool
|
||||
if reflectValue.Type().AssignableTo(fieldType) {
|
||||
field.Set(reflectValue)
|
||||
} else if reflectValue.Type().Elem().AssignableTo(fieldType) {
|
||||
// if field's type is struct, then need to set value back to argument after save
|
||||
setFieldBackToValue = true
|
||||
field.Set(reflectValue.Elem())
|
||||
} else if fieldType.Kind() == reflect.Slice {
|
||||
if reflectValue.Type().AssignableTo(fieldType.Elem()) {
|
||||
field.Set(reflect.Append(field.Field, reflectValue))
|
||||
} else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) {
|
||||
// if field's type is slice of struct, then need to set value back to argument after save
|
||||
setSliceFieldBackToValue = true
|
||||
field.Set(reflect.Append(field.Field, reflectValue.Elem()))
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
|
||||
} else {
|
||||
association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error)
|
||||
|
||||
if setFieldBackToValue {
|
||||
reflectValue.Elem().Set(field.Field)
|
||||
} else if setSliceFieldBackToValue {
|
||||
reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
reflectValue := reflect.ValueOf(value)
|
||||
indirectReflectValue := reflect.Indirect(reflectValue)
|
||||
if indirectReflectValue.Kind() == reflect.Struct {
|
||||
saveAssociation(reflectValue)
|
||||
} else if indirectReflectValue.Kind() == reflect.Slice {
|
||||
for i := 0; i < indirectReflectValue.Len(); i++ {
|
||||
saveAssociation(indirectReflectValue.Index(i))
|
||||
}
|
||||
} else {
|
||||
association.setErr(errors.New("invalid value type"))
|
||||
}
|
||||
}
|
||||
return association
|
||||
}
|
||||
|
||||
// Append append new associations for many2many, has_many, will replace current association for has_one, belongs_to
|
||||
func (association *Association) Append(values ...interface{}) *Association {
|
||||
if relationship := association.Field.Relationship; relationship.Kind == "has_one" {
|
||||
return association.Replace(values...)
|
||||
|
@ -101,6 +27,7 @@ func (association *Association) Append(values ...interface{}) *Association {
|
|||
return association.saveAssociations(values...)
|
||||
}
|
||||
|
||||
// Replace replace current associations with new one
|
||||
func (association *Association) Replace(values ...interface{}) *Association {
|
||||
var (
|
||||
relationship = association.Field.Relationship
|
||||
|
@ -115,7 +42,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
|||
|
||||
// Belongs To
|
||||
if relationship.Kind == "belongs_to" {
|
||||
// Set foreign key to be null only when clearing value
|
||||
// Set foreign key to be null when clearing value (length equals 0)
|
||||
if len(values) == 0 {
|
||||
// Set foreign key to be nil
|
||||
var foreignKeyMap = map[string]interface{}{}
|
||||
|
@ -125,29 +52,21 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
|||
association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error)
|
||||
}
|
||||
} else {
|
||||
// Relations
|
||||
// Polymorphic Relations
|
||||
if relationship.PolymorphicDBName != "" {
|
||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
|
||||
}
|
||||
|
||||
// Relations except new created
|
||||
if len(values) > 0 {
|
||||
var newPrimaryKeys [][]interface{}
|
||||
var associationForeignFieldNames []string
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
// If many to many relations, get it from foreign key
|
||||
associationForeignFieldNames = relationship.AssociationForeignFieldNames
|
||||
} else {
|
||||
// If other relations, get real primary keys
|
||||
for _, field := range scope.New(reflect.New(field.Type()).Interface()).Fields() {
|
||||
if field.IsPrimaryKey {
|
||||
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
|
||||
}
|
||||
}
|
||||
associationForeignFieldNames = relationship.AssociationForeignDBNames
|
||||
}
|
||||
|
||||
newPrimaryKeys = association.getPrimaryKeys(associationForeignFieldNames, field.Interface())
|
||||
newPrimaryKeys := association.getPrimaryKeys(associationForeignFieldNames, field.Interface())
|
||||
|
||||
if len(newPrimaryKeys) > 0 {
|
||||
sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys))
|
||||
|
@ -156,13 +75,11 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
|||
}
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
|
||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
if sourcePrimaryKeys := association.getPrimaryKeys(relationship.ForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
|
||||
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
|
||||
|
||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
||||
}
|
||||
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
||||
var foreignKeyMap = map[string]interface{}{}
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
|
@ -179,6 +96,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
|||
return association
|
||||
}
|
||||
|
||||
// Delete remove relationship between source & passed arguments, but won't delete those arguments
|
||||
func (association *Association) Delete(values ...interface{}) *Association {
|
||||
var (
|
||||
relationship = association.Field.Relationship
|
||||
|
@ -292,10 +210,12 @@ func (association *Association) Delete(values ...interface{}) *Association {
|
|||
return association
|
||||
}
|
||||
|
||||
// Clear remove relationship between source & current associations, won't delete those associations
|
||||
func (association *Association) Clear() *Association {
|
||||
return association.Replace()
|
||||
}
|
||||
|
||||
// Count return the count of current associations
|
||||
func (association *Association) Count() int {
|
||||
var (
|
||||
count = 0
|
||||
|
@ -333,78 +253,3 @@ func (association *Association) Count() int {
|
|||
|
||||
return count
|
||||
}
|
||||
|
||||
func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) (results [][]interface{}) {
|
||||
scope := association.Scope
|
||||
|
||||
for _, value := range values {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
if reflectValue.Kind() == reflect.Slice {
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
primaryKeys := []interface{}{}
|
||||
newScope := scope.New(reflectValue.Index(i).Interface())
|
||||
for _, column := range columns {
|
||||
if field, ok := newScope.FieldByName(column); ok {
|
||||
primaryKeys = append(primaryKeys, field.Field.Interface())
|
||||
} else {
|
||||
primaryKeys = append(primaryKeys, "")
|
||||
}
|
||||
}
|
||||
results = append(results, primaryKeys)
|
||||
}
|
||||
} else if reflectValue.Kind() == reflect.Struct {
|
||||
newScope := scope.New(value)
|
||||
var primaryKeys []interface{}
|
||||
for _, column := range columns {
|
||||
if field, ok := newScope.FieldByName(column); ok {
|
||||
primaryKeys = append(primaryKeys, field.Field.Interface())
|
||||
} else {
|
||||
primaryKeys = append(primaryKeys, "")
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, primaryKeys)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func toQueryMarks(primaryValues [][]interface{}) string {
|
||||
var results []string
|
||||
|
||||
for _, primaryValue := range primaryValues {
|
||||
var marks []string
|
||||
for _ = range primaryValue {
|
||||
marks = append(marks, "?")
|
||||
}
|
||||
|
||||
if len(marks) > 1 {
|
||||
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
|
||||
} else {
|
||||
results = append(results, strings.Join(marks, ""))
|
||||
}
|
||||
}
|
||||
return strings.Join(results, ",")
|
||||
}
|
||||
|
||||
func toQueryCondition(scope *Scope, columns []string) string {
|
||||
var newColumns []string
|
||||
for _, column := range columns {
|
||||
newColumns = append(newColumns, scope.Quote(column))
|
||||
}
|
||||
|
||||
if len(columns) > 1 {
|
||||
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
|
||||
}
|
||||
return strings.Join(newColumns, ",")
|
||||
}
|
||||
|
||||
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
|
||||
for _, primaryValue := range primaryValues {
|
||||
for _, value := range primaryValue {
|
||||
values = append(values, value)
|
||||
}
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (association *Association) setErr(err error) *Association {
|
||||
if err != nil {
|
||||
association.Error = err
|
||||
}
|
||||
return association
|
||||
}
|
||||
|
||||
func (association *Association) saveAssociations(values ...interface{}) *Association {
|
||||
scope := association.Scope
|
||||
field := association.Field
|
||||
relationship := association.Field.Relationship
|
||||
|
||||
saveAssociation := func(reflectValue reflect.Value) {
|
||||
// value has to been pointer
|
||||
if reflectValue.Kind() != reflect.Ptr {
|
||||
reflectPtr := reflect.New(reflectValue.Type())
|
||||
reflectPtr.Elem().Set(reflectValue)
|
||||
reflectValue = reflectPtr
|
||||
}
|
||||
|
||||
// value has to been saved for many2many
|
||||
if relationship.Kind == "many_to_many" {
|
||||
if scope.New(reflectValue.Interface()).PrimaryKeyZero() {
|
||||
association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error)
|
||||
}
|
||||
}
|
||||
|
||||
// Assign Fields
|
||||
var fieldType = field.Field.Type()
|
||||
var setFieldBackToValue, setSliceFieldBackToValue bool
|
||||
if reflectValue.Type().AssignableTo(fieldType) {
|
||||
field.Set(reflectValue)
|
||||
} else if reflectValue.Type().Elem().AssignableTo(fieldType) {
|
||||
// if field's type is struct, then need to set value back to argument after save
|
||||
setFieldBackToValue = true
|
||||
field.Set(reflectValue.Elem())
|
||||
} else if fieldType.Kind() == reflect.Slice {
|
||||
if reflectValue.Type().AssignableTo(fieldType.Elem()) {
|
||||
field.Set(reflect.Append(field.Field, reflectValue))
|
||||
} else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) {
|
||||
// if field's type is slice of struct, then need to set value back to argument after save
|
||||
setSliceFieldBackToValue = true
|
||||
field.Set(reflect.Append(field.Field, reflectValue.Elem()))
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
|
||||
} else {
|
||||
association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error)
|
||||
|
||||
if setFieldBackToValue {
|
||||
reflectValue.Elem().Set(field.Field)
|
||||
} else if setSliceFieldBackToValue {
|
||||
reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
reflectValue := reflect.ValueOf(value)
|
||||
indirectReflectValue := reflect.Indirect(reflectValue)
|
||||
if indirectReflectValue.Kind() == reflect.Struct {
|
||||
saveAssociation(reflectValue)
|
||||
} else if indirectReflectValue.Kind() == reflect.Slice {
|
||||
for i := 0; i < indirectReflectValue.Len(); i++ {
|
||||
saveAssociation(indirectReflectValue.Index(i))
|
||||
}
|
||||
} else {
|
||||
association.setErr(errors.New("invalid value type"))
|
||||
}
|
||||
}
|
||||
return association
|
||||
}
|
||||
|
||||
func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) (results [][]interface{}) {
|
||||
scope := association.Scope
|
||||
|
||||
for _, value := range values {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
if reflectValue.Kind() == reflect.Slice {
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
primaryKeys := []interface{}{}
|
||||
newScope := scope.New(reflectValue.Index(i).Interface())
|
||||
for _, column := range columns {
|
||||
if field, ok := newScope.FieldByName(column); ok {
|
||||
primaryKeys = append(primaryKeys, field.Field.Interface())
|
||||
} else {
|
||||
primaryKeys = append(primaryKeys, "")
|
||||
}
|
||||
}
|
||||
results = append(results, primaryKeys)
|
||||
}
|
||||
} else if reflectValue.Kind() == reflect.Struct {
|
||||
newScope := scope.New(value)
|
||||
var primaryKeys []interface{}
|
||||
for _, column := range columns {
|
||||
if field, ok := newScope.FieldByName(column); ok {
|
||||
primaryKeys = append(primaryKeys, field.Field.Interface())
|
||||
} else {
|
||||
primaryKeys = append(primaryKeys, "")
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, primaryKeys)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func toQueryMarks(primaryValues [][]interface{}) string {
|
||||
var results []string
|
||||
|
||||
for _, primaryValue := range primaryValues {
|
||||
var marks []string
|
||||
for _ = range primaryValue {
|
||||
marks = append(marks, "?")
|
||||
}
|
||||
|
||||
if len(marks) > 1 {
|
||||
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
|
||||
} else {
|
||||
results = append(results, strings.Join(marks, ""))
|
||||
}
|
||||
}
|
||||
return strings.Join(results, ",")
|
||||
}
|
||||
|
||||
func toQueryCondition(scope *Scope, columns []string) string {
|
||||
var newColumns []string
|
||||
for _, column := range columns {
|
||||
newColumns = append(newColumns, scope.Quote(column))
|
||||
}
|
||||
|
||||
if len(columns) > 1 {
|
||||
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
|
||||
}
|
||||
return strings.Join(newColumns, ",")
|
||||
}
|
||||
|
||||
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
|
||||
for _, primaryValue := range primaryValues {
|
||||
for _, value := range primaryValue {
|
||||
values = append(values, value)
|
||||
}
|
||||
}
|
||||
return values
|
||||
}
|
Loading…
Reference in New Issue