mirror of https://github.com/go-gorm/gorm.git
Finish implement association support
This commit is contained in:
parent
20cb57b1ac
commit
0f21272c7f
198
association.go
198
association.go
|
@ -1,6 +1,7 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
|
@ -34,16 +35,119 @@ func (db *DB) Association(column string) *Association {
|
||||||
|
|
||||||
func (association *Association) Find(out interface{}, conds ...interface{}) error {
|
func (association *Association) Find(out interface{}, conds ...interface{}) error {
|
||||||
if association.Error == nil {
|
if association.Error == nil {
|
||||||
|
var (
|
||||||
|
tx = association.DB
|
||||||
|
queryConds = association.Relationship.ToQueryConditions(tx.Statement.ReflectValue)
|
||||||
|
)
|
||||||
|
|
||||||
|
if association.Relationship.JoinTable != nil {
|
||||||
|
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
||||||
|
tx.Clauses(queryClause)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.Clauses(clause.From{Joins: []clause.Join{{
|
||||||
|
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
||||||
|
ON: clause.Where{Exprs: queryConds},
|
||||||
|
}}})
|
||||||
|
} else {
|
||||||
|
tx.Clauses(clause.Where{Exprs: queryConds})
|
||||||
|
}
|
||||||
|
|
||||||
|
association.Error = tx.Find(out, conds...).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
return association.Error
|
return association.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (association *Association) Append(values ...interface{}) error {
|
func (association *Association) Append(values ...interface{}) error {
|
||||||
|
if association.Error == nil {
|
||||||
|
switch association.Relationship.Type {
|
||||||
|
case schema.HasOne, schema.BelongsTo:
|
||||||
|
if len(values) > 0 {
|
||||||
|
association.Error = association.Replace(values...)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
association.saveAssociation(false, values...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return association.Error
|
return association.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (association *Association) Replace(values ...interface{}) error {
|
func (association *Association) Replace(values ...interface{}) error {
|
||||||
|
if association.Error == nil {
|
||||||
|
association.saveAssociation(true, values...)
|
||||||
|
reflectValue := association.DB.Statement.ReflectValue
|
||||||
|
rel := association.Relationship
|
||||||
|
|
||||||
|
switch rel.Type {
|
||||||
|
case schema.HasOne, schema.HasMany:
|
||||||
|
var (
|
||||||
|
primaryFields []*schema.Field
|
||||||
|
foreignKeys []string
|
||||||
|
updateMap = map[string]interface{}{}
|
||||||
|
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, ref := range rel.References {
|
||||||
|
if ref.OwnPrimaryKey {
|
||||||
|
primaryFields = append(primaryFields, ref.PrimaryKey)
|
||||||
|
} else {
|
||||||
|
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
|
||||||
|
updateMap[ref.ForeignKey.DBName] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, values := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
|
||||||
|
column, queryValues := schema.ToQueryValues(foreignKeys, values)
|
||||||
|
association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap)
|
||||||
|
case schema.Many2Many:
|
||||||
|
var primaryFields, relPrimaryFields []*schema.Field
|
||||||
|
var foreignKeys, relForeignKeys []string
|
||||||
|
modelValue := reflect.New(rel.JoinTable.ModelType).Interface()
|
||||||
|
conds := []clause.Expression{}
|
||||||
|
|
||||||
|
for _, ref := range rel.References {
|
||||||
|
if ref.OwnPrimaryKey {
|
||||||
|
primaryFields = append(primaryFields, ref.PrimaryKey)
|
||||||
|
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
|
||||||
|
} else if ref.PrimaryValue != "" {
|
||||||
|
conds = append(conds, clause.Eq{
|
||||||
|
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
|
||||||
|
Value: ref.PrimaryValue,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
|
||||||
|
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
generateConds := func(rv reflect.Value) {
|
||||||
|
_, values := schema.GetIdentityFieldValuesMap(rv, primaryFields)
|
||||||
|
column, queryValues := schema.ToQueryValues(foreignKeys, values)
|
||||||
|
|
||||||
|
relValue := rel.Field.ReflectValueOf(rv)
|
||||||
|
_, relValues := schema.GetIdentityFieldValuesMap(relValue, relPrimaryFields)
|
||||||
|
relColumn, relQueryValues := schema.ToQueryValues(relForeignKeys, relValues)
|
||||||
|
|
||||||
|
conds = append(conds, clause.And(
|
||||||
|
clause.IN{Column: column, Values: queryValues},
|
||||||
|
clause.Not(clause.IN{Column: relColumn, Values: relQueryValues}),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
switch reflectValue.Kind() {
|
||||||
|
case reflect.Struct:
|
||||||
|
generateConds(reflectValue)
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
for i := 0; i < reflectValue.Len(); i++ {
|
||||||
|
generateConds(reflectValue.Index(i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
association.DB.Where(conds).Delete(modelValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
return association.Error
|
return association.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,7 +182,7 @@ func (association *Association) Delete(values ...interface{}) error {
|
||||||
column, values := schema.ToQueryValues(foreignKeys, relQueryValues)
|
column, values := schema.ToQueryValues(foreignKeys, relQueryValues)
|
||||||
tx.Where(clause.IN{Column: column, Values: values})
|
tx.Where(clause.IN{Column: column, Values: values})
|
||||||
|
|
||||||
switch association.Relationship.Type {
|
switch rel.Type {
|
||||||
case schema.HasOne, schema.HasMany:
|
case schema.HasOne, schema.HasMany:
|
||||||
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||||
tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs)
|
tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs)
|
||||||
|
@ -164,3 +268,95 @@ func (association *Association) Count() (count int) {
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (association *Association) saveAssociation(clear bool, values ...interface{}) {
|
||||||
|
reflectValue := association.DB.Statement.ReflectValue
|
||||||
|
|
||||||
|
appendToRelations := func(source, rv reflect.Value, clear bool) {
|
||||||
|
switch association.Relationship.Type {
|
||||||
|
case schema.HasOne, schema.BelongsTo:
|
||||||
|
switch rv.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
if rv.Len() > 0 {
|
||||||
|
association.Error = association.Relationship.Field.Set(source, rv.Index(0))
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
association.Error = association.Relationship.Field.Set(source, rv)
|
||||||
|
}
|
||||||
|
case schema.HasMany, schema.Many2Many:
|
||||||
|
elemType := association.Relationship.Field.IndirectFieldType.Elem()
|
||||||
|
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(reflectValue))
|
||||||
|
if clear {
|
||||||
|
fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType)
|
||||||
|
}
|
||||||
|
|
||||||
|
appendToFieldValues := func(ev reflect.Value) {
|
||||||
|
if ev.Type().AssignableTo(elemType) {
|
||||||
|
fieldValue = reflect.Append(fieldValue, ev)
|
||||||
|
} else if ev.Type().Elem().AssignableTo(elemType) {
|
||||||
|
fieldValue = reflect.Append(fieldValue, ev.Elem())
|
||||||
|
} else {
|
||||||
|
association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch rv.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
for i := 0; i < rv.Len(); i++ {
|
||||||
|
appendToFieldValues(reflect.Indirect(rv.Index(i)))
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
appendToFieldValues(rv)
|
||||||
|
}
|
||||||
|
|
||||||
|
if association.Error == nil {
|
||||||
|
association.Error = association.Relationship.Field.Set(source, fieldValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
selectedColumns := []string{association.Relationship.Name}
|
||||||
|
hasZero := false
|
||||||
|
for _, ref := range association.Relationship.References {
|
||||||
|
if !ref.OwnPrimaryKey {
|
||||||
|
selectedColumns = append(selectedColumns, ref.ForeignKey.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch reflectValue.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
if len(values) != reflectValue.Len() {
|
||||||
|
if clear && len(values) == 0 {
|
||||||
|
for i := 0; i < reflectValue.Len(); i++ {
|
||||||
|
association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType))
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
association.Error = errors.New("invalid association values, length doesn't match")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < reflectValue.Len(); i++ {
|
||||||
|
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
|
||||||
|
|
||||||
|
if !hasZero {
|
||||||
|
_, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue.Index(i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if clear && len(values) == 0 {
|
||||||
|
association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType))
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, value := range values {
|
||||||
|
appendToRelations(reflectValue, reflect.Indirect(reflect.ValueOf(value)), clear && idx == 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasZero {
|
||||||
|
association.DB.Save(reflectValue.Interface())
|
||||||
|
} else {
|
||||||
|
association.DB.Select(selectedColumns).Save(reflectValue.Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -28,7 +28,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
case reflect.Slice:
|
case reflect.Slice, reflect.Array:
|
||||||
var (
|
var (
|
||||||
objs []reflect.Value
|
objs []reflect.Value
|
||||||
fieldType = rel.Field.FieldType
|
fieldType = rel.Field.FieldType
|
||||||
|
@ -92,7 +92,7 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
case reflect.Slice:
|
case reflect.Slice, reflect.Array:
|
||||||
var (
|
var (
|
||||||
fieldType = rel.Field.FieldType
|
fieldType = rel.Field.FieldType
|
||||||
isPtr = fieldType.Kind() == reflect.Ptr
|
isPtr = fieldType.Kind() == reflect.Ptr
|
||||||
|
@ -193,7 +193,7 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
case reflect.Slice:
|
case reflect.Slice, reflect.Array:
|
||||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||||
appendToElems(db.Statement.ReflectValue.Index(i))
|
appendToElems(db.Statement.ReflectValue.Index(i))
|
||||||
}
|
}
|
||||||
|
@ -260,7 +260,7 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
case reflect.Slice:
|
case reflect.Slice, reflect.Array:
|
||||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||||
appendToElems(db.Statement.ReflectValue.Index(i))
|
appendToElems(db.Statement.ReflectValue.Index(i))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue