gorm/callbacks/preload.go

176 lines
6.1 KiB
Go
Raw Normal View History

2020-05-07 05:03:48 +03:00
package callbacks
import (
"fmt"
2020-05-14 07:19:12 +03:00
"reflect"
2020-06-02 04:16:07 +03:00
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
2020-05-07 05:03:48 +03:00
)
func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) {
2020-05-14 07:19:12 +03:00
var (
reflectValue = db.Statement.ReflectValue
tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks})
2020-05-14 07:19:12 +03:00
relForeignKeys []string
relForeignFields []*schema.Field
foreignFields []*schema.Field
foreignValues [][]interface{}
identityMap = map[string][]reflect.Value{}
2020-06-01 19:44:48 +03:00
inlineConds []interface{}
2020-05-14 07:19:12 +03:00
)
2021-01-15 12:15:59 +03:00
db.Statement.Settings.Range(func(k, v interface{}) bool {
tx.Statement.Settings.Store(k, v)
return true
})
2020-05-14 07:19:12 +03:00
if rel.JoinTable != nil {
var (
joinForeignFields = make([]*schema.Field, 0, len(rel.References))
joinRelForeignFields = make([]*schema.Field, 0, len(rel.References))
joinForeignKeys = make([]string, 0, len(rel.References))
)
2020-05-14 07:19:12 +03:00
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName)
joinForeignFields = append(joinForeignFields, ref.ForeignKey)
foreignFields = append(foreignFields, ref.PrimaryKey)
} else if ref.PrimaryValue != "" {
2020-05-31 09:41:45 +03:00
tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
2020-05-14 07:19:12 +03:00
} else {
joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey)
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
relForeignFields = append(relForeignFields, ref.PrimaryKey)
}
}
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields)
2020-05-23 16:03:28 +03:00
if len(joinForeignValues) == 0 {
return
}
2020-05-18 08:07:11 +03:00
joinResults := rel.JoinTable.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error)
2020-05-14 07:19:12 +03:00
// convert join identity map to relation identity map
2020-06-01 03:12:44 +03:00
fieldValues := make([]interface{}, len(joinForeignFields))
joinFieldValues := make([]interface{}, len(joinRelForeignFields))
2020-05-14 07:19:12 +03:00
for i := 0; i < joinResults.Len(); i++ {
joinIndexValue := joinResults.Index(i)
2020-05-23 16:03:28 +03:00
for idx, field := range joinForeignFields {
fieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue)
2020-05-14 07:19:12 +03:00
}
2020-05-23 16:03:28 +03:00
for idx, field := range joinRelForeignFields {
joinFieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue)
2020-05-14 07:19:12 +03:00
}
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
2020-06-01 19:44:48 +03:00
joinKey := utils.ToStringKey(joinFieldValues...)
identityMap[joinKey] = append(identityMap[joinKey], results...)
2020-05-14 07:19:12 +03:00
}
}
_, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, joinResults, joinRelForeignFields)
2020-05-14 07:19:12 +03:00
} else {
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
relForeignFields = append(relForeignFields, ref.ForeignKey)
foreignFields = append(foreignFields, ref.PrimaryKey)
} else if ref.PrimaryValue != "" {
2020-05-31 09:41:45 +03:00
tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
2020-05-14 07:19:12 +03:00
} else {
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
relForeignFields = append(relForeignFields, ref.PrimaryKey)
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields)
2020-05-23 16:03:28 +03:00
if len(foreignValues) == 0 {
return
}
2020-05-14 07:19:12 +03:00
}
// nested preload
for p, pvs := range preloads {
tx = tx.Preload(p, pvs...)
}
2020-05-18 08:07:11 +03:00
reflectResults := rel.FieldSchema.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
2020-06-01 19:44:48 +03:00
if len(values) != 0 {
for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
tx = fc(tx)
} else {
inlineConds = append(inlineConds, cond)
}
2020-06-01 19:44:48 +03:00
}
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error)
}
2020-05-14 07:19:12 +03:00
2020-06-01 03:12:44 +03:00
fieldValues := make([]interface{}, len(relForeignFields))
2020-06-01 19:44:48 +03:00
// clean up old values before preloading
switch reflectValue.Kind() {
case reflect.Struct:
switch rel.Type {
case schema.HasMany, schema.Many2Many:
rel.Field.Set(db.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
default:
rel.Field.Set(db.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface())
}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
switch rel.Type {
case schema.HasMany, schema.Many2Many:
rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
default:
rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())
}
}
}
2020-05-14 07:19:12 +03:00
for i := 0; i < reflectResults.Len(); i++ {
2020-06-01 19:44:48 +03:00
elem := reflectResults.Index(i)
for idx, field := range relForeignFields {
fieldValues[idx], _ = field.ValueOf(db.Statement.Context, elem)
2020-05-14 07:19:12 +03:00
}
datas, ok := identityMap[utils.ToStringKey(fieldValues...)]
if !ok {
db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists",
elem.Interface()))
continue
}
for _, data := range datas {
reflectFieldValue := rel.Field.ReflectValueOf(db.Statement.Context, data)
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
}
2020-06-01 19:44:48 +03:00
reflectFieldValue = reflect.Indirect(reflectFieldValue)
switch reflectFieldValue.Kind() {
case reflect.Struct:
rel.Field.Set(db.Statement.Context, data, elem.Interface())
case reflect.Slice, reflect.Array:
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface())
} else {
rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())
}
2020-05-14 07:19:12 +03:00
}
}
}
2020-05-07 05:03:48 +03:00
}