Keep refactoring association mode

This commit is contained in:
Jinzhu 2016-01-16 12:18:04 +08:00
parent c84e787b1d
commit dc23ae63bf
4 changed files with 149 additions and 152 deletions

View File

@ -1,27 +1,28 @@
package gorm
import (
"errors"
"fmt"
"reflect"
)
// Association Association Mode contains some helper methods to handle relationship things easily.
// Association Mode contains some helper methods to handle relationship things easily.
type Association struct {
Scope *Scope
Column string
Error error
Field *Field
scope *Scope
column string
field *Field
}
// Find find out all related associations
func (association *Association) Find(value interface{}) *Association {
association.Scope.related(value, association.Column)
return association.setErr(association.Scope.db.Error)
association.scope.related(value, association.column)
return association.setErr(association.scope.db.Error)
}
// Append append new associations for many2many, has_many, will replace current association for has_one, belongs_to
// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to
func (association *Association) Append(values ...interface{}) *Association {
if relationship := association.Field.Relationship; relationship.Kind == "has_one" {
if relationship := association.field.Relationship; relationship.Kind == "has_one" {
return association.Replace(values...)
}
return association.saveAssociations(values...)
@ -30,14 +31,14 @@ func (association *Association) Append(values ...interface{}) *Association {
// Replace replace current associations with new one
func (association *Association) Replace(values ...interface{}) *Association {
var (
relationship = association.Field.Relationship
scope = association.Scope
field = association.Field.Field
relationship = association.field.Relationship
scope = association.scope
field = association.field.Field
newDB = scope.NewDB()
)
// Append new values
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
association.field.Set(reflect.Zero(association.field.Field.Type()))
association.saveAssociations(values...)
// Belongs To
@ -109,7 +110,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
}
}
fieldValue := reflect.New(association.Field.Field.Type()).Interface()
fieldValue := reflect.New(association.field.Field.Type()).Interface()
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
}
}
@ -119,9 +120,9 @@ func (association *Association) Replace(values ...interface{}) *Association {
// Delete remove relationship between source & passed arguments, but won't delete those arguments
func (association *Association) Delete(values ...interface{}) *Association {
var (
relationship = association.Field.Relationship
scope = association.Scope
field = association.Field.Field
relationship = association.field.Relationship
scope = association.scope
field = association.field.Field
newDB = scope.NewDB()
)
@ -196,18 +197,18 @@ func (association *Association) Delete(values ...interface{}) *Association {
)
// set matched relation's foreign key to be null
fieldValue := reflect.New(association.Field.Field.Type()).Interface()
fieldValue := reflect.New(association.field.Field.Type()).Interface()
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
}
}
// Remove deleted records from source's field
if association.Error == nil {
if association.Field.Field.Kind() == reflect.Slice {
leftValues := reflect.Zero(association.Field.Field.Type())
if field.Kind() == reflect.Slice {
leftValues := reflect.Zero(field.Type())
for i := 0; i < association.Field.Field.Len(); i++ {
reflectValue := association.Field.Field.Index(i)
for i := 0; i < field.Len(); i++ {
reflectValue := field.Index(i)
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
var isDeleted = false
for _, pk := range deletingPrimaryKeys {
@ -221,12 +222,12 @@ func (association *Association) Delete(values ...interface{}) *Association {
}
}
association.Field.Set(leftValues)
} else if association.Field.Field.Kind() == reflect.Struct {
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, association.Field.Field.Interface())[0]
association.field.Set(leftValues)
} else if field.Kind() == reflect.Struct {
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0]
for _, pk := range deletingPrimaryKeys {
if equalAsString(primaryKey, pk) {
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
association.field.Set(reflect.Zero(field.Type()))
break
}
}
@ -245,14 +246,14 @@ func (association *Association) Clear() *Association {
func (association *Association) Count() int {
var (
count = 0
relationship = association.Field.Relationship
scope = association.Scope
fieldValue = association.Field.Field.Interface()
relationship = association.field.Relationship
scope = association.scope
fieldValue = association.field.Field.Interface()
query = scope.DB()
)
if relationship.Kind == "many_to_many" {
query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value)
query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
query = query.Where(
@ -277,3 +278,81 @@ func (association *Association) Count() int {
query.Model(fieldValue).Count(&count)
return count
}
// saveAssociations save passed values as associations
func (association *Association) saveAssociations(values ...interface{}) *Association {
var (
scope = association.scope
field = association.field
relationship = field.Relationship
)
saveAssociation := func(reflectValue reflect.Value) {
// value has to been pointer
if reflectValue.Kind() != reflect.Ptr {
reflectPtr := reflect.New(reflectValue.Type())
reflectPtr.Elem().Set(reflectValue)
reflectValue = reflectPtr
}
// value has to been saved for many2many
if relationship.Kind == "many_to_many" {
if scope.New(reflectValue.Interface()).PrimaryKeyZero() {
association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error)
}
}
// Assign Fields
var fieldType = field.Field.Type()
var setFieldBackToValue, setSliceFieldBackToValue bool
if reflectValue.Type().AssignableTo(fieldType) {
field.Set(reflectValue)
} else if reflectValue.Type().Elem().AssignableTo(fieldType) {
// if field's type is struct, then need to set value back to argument after save
setFieldBackToValue = true
field.Set(reflectValue.Elem())
} else if fieldType.Kind() == reflect.Slice {
if reflectValue.Type().AssignableTo(fieldType.Elem()) {
field.Set(reflect.Append(field.Field, reflectValue))
} else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) {
// if field's type is slice of struct, then need to set value back to argument after save
setSliceFieldBackToValue = true
field.Set(reflect.Append(field.Field, reflectValue.Elem()))
}
}
if relationship.Kind == "many_to_many" {
association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
} else {
association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error)
if setFieldBackToValue {
reflectValue.Elem().Set(field.Field)
} else if setSliceFieldBackToValue {
reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1))
}
}
}
for _, value := range values {
reflectValue := reflect.ValueOf(value)
indirectReflectValue := reflect.Indirect(reflectValue)
if indirectReflectValue.Kind() == reflect.Struct {
saveAssociation(reflectValue)
} else if indirectReflectValue.Kind() == reflect.Slice {
for i := 0; i < indirectReflectValue.Len(); i++ {
saveAssociation(indirectReflectValue.Index(i))
}
} else {
association.setErr(errors.New("invalid value type"))
}
}
return association
}
func (association *Association) setErr(err error) *Association {
if err != nil {
association.Error = err
}
return association
}

View File

@ -1,122 +0,0 @@
package gorm
import (
"errors"
"fmt"
"reflect"
"strings"
)
func (association *Association) setErr(err error) *Association {
if err != nil {
association.Error = err
}
return association
}
func (association *Association) saveAssociations(values ...interface{}) *Association {
scope := association.Scope
field := association.Field
relationship := association.Field.Relationship
saveAssociation := func(reflectValue reflect.Value) {
// value has to been pointer
if reflectValue.Kind() != reflect.Ptr {
reflectPtr := reflect.New(reflectValue.Type())
reflectPtr.Elem().Set(reflectValue)
reflectValue = reflectPtr
}
// value has to been saved for many2many
if relationship.Kind == "many_to_many" {
if scope.New(reflectValue.Interface()).PrimaryKeyZero() {
association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error)
}
}
// Assign Fields
var fieldType = field.Field.Type()
var setFieldBackToValue, setSliceFieldBackToValue bool
if reflectValue.Type().AssignableTo(fieldType) {
field.Set(reflectValue)
} else if reflectValue.Type().Elem().AssignableTo(fieldType) {
// if field's type is struct, then need to set value back to argument after save
setFieldBackToValue = true
field.Set(reflectValue.Elem())
} else if fieldType.Kind() == reflect.Slice {
if reflectValue.Type().AssignableTo(fieldType.Elem()) {
field.Set(reflect.Append(field.Field, reflectValue))
} else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) {
// if field's type is slice of struct, then need to set value back to argument after save
setSliceFieldBackToValue = true
field.Set(reflect.Append(field.Field, reflectValue.Elem()))
}
}
if relationship.Kind == "many_to_many" {
association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
} else {
association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error)
if setFieldBackToValue {
reflectValue.Elem().Set(field.Field)
} else if setSliceFieldBackToValue {
reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1))
}
}
}
for _, value := range values {
reflectValue := reflect.ValueOf(value)
indirectReflectValue := reflect.Indirect(reflectValue)
if indirectReflectValue.Kind() == reflect.Struct {
saveAssociation(reflectValue)
} else if indirectReflectValue.Kind() == reflect.Slice {
for i := 0; i < indirectReflectValue.Len(); i++ {
saveAssociation(indirectReflectValue.Index(i))
}
} else {
association.setErr(errors.New("invalid value type"))
}
}
return association
}
func toQueryMarks(primaryValues [][]interface{}) string {
var results []string
for _, primaryValue := range primaryValues {
var marks []string
for _ = range primaryValue {
marks = append(marks, "?")
}
if len(marks) > 1 {
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
} else {
results = append(results, strings.Join(marks, ""))
}
}
return strings.Join(results, ",")
}
func toQueryCondition(scope *Scope, columns []string) string {
var newColumns []string
for _, column := range columns {
newColumns = append(newColumns, scope.Quote(column))
}
if len(columns) > 1 {
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
}
return strings.Join(newColumns, ",")
}
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
for _, primaryValue := range primaryValues {
for _, value := range primaryValue {
values = append(values, value)
}
}
return values
}

View File

@ -480,7 +480,7 @@ func (s *DB) Association(column string) *Association {
if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 {
err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
} else {
return &Association{Scope: scope, Column: column, Field: field}
return &Association{scope: scope, column: column, field: field}
}
} else {
err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)

View File

@ -2,6 +2,7 @@ package gorm
import (
"bytes"
"fmt"
"strings"
"sync"
)
@ -100,3 +101,42 @@ type expr struct {
func Expr(expression string, args ...interface{}) *expr {
return &expr{expr: expression, args: args}
}
func toQueryMarks(primaryValues [][]interface{}) string {
var results []string
for _, primaryValue := range primaryValues {
var marks []string
for _ = range primaryValue {
marks = append(marks, "?")
}
if len(marks) > 1 {
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
} else {
results = append(results, strings.Join(marks, ""))
}
}
return strings.Join(results, ",")
}
func toQueryCondition(scope *Scope, columns []string) string {
var newColumns []string
for _, column := range columns {
newColumns = append(newColumns, scope.Quote(column))
}
if len(columns) > 1 {
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
}
return strings.Join(newColumns, ",")
}
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
for _, primaryValue := range primaryValues {
for _, value := range primaryValue {
values = append(values, value)
}
}
return values
}