gorm/callbacks/preload.go

174 lines
6.0 KiB
Go
Raw Permalink Normal View History

2020-05-07 05:03:48 +03:00
package callbacks
import (
"fmt"
2020-05-14 07:19:12 +03:00
"reflect"
2020-06-02 04:16:07 +03:00
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
2020-05-07 05:03:48 +03:00
)
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
2020-05-14 07:19:12 +03:00
var (
reflectValue = tx.Statement.ReflectValue
2020-05-14 07:19:12 +03:00
relForeignKeys []string
relForeignFields []*schema.Field
foreignFields []*schema.Field
foreignValues [][]interface{}
identityMap = map[string][]reflect.Value{}
2020-06-01 19:44:48 +03:00
inlineConds []interface{}
2020-05-14 07:19:12 +03:00
)
if rel.JoinTable != nil {
var (
joinForeignFields = make([]*schema.Field, 0, len(rel.References))
joinRelForeignFields = make([]*schema.Field, 0, len(rel.References))
joinForeignKeys = make([]string, 0, len(rel.References))
)
2020-05-14 07:19:12 +03:00
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName)
joinForeignFields = append(joinForeignFields, ref.ForeignKey)
foreignFields = append(foreignFields, ref.PrimaryKey)
} else if ref.PrimaryValue != "" {
2020-05-31 09:41:45 +03:00
tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
2020-05-14 07:19:12 +03:00
} else {
joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey)
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
relForeignFields = append(relForeignFields, ref.PrimaryKey)
}
}
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
2020-05-23 16:03:28 +03:00
if len(joinForeignValues) == 0 {
return nil
2020-05-23 16:03:28 +03:00
}
2020-05-18 08:07:11 +03:00
joinResults := rel.JoinTable.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil {
return err
}
2020-05-14 07:19:12 +03:00
// convert join identity map to relation identity map
2020-06-01 03:12:44 +03:00
fieldValues := make([]interface{}, len(joinForeignFields))
joinFieldValues := make([]interface{}, len(joinRelForeignFields))
2020-05-14 07:19:12 +03:00
for i := 0; i < joinResults.Len(); i++ {
joinIndexValue := joinResults.Index(i)
2020-05-23 16:03:28 +03:00
for idx, field := range joinForeignFields {
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
2020-05-14 07:19:12 +03:00
}
2020-05-23 16:03:28 +03:00
for idx, field := range joinRelForeignFields {
joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
2020-05-14 07:19:12 +03:00
}
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
2020-06-01 19:44:48 +03:00
joinKey := utils.ToStringKey(joinFieldValues...)
identityMap[joinKey] = append(identityMap[joinKey], results...)
2020-05-14 07:19:12 +03:00
}
}
_, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields)
2020-05-14 07:19:12 +03:00
} else {
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
relForeignFields = append(relForeignFields, ref.ForeignKey)
foreignFields = append(foreignFields, ref.PrimaryKey)
} else if ref.PrimaryValue != "" {
2020-05-31 09:41:45 +03:00
tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
2020-05-14 07:19:12 +03:00
} else {
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
relForeignFields = append(relForeignFields, ref.PrimaryKey)
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
2020-05-23 16:03:28 +03:00
if len(foreignValues) == 0 {
return nil
2020-05-23 16:03:28 +03:00
}
2020-05-14 07:19:12 +03:00
}
// nested preload
for p, pvs := range preloads {
tx = tx.Preload(p, pvs...)
}
2020-05-18 08:07:11 +03:00
reflectResults := rel.FieldSchema.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
2020-06-01 19:44:48 +03:00
if len(values) != 0 {
for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
tx = fc(tx)
} else {
inlineConds = append(inlineConds, cond)
}
2020-06-01 19:44:48 +03:00
}
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
return err
}
}
2020-05-14 07:19:12 +03:00
2020-06-01 03:12:44 +03:00
fieldValues := make([]interface{}, len(relForeignFields))
2020-06-01 19:44:48 +03:00
// clean up old values before preloading
switch reflectValue.Kind() {
case reflect.Struct:
switch rel.Type {
case schema.HasMany, schema.Many2Many:
2022-03-23 12:24:25 +03:00
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
default:
2022-03-23 12:24:25 +03:00
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()))
}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
switch rel.Type {
case schema.HasMany, schema.Many2Many:
2022-03-23 12:24:25 +03:00
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
default:
2022-03-23 12:24:25 +03:00
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()))
}
}
}
2020-05-14 07:19:12 +03:00
for i := 0; i < reflectResults.Len(); i++ {
2020-06-01 19:44:48 +03:00
elem := reflectResults.Index(i)
for idx, field := range relForeignFields {
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem)
2020-05-14 07:19:12 +03:00
}
datas, ok := identityMap[utils.ToStringKey(fieldValues...)]
if !ok {
return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())
}
for _, data := range datas {
reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data)
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
}
2020-06-01 19:44:48 +03:00
reflectFieldValue = reflect.Indirect(reflectFieldValue)
switch reflectFieldValue.Kind() {
case reflect.Struct:
2022-03-23 12:24:25 +03:00
tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface()))
case reflect.Slice, reflect.Array:
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
2022-03-23 12:24:25 +03:00
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()))
} else {
2022-03-23 12:24:25 +03:00
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()))
}
2020-05-14 07:19:12 +03:00
}
}
}
return tx.Error
2020-05-07 05:03:48 +03:00
}