mirror of https://github.com/go-gorm/gorm.git
Keep refactoring association mode
This commit is contained in:
parent
c84e787b1d
commit
dc23ae63bf
137
association.go
137
association.go
|
@ -1,27 +1,28 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Association Association Mode contains some helper methods to handle relationship things easily.
|
||||
// Association Mode contains some helper methods to handle relationship things easily.
|
||||
type Association struct {
|
||||
Scope *Scope
|
||||
Column string
|
||||
Error error
|
||||
Field *Field
|
||||
scope *Scope
|
||||
column string
|
||||
field *Field
|
||||
}
|
||||
|
||||
// Find find out all related associations
|
||||
func (association *Association) Find(value interface{}) *Association {
|
||||
association.Scope.related(value, association.Column)
|
||||
return association.setErr(association.Scope.db.Error)
|
||||
association.scope.related(value, association.column)
|
||||
return association.setErr(association.scope.db.Error)
|
||||
}
|
||||
|
||||
// Append append new associations for many2many, has_many, will replace current association for has_one, belongs_to
|
||||
// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to
|
||||
func (association *Association) Append(values ...interface{}) *Association {
|
||||
if relationship := association.Field.Relationship; relationship.Kind == "has_one" {
|
||||
if relationship := association.field.Relationship; relationship.Kind == "has_one" {
|
||||
return association.Replace(values...)
|
||||
}
|
||||
return association.saveAssociations(values...)
|
||||
|
@ -30,14 +31,14 @@ func (association *Association) Append(values ...interface{}) *Association {
|
|||
// Replace replace current associations with new one
|
||||
func (association *Association) Replace(values ...interface{}) *Association {
|
||||
var (
|
||||
relationship = association.Field.Relationship
|
||||
scope = association.Scope
|
||||
field = association.Field.Field
|
||||
relationship = association.field.Relationship
|
||||
scope = association.scope
|
||||
field = association.field.Field
|
||||
newDB = scope.NewDB()
|
||||
)
|
||||
|
||||
// Append new values
|
||||
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
|
||||
association.field.Set(reflect.Zero(association.field.Field.Type()))
|
||||
association.saveAssociations(values...)
|
||||
|
||||
// Belongs To
|
||||
|
@ -109,7 +110,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
|||
}
|
||||
}
|
||||
|
||||
fieldValue := reflect.New(association.Field.Field.Type()).Interface()
|
||||
fieldValue := reflect.New(association.field.Field.Type()).Interface()
|
||||
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
|
||||
}
|
||||
}
|
||||
|
@ -119,9 +120,9 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
|||
// Delete remove relationship between source & passed arguments, but won't delete those arguments
|
||||
func (association *Association) Delete(values ...interface{}) *Association {
|
||||
var (
|
||||
relationship = association.Field.Relationship
|
||||
scope = association.Scope
|
||||
field = association.Field.Field
|
||||
relationship = association.field.Relationship
|
||||
scope = association.scope
|
||||
field = association.field.Field
|
||||
newDB = scope.NewDB()
|
||||
)
|
||||
|
||||
|
@ -196,18 +197,18 @@ func (association *Association) Delete(values ...interface{}) *Association {
|
|||
)
|
||||
|
||||
// set matched relation's foreign key to be null
|
||||
fieldValue := reflect.New(association.Field.Field.Type()).Interface()
|
||||
fieldValue := reflect.New(association.field.Field.Type()).Interface()
|
||||
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove deleted records from source's field
|
||||
if association.Error == nil {
|
||||
if association.Field.Field.Kind() == reflect.Slice {
|
||||
leftValues := reflect.Zero(association.Field.Field.Type())
|
||||
if field.Kind() == reflect.Slice {
|
||||
leftValues := reflect.Zero(field.Type())
|
||||
|
||||
for i := 0; i < association.Field.Field.Len(); i++ {
|
||||
reflectValue := association.Field.Field.Index(i)
|
||||
for i := 0; i < field.Len(); i++ {
|
||||
reflectValue := field.Index(i)
|
||||
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
|
||||
var isDeleted = false
|
||||
for _, pk := range deletingPrimaryKeys {
|
||||
|
@ -221,12 +222,12 @@ func (association *Association) Delete(values ...interface{}) *Association {
|
|||
}
|
||||
}
|
||||
|
||||
association.Field.Set(leftValues)
|
||||
} else if association.Field.Field.Kind() == reflect.Struct {
|
||||
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, association.Field.Field.Interface())[0]
|
||||
association.field.Set(leftValues)
|
||||
} else if field.Kind() == reflect.Struct {
|
||||
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0]
|
||||
for _, pk := range deletingPrimaryKeys {
|
||||
if equalAsString(primaryKey, pk) {
|
||||
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
|
||||
association.field.Set(reflect.Zero(field.Type()))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -245,14 +246,14 @@ func (association *Association) Clear() *Association {
|
|||
func (association *Association) Count() int {
|
||||
var (
|
||||
count = 0
|
||||
relationship = association.Field.Relationship
|
||||
scope = association.Scope
|
||||
fieldValue = association.Field.Field.Interface()
|
||||
relationship = association.field.Relationship
|
||||
scope = association.scope
|
||||
fieldValue = association.field.Field.Interface()
|
||||
query = scope.DB()
|
||||
)
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value)
|
||||
query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
|
||||
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
||||
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
|
||||
query = query.Where(
|
||||
|
@ -277,3 +278,81 @@ func (association *Association) Count() int {
|
|||
query.Model(fieldValue).Count(&count)
|
||||
return count
|
||||
}
|
||||
|
||||
// saveAssociations save passed values as associations
|
||||
func (association *Association) saveAssociations(values ...interface{}) *Association {
|
||||
var (
|
||||
scope = association.scope
|
||||
field = association.field
|
||||
relationship = field.Relationship
|
||||
)
|
||||
|
||||
saveAssociation := func(reflectValue reflect.Value) {
|
||||
// value has to been pointer
|
||||
if reflectValue.Kind() != reflect.Ptr {
|
||||
reflectPtr := reflect.New(reflectValue.Type())
|
||||
reflectPtr.Elem().Set(reflectValue)
|
||||
reflectValue = reflectPtr
|
||||
}
|
||||
|
||||
// value has to been saved for many2many
|
||||
if relationship.Kind == "many_to_many" {
|
||||
if scope.New(reflectValue.Interface()).PrimaryKeyZero() {
|
||||
association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error)
|
||||
}
|
||||
}
|
||||
|
||||
// Assign Fields
|
||||
var fieldType = field.Field.Type()
|
||||
var setFieldBackToValue, setSliceFieldBackToValue bool
|
||||
if reflectValue.Type().AssignableTo(fieldType) {
|
||||
field.Set(reflectValue)
|
||||
} else if reflectValue.Type().Elem().AssignableTo(fieldType) {
|
||||
// if field's type is struct, then need to set value back to argument after save
|
||||
setFieldBackToValue = true
|
||||
field.Set(reflectValue.Elem())
|
||||
} else if fieldType.Kind() == reflect.Slice {
|
||||
if reflectValue.Type().AssignableTo(fieldType.Elem()) {
|
||||
field.Set(reflect.Append(field.Field, reflectValue))
|
||||
} else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) {
|
||||
// if field's type is slice of struct, then need to set value back to argument after save
|
||||
setSliceFieldBackToValue = true
|
||||
field.Set(reflect.Append(field.Field, reflectValue.Elem()))
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
|
||||
} else {
|
||||
association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error)
|
||||
|
||||
if setFieldBackToValue {
|
||||
reflectValue.Elem().Set(field.Field)
|
||||
} else if setSliceFieldBackToValue {
|
||||
reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
reflectValue := reflect.ValueOf(value)
|
||||
indirectReflectValue := reflect.Indirect(reflectValue)
|
||||
if indirectReflectValue.Kind() == reflect.Struct {
|
||||
saveAssociation(reflectValue)
|
||||
} else if indirectReflectValue.Kind() == reflect.Slice {
|
||||
for i := 0; i < indirectReflectValue.Len(); i++ {
|
||||
saveAssociation(indirectReflectValue.Index(i))
|
||||
}
|
||||
} else {
|
||||
association.setErr(errors.New("invalid value type"))
|
||||
}
|
||||
}
|
||||
return association
|
||||
}
|
||||
|
||||
func (association *Association) setErr(err error) *Association {
|
||||
if err != nil {
|
||||
association.Error = err
|
||||
}
|
||||
return association
|
||||
}
|
||||
|
|
|
@ -1,122 +0,0 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (association *Association) setErr(err error) *Association {
|
||||
if err != nil {
|
||||
association.Error = err
|
||||
}
|
||||
return association
|
||||
}
|
||||
|
||||
func (association *Association) saveAssociations(values ...interface{}) *Association {
|
||||
scope := association.Scope
|
||||
field := association.Field
|
||||
relationship := association.Field.Relationship
|
||||
|
||||
saveAssociation := func(reflectValue reflect.Value) {
|
||||
// value has to been pointer
|
||||
if reflectValue.Kind() != reflect.Ptr {
|
||||
reflectPtr := reflect.New(reflectValue.Type())
|
||||
reflectPtr.Elem().Set(reflectValue)
|
||||
reflectValue = reflectPtr
|
||||
}
|
||||
|
||||
// value has to been saved for many2many
|
||||
if relationship.Kind == "many_to_many" {
|
||||
if scope.New(reflectValue.Interface()).PrimaryKeyZero() {
|
||||
association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error)
|
||||
}
|
||||
}
|
||||
|
||||
// Assign Fields
|
||||
var fieldType = field.Field.Type()
|
||||
var setFieldBackToValue, setSliceFieldBackToValue bool
|
||||
if reflectValue.Type().AssignableTo(fieldType) {
|
||||
field.Set(reflectValue)
|
||||
} else if reflectValue.Type().Elem().AssignableTo(fieldType) {
|
||||
// if field's type is struct, then need to set value back to argument after save
|
||||
setFieldBackToValue = true
|
||||
field.Set(reflectValue.Elem())
|
||||
} else if fieldType.Kind() == reflect.Slice {
|
||||
if reflectValue.Type().AssignableTo(fieldType.Elem()) {
|
||||
field.Set(reflect.Append(field.Field, reflectValue))
|
||||
} else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) {
|
||||
// if field's type is slice of struct, then need to set value back to argument after save
|
||||
setSliceFieldBackToValue = true
|
||||
field.Set(reflect.Append(field.Field, reflectValue.Elem()))
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
|
||||
} else {
|
||||
association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error)
|
||||
|
||||
if setFieldBackToValue {
|
||||
reflectValue.Elem().Set(field.Field)
|
||||
} else if setSliceFieldBackToValue {
|
||||
reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
reflectValue := reflect.ValueOf(value)
|
||||
indirectReflectValue := reflect.Indirect(reflectValue)
|
||||
if indirectReflectValue.Kind() == reflect.Struct {
|
||||
saveAssociation(reflectValue)
|
||||
} else if indirectReflectValue.Kind() == reflect.Slice {
|
||||
for i := 0; i < indirectReflectValue.Len(); i++ {
|
||||
saveAssociation(indirectReflectValue.Index(i))
|
||||
}
|
||||
} else {
|
||||
association.setErr(errors.New("invalid value type"))
|
||||
}
|
||||
}
|
||||
return association
|
||||
}
|
||||
|
||||
func toQueryMarks(primaryValues [][]interface{}) string {
|
||||
var results []string
|
||||
|
||||
for _, primaryValue := range primaryValues {
|
||||
var marks []string
|
||||
for _ = range primaryValue {
|
||||
marks = append(marks, "?")
|
||||
}
|
||||
|
||||
if len(marks) > 1 {
|
||||
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
|
||||
} else {
|
||||
results = append(results, strings.Join(marks, ""))
|
||||
}
|
||||
}
|
||||
return strings.Join(results, ",")
|
||||
}
|
||||
|
||||
func toQueryCondition(scope *Scope, columns []string) string {
|
||||
var newColumns []string
|
||||
for _, column := range columns {
|
||||
newColumns = append(newColumns, scope.Quote(column))
|
||||
}
|
||||
|
||||
if len(columns) > 1 {
|
||||
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
|
||||
}
|
||||
return strings.Join(newColumns, ",")
|
||||
}
|
||||
|
||||
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
|
||||
for _, primaryValue := range primaryValues {
|
||||
for _, value := range primaryValue {
|
||||
values = append(values, value)
|
||||
}
|
||||
}
|
||||
return values
|
||||
}
|
2
main.go
2
main.go
|
@ -480,7 +480,7 @@ func (s *DB) Association(column string) *Association {
|
|||
if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 {
|
||||
err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
|
||||
} else {
|
||||
return &Association{Scope: scope, Column: column, Field: field}
|
||||
return &Association{scope: scope, column: column, field: field}
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)
|
||||
|
|
40
utils.go
40
utils.go
|
@ -2,6 +2,7 @@ package gorm
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
@ -100,3 +101,42 @@ type expr struct {
|
|||
func Expr(expression string, args ...interface{}) *expr {
|
||||
return &expr{expr: expression, args: args}
|
||||
}
|
||||
|
||||
func toQueryMarks(primaryValues [][]interface{}) string {
|
||||
var results []string
|
||||
|
||||
for _, primaryValue := range primaryValues {
|
||||
var marks []string
|
||||
for _ = range primaryValue {
|
||||
marks = append(marks, "?")
|
||||
}
|
||||
|
||||
if len(marks) > 1 {
|
||||
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
|
||||
} else {
|
||||
results = append(results, strings.Join(marks, ""))
|
||||
}
|
||||
}
|
||||
return strings.Join(results, ",")
|
||||
}
|
||||
|
||||
func toQueryCondition(scope *Scope, columns []string) string {
|
||||
var newColumns []string
|
||||
for _, column := range columns {
|
||||
newColumns = append(newColumns, scope.Quote(column))
|
||||
}
|
||||
|
||||
if len(columns) > 1 {
|
||||
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
|
||||
}
|
||||
return strings.Join(newColumns, ",")
|
||||
}
|
||||
|
||||
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
|
||||
for _, primaryValue := range primaryValues {
|
||||
for _, value := range primaryValue {
|
||||
values = append(values, value)
|
||||
}
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue