gorm/callbacks/preload.go

156 lines
5.3 KiB
Go
Raw Normal View History

2020-05-07 05:03:48 +03:00
package callbacks
import (
2020-05-14 07:19:12 +03:00
"reflect"
2020-06-02 04:16:07 +03:00
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
2020-05-07 05:03:48 +03:00
)
func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
2020-05-14 07:19:12 +03:00
var (
reflectValue = db.Statement.ReflectValue
2020-05-14 07:19:12 +03:00
rel = rels[len(rels)-1]
tx = db.Session(&gorm.Session{NewDB: true})
2020-05-14 07:19:12 +03:00
relForeignKeys []string
relForeignFields []*schema.Field
foreignFields []*schema.Field
foreignValues [][]interface{}
identityMap = map[string][]reflect.Value{}
2020-06-01 19:44:48 +03:00
inlineConds []interface{}
2020-05-14 07:19:12 +03:00
)
if len(rels) > 1 {
2020-05-23 19:52:25 +03:00
reflectValue = schema.GetRelationsValues(reflectValue, rels[:len(rels)-1])
2020-05-14 07:19:12 +03:00
}
if rel.JoinTable != nil {
var joinForeignFields, joinRelForeignFields []*schema.Field
var joinForeignKeys []string
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName)
joinForeignFields = append(joinForeignFields, ref.ForeignKey)
foreignFields = append(foreignFields, ref.PrimaryKey)
} else if ref.PrimaryValue != "" {
2020-05-31 09:41:45 +03:00
tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
2020-05-14 07:19:12 +03:00
} else {
joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey)
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
relForeignFields = append(relForeignFields, ref.PrimaryKey)
}
}
2020-05-23 16:03:28 +03:00
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
if len(joinForeignValues) == 0 {
return
}
2020-05-18 08:07:11 +03:00
joinResults := rel.JoinTable.MakeSlice().Elem()
column, values := schema.ToQueryValues(rel.JoinTable.Table, joinForeignKeys, joinForeignValues)
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error)
2020-05-14 07:19:12 +03:00
// convert join identity map to relation identity map
2020-06-01 03:12:44 +03:00
fieldValues := make([]interface{}, len(joinForeignFields))
joinFieldValues := make([]interface{}, len(joinRelForeignFields))
2020-05-14 07:19:12 +03:00
for i := 0; i < joinResults.Len(); i++ {
2020-05-23 16:03:28 +03:00
for idx, field := range joinForeignFields {
fieldValues[idx], _ = field.ValueOf(joinResults.Index(i))
2020-05-14 07:19:12 +03:00
}
2020-05-23 16:03:28 +03:00
for idx, field := range joinRelForeignFields {
joinFieldValues[idx], _ = field.ValueOf(joinResults.Index(i))
2020-05-14 07:19:12 +03:00
}
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
2020-06-01 19:44:48 +03:00
joinKey := utils.ToStringKey(joinFieldValues...)
identityMap[joinKey] = append(identityMap[joinKey], results...)
2020-05-14 07:19:12 +03:00
}
}
2020-05-18 08:07:11 +03:00
_, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields)
2020-05-14 07:19:12 +03:00
} else {
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
relForeignFields = append(relForeignFields, ref.ForeignKey)
foreignFields = append(foreignFields, ref.PrimaryKey)
} else if ref.PrimaryValue != "" {
2020-05-31 09:41:45 +03:00
tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
2020-05-14 07:19:12 +03:00
} else {
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
relForeignFields = append(relForeignFields, ref.PrimaryKey)
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
2020-05-18 08:07:11 +03:00
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
2020-05-23 16:03:28 +03:00
if len(foreignValues) == 0 {
return
}
2020-05-14 07:19:12 +03:00
}
2020-05-18 08:07:11 +03:00
reflectResults := rel.FieldSchema.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
2020-06-01 19:44:48 +03:00
for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
tx = fc(tx)
} else {
inlineConds = append(inlineConds, cond)
}
}
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error)
2020-05-14 07:19:12 +03:00
2020-06-01 03:12:44 +03:00
fieldValues := make([]interface{}, len(relForeignFields))
2020-06-01 19:44:48 +03:00
// clean up old values before preloading
switch reflectValue.Kind() {
case reflect.Struct:
switch rel.Type {
case schema.HasMany, schema.Many2Many:
2020-11-10 13:38:24 +03:00
rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
default:
rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface())
}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
switch rel.Type {
case schema.HasMany, schema.Many2Many:
2020-11-10 13:38:24 +03:00
rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
default:
rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())
}
}
}
2020-05-14 07:19:12 +03:00
for i := 0; i < reflectResults.Len(); i++ {
2020-06-01 19:44:48 +03:00
elem := reflectResults.Index(i)
for idx, field := range relForeignFields {
2020-06-01 19:44:48 +03:00
fieldValues[idx], _ = field.ValueOf(elem)
2020-05-14 07:19:12 +03:00
}
for _, data := range identityMap[utils.ToStringKey(fieldValues...)] {
2020-05-29 18:38:03 +03:00
reflectFieldValue := rel.Field.ReflectValueOf(data)
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
}
2020-06-01 19:44:48 +03:00
2020-05-29 18:38:03 +03:00
reflectFieldValue = reflect.Indirect(reflectFieldValue)
2020-05-14 07:19:12 +03:00
switch reflectFieldValue.Kind() {
case reflect.Struct:
rel.Field.Set(data, reflectResults.Index(i).Interface())
2020-05-14 07:19:12 +03:00
case reflect.Slice, reflect.Array:
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
2020-06-01 19:44:48 +03:00
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface())
2020-05-23 19:52:25 +03:00
} else {
2020-06-01 19:44:48 +03:00
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())
}
2020-05-14 07:19:12 +03:00
}
}
}
2020-05-07 05:03:48 +03:00
}