mirror of https://github.com/go-gorm/gorm.git
Keep refactoring association mode
This commit is contained in:
parent
c84e787b1d
commit
dc23ae63bf
137
association.go
137
association.go
|
@ -1,27 +1,28 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Association Association Mode contains some helper methods to handle relationship things easily.
|
// Association Mode contains some helper methods to handle relationship things easily.
|
||||||
type Association struct {
|
type Association struct {
|
||||||
Scope *Scope
|
|
||||||
Column string
|
|
||||||
Error error
|
Error error
|
||||||
Field *Field
|
scope *Scope
|
||||||
|
column string
|
||||||
|
field *Field
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find find out all related associations
|
// Find find out all related associations
|
||||||
func (association *Association) Find(value interface{}) *Association {
|
func (association *Association) Find(value interface{}) *Association {
|
||||||
association.Scope.related(value, association.Column)
|
association.scope.related(value, association.column)
|
||||||
return association.setErr(association.Scope.db.Error)
|
return association.setErr(association.scope.db.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Append append new associations for many2many, has_many, will replace current association for has_one, belongs_to
|
// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to
|
||||||
func (association *Association) Append(values ...interface{}) *Association {
|
func (association *Association) Append(values ...interface{}) *Association {
|
||||||
if relationship := association.Field.Relationship; relationship.Kind == "has_one" {
|
if relationship := association.field.Relationship; relationship.Kind == "has_one" {
|
||||||
return association.Replace(values...)
|
return association.Replace(values...)
|
||||||
}
|
}
|
||||||
return association.saveAssociations(values...)
|
return association.saveAssociations(values...)
|
||||||
|
@ -30,14 +31,14 @@ func (association *Association) Append(values ...interface{}) *Association {
|
||||||
// Replace replace current associations with new one
|
// Replace replace current associations with new one
|
||||||
func (association *Association) Replace(values ...interface{}) *Association {
|
func (association *Association) Replace(values ...interface{}) *Association {
|
||||||
var (
|
var (
|
||||||
relationship = association.Field.Relationship
|
relationship = association.field.Relationship
|
||||||
scope = association.Scope
|
scope = association.scope
|
||||||
field = association.Field.Field
|
field = association.field.Field
|
||||||
newDB = scope.NewDB()
|
newDB = scope.NewDB()
|
||||||
)
|
)
|
||||||
|
|
||||||
// Append new values
|
// Append new values
|
||||||
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
|
association.field.Set(reflect.Zero(association.field.Field.Type()))
|
||||||
association.saveAssociations(values...)
|
association.saveAssociations(values...)
|
||||||
|
|
||||||
// Belongs To
|
// Belongs To
|
||||||
|
@ -109,7 +110,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fieldValue := reflect.New(association.Field.Field.Type()).Interface()
|
fieldValue := reflect.New(association.field.Field.Type()).Interface()
|
||||||
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
|
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -119,9 +120,9 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
||||||
// Delete remove relationship between source & passed arguments, but won't delete those arguments
|
// Delete remove relationship between source & passed arguments, but won't delete those arguments
|
||||||
func (association *Association) Delete(values ...interface{}) *Association {
|
func (association *Association) Delete(values ...interface{}) *Association {
|
||||||
var (
|
var (
|
||||||
relationship = association.Field.Relationship
|
relationship = association.field.Relationship
|
||||||
scope = association.Scope
|
scope = association.scope
|
||||||
field = association.Field.Field
|
field = association.field.Field
|
||||||
newDB = scope.NewDB()
|
newDB = scope.NewDB()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -196,18 +197,18 @@ func (association *Association) Delete(values ...interface{}) *Association {
|
||||||
)
|
)
|
||||||
|
|
||||||
// set matched relation's foreign key to be null
|
// set matched relation's foreign key to be null
|
||||||
fieldValue := reflect.New(association.Field.Field.Type()).Interface()
|
fieldValue := reflect.New(association.field.Field.Type()).Interface()
|
||||||
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
|
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove deleted records from source's field
|
// Remove deleted records from source's field
|
||||||
if association.Error == nil {
|
if association.Error == nil {
|
||||||
if association.Field.Field.Kind() == reflect.Slice {
|
if field.Kind() == reflect.Slice {
|
||||||
leftValues := reflect.Zero(association.Field.Field.Type())
|
leftValues := reflect.Zero(field.Type())
|
||||||
|
|
||||||
for i := 0; i < association.Field.Field.Len(); i++ {
|
for i := 0; i < field.Len(); i++ {
|
||||||
reflectValue := association.Field.Field.Index(i)
|
reflectValue := field.Index(i)
|
||||||
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
|
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
|
||||||
var isDeleted = false
|
var isDeleted = false
|
||||||
for _, pk := range deletingPrimaryKeys {
|
for _, pk := range deletingPrimaryKeys {
|
||||||
|
@ -221,12 +222,12 @@ func (association *Association) Delete(values ...interface{}) *Association {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
association.Field.Set(leftValues)
|
association.field.Set(leftValues)
|
||||||
} else if association.Field.Field.Kind() == reflect.Struct {
|
} else if field.Kind() == reflect.Struct {
|
||||||
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, association.Field.Field.Interface())[0]
|
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0]
|
||||||
for _, pk := range deletingPrimaryKeys {
|
for _, pk := range deletingPrimaryKeys {
|
||||||
if equalAsString(primaryKey, pk) {
|
if equalAsString(primaryKey, pk) {
|
||||||
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
|
association.field.Set(reflect.Zero(field.Type()))
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -245,14 +246,14 @@ func (association *Association) Clear() *Association {
|
||||||
func (association *Association) Count() int {
|
func (association *Association) Count() int {
|
||||||
var (
|
var (
|
||||||
count = 0
|
count = 0
|
||||||
relationship = association.Field.Relationship
|
relationship = association.field.Relationship
|
||||||
scope = association.Scope
|
scope = association.scope
|
||||||
fieldValue = association.Field.Field.Interface()
|
fieldValue = association.field.Field.Interface()
|
||||||
query = scope.DB()
|
query = scope.DB()
|
||||||
)
|
)
|
||||||
|
|
||||||
if relationship.Kind == "many_to_many" {
|
if relationship.Kind == "many_to_many" {
|
||||||
query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value)
|
query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
|
||||||
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
||||||
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
|
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
|
||||||
query = query.Where(
|
query = query.Where(
|
||||||
|
@ -277,3 +278,81 @@ func (association *Association) Count() int {
|
||||||
query.Model(fieldValue).Count(&count)
|
query.Model(fieldValue).Count(&count)
|
||||||
return count
|
return count
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// saveAssociations save passed values as associations
|
||||||
|
func (association *Association) saveAssociations(values ...interface{}) *Association {
|
||||||
|
var (
|
||||||
|
scope = association.scope
|
||||||
|
field = association.field
|
||||||
|
relationship = field.Relationship
|
||||||
|
)
|
||||||
|
|
||||||
|
saveAssociation := func(reflectValue reflect.Value) {
|
||||||
|
// value has to been pointer
|
||||||
|
if reflectValue.Kind() != reflect.Ptr {
|
||||||
|
reflectPtr := reflect.New(reflectValue.Type())
|
||||||
|
reflectPtr.Elem().Set(reflectValue)
|
||||||
|
reflectValue = reflectPtr
|
||||||
|
}
|
||||||
|
|
||||||
|
// value has to been saved for many2many
|
||||||
|
if relationship.Kind == "many_to_many" {
|
||||||
|
if scope.New(reflectValue.Interface()).PrimaryKeyZero() {
|
||||||
|
association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign Fields
|
||||||
|
var fieldType = field.Field.Type()
|
||||||
|
var setFieldBackToValue, setSliceFieldBackToValue bool
|
||||||
|
if reflectValue.Type().AssignableTo(fieldType) {
|
||||||
|
field.Set(reflectValue)
|
||||||
|
} else if reflectValue.Type().Elem().AssignableTo(fieldType) {
|
||||||
|
// if field's type is struct, then need to set value back to argument after save
|
||||||
|
setFieldBackToValue = true
|
||||||
|
field.Set(reflectValue.Elem())
|
||||||
|
} else if fieldType.Kind() == reflect.Slice {
|
||||||
|
if reflectValue.Type().AssignableTo(fieldType.Elem()) {
|
||||||
|
field.Set(reflect.Append(field.Field, reflectValue))
|
||||||
|
} else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) {
|
||||||
|
// if field's type is slice of struct, then need to set value back to argument after save
|
||||||
|
setSliceFieldBackToValue = true
|
||||||
|
field.Set(reflect.Append(field.Field, reflectValue.Elem()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if relationship.Kind == "many_to_many" {
|
||||||
|
association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
|
||||||
|
} else {
|
||||||
|
association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error)
|
||||||
|
|
||||||
|
if setFieldBackToValue {
|
||||||
|
reflectValue.Elem().Set(field.Field)
|
||||||
|
} else if setSliceFieldBackToValue {
|
||||||
|
reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, value := range values {
|
||||||
|
reflectValue := reflect.ValueOf(value)
|
||||||
|
indirectReflectValue := reflect.Indirect(reflectValue)
|
||||||
|
if indirectReflectValue.Kind() == reflect.Struct {
|
||||||
|
saveAssociation(reflectValue)
|
||||||
|
} else if indirectReflectValue.Kind() == reflect.Slice {
|
||||||
|
for i := 0; i < indirectReflectValue.Len(); i++ {
|
||||||
|
saveAssociation(indirectReflectValue.Index(i))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
association.setErr(errors.New("invalid value type"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return association
|
||||||
|
}
|
||||||
|
|
||||||
|
func (association *Association) setErr(err error) *Association {
|
||||||
|
if err != nil {
|
||||||
|
association.Error = err
|
||||||
|
}
|
||||||
|
return association
|
||||||
|
}
|
||||||
|
|
|
@ -1,122 +0,0 @@
|
||||||
package gorm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (association *Association) setErr(err error) *Association {
|
|
||||||
if err != nil {
|
|
||||||
association.Error = err
|
|
||||||
}
|
|
||||||
return association
|
|
||||||
}
|
|
||||||
|
|
||||||
func (association *Association) saveAssociations(values ...interface{}) *Association {
|
|
||||||
scope := association.Scope
|
|
||||||
field := association.Field
|
|
||||||
relationship := association.Field.Relationship
|
|
||||||
|
|
||||||
saveAssociation := func(reflectValue reflect.Value) {
|
|
||||||
// value has to been pointer
|
|
||||||
if reflectValue.Kind() != reflect.Ptr {
|
|
||||||
reflectPtr := reflect.New(reflectValue.Type())
|
|
||||||
reflectPtr.Elem().Set(reflectValue)
|
|
||||||
reflectValue = reflectPtr
|
|
||||||
}
|
|
||||||
|
|
||||||
// value has to been saved for many2many
|
|
||||||
if relationship.Kind == "many_to_many" {
|
|
||||||
if scope.New(reflectValue.Interface()).PrimaryKeyZero() {
|
|
||||||
association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assign Fields
|
|
||||||
var fieldType = field.Field.Type()
|
|
||||||
var setFieldBackToValue, setSliceFieldBackToValue bool
|
|
||||||
if reflectValue.Type().AssignableTo(fieldType) {
|
|
||||||
field.Set(reflectValue)
|
|
||||||
} else if reflectValue.Type().Elem().AssignableTo(fieldType) {
|
|
||||||
// if field's type is struct, then need to set value back to argument after save
|
|
||||||
setFieldBackToValue = true
|
|
||||||
field.Set(reflectValue.Elem())
|
|
||||||
} else if fieldType.Kind() == reflect.Slice {
|
|
||||||
if reflectValue.Type().AssignableTo(fieldType.Elem()) {
|
|
||||||
field.Set(reflect.Append(field.Field, reflectValue))
|
|
||||||
} else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) {
|
|
||||||
// if field's type is slice of struct, then need to set value back to argument after save
|
|
||||||
setSliceFieldBackToValue = true
|
|
||||||
field.Set(reflect.Append(field.Field, reflectValue.Elem()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if relationship.Kind == "many_to_many" {
|
|
||||||
association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
|
|
||||||
} else {
|
|
||||||
association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error)
|
|
||||||
|
|
||||||
if setFieldBackToValue {
|
|
||||||
reflectValue.Elem().Set(field.Field)
|
|
||||||
} else if setSliceFieldBackToValue {
|
|
||||||
reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, value := range values {
|
|
||||||
reflectValue := reflect.ValueOf(value)
|
|
||||||
indirectReflectValue := reflect.Indirect(reflectValue)
|
|
||||||
if indirectReflectValue.Kind() == reflect.Struct {
|
|
||||||
saveAssociation(reflectValue)
|
|
||||||
} else if indirectReflectValue.Kind() == reflect.Slice {
|
|
||||||
for i := 0; i < indirectReflectValue.Len(); i++ {
|
|
||||||
saveAssociation(indirectReflectValue.Index(i))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
association.setErr(errors.New("invalid value type"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return association
|
|
||||||
}
|
|
||||||
|
|
||||||
func toQueryMarks(primaryValues [][]interface{}) string {
|
|
||||||
var results []string
|
|
||||||
|
|
||||||
for _, primaryValue := range primaryValues {
|
|
||||||
var marks []string
|
|
||||||
for _ = range primaryValue {
|
|
||||||
marks = append(marks, "?")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(marks) > 1 {
|
|
||||||
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
|
|
||||||
} else {
|
|
||||||
results = append(results, strings.Join(marks, ""))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strings.Join(results, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
func toQueryCondition(scope *Scope, columns []string) string {
|
|
||||||
var newColumns []string
|
|
||||||
for _, column := range columns {
|
|
||||||
newColumns = append(newColumns, scope.Quote(column))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(columns) > 1 {
|
|
||||||
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
|
|
||||||
}
|
|
||||||
return strings.Join(newColumns, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
|
|
||||||
for _, primaryValue := range primaryValues {
|
|
||||||
for _, value := range primaryValue {
|
|
||||||
values = append(values, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return values
|
|
||||||
}
|
|
2
main.go
2
main.go
|
@ -480,7 +480,7 @@ func (s *DB) Association(column string) *Association {
|
||||||
if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 {
|
if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 {
|
||||||
err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
|
err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
|
||||||
} else {
|
} else {
|
||||||
return &Association{Scope: scope, Column: column, Field: field}
|
return &Association{scope: scope, column: column, field: field}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)
|
err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)
|
||||||
|
|
40
utils.go
40
utils.go
|
@ -2,6 +2,7 @@ package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
@ -100,3 +101,42 @@ type expr struct {
|
||||||
func Expr(expression string, args ...interface{}) *expr {
|
func Expr(expression string, args ...interface{}) *expr {
|
||||||
return &expr{expr: expression, args: args}
|
return &expr{expr: expression, args: args}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toQueryMarks(primaryValues [][]interface{}) string {
|
||||||
|
var results []string
|
||||||
|
|
||||||
|
for _, primaryValue := range primaryValues {
|
||||||
|
var marks []string
|
||||||
|
for _ = range primaryValue {
|
||||||
|
marks = append(marks, "?")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(marks) > 1 {
|
||||||
|
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
|
||||||
|
} else {
|
||||||
|
results = append(results, strings.Join(marks, ""))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(results, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
func toQueryCondition(scope *Scope, columns []string) string {
|
||||||
|
var newColumns []string
|
||||||
|
for _, column := range columns {
|
||||||
|
newColumns = append(newColumns, scope.Quote(column))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(columns) > 1 {
|
||||||
|
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
|
||||||
|
}
|
||||||
|
return strings.Join(newColumns, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
|
||||||
|
for _, primaryValue := range primaryValues {
|
||||||
|
for _, value := range primaryValue {
|
||||||
|
values = append(values, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue