forked from mirror/gorm
Implement preload support
This commit is contained in:
parent
41697d58d3
commit
b549f9bb9a
|
@ -1,9 +1,196 @@
|
|||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/clause"
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
"github.com/jinzhu/gorm/utils"
|
||||
)
|
||||
|
||||
func preload(db *gorm.DB, preloadFields []string, rel *schema.Relationship) {
|
||||
// getRelationsValue get relations's value from a reflect value
|
||||
func getRelationsValue(reflectValue reflect.Value, rels []*schema.Relationship) (reflectResults reflect.Value) {
|
||||
for _, rel := range rels {
|
||||
reflectResults = reflect.MakeSlice(reflect.SliceOf(rel.FieldSchema.ModelType), 0, 0)
|
||||
|
||||
appendToResults := func(value reflect.Value) {
|
||||
if _, isZero := rel.Field.ValueOf(value); !isZero {
|
||||
result := reflect.Indirect(rel.Field.ReflectValueOf(value))
|
||||
switch result.Kind() {
|
||||
case reflect.Struct:
|
||||
reflectResults = reflect.Append(reflectResults, result)
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < value.Len(); i++ {
|
||||
reflectResults = reflect.Append(reflectResults, reflect.Indirect(value.Index(i)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
appendToResults(reflectValue)
|
||||
case reflect.Slice:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
appendToResults(reflectValue.Index(i))
|
||||
}
|
||||
}
|
||||
|
||||
reflectValue = reflectResults
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func getIdentityFieldValuesMap(reflectValue reflect.Value, fields []*schema.Field) (map[string][]reflect.Value, [][]interface{}) {
|
||||
var (
|
||||
fieldValues = make([]reflect.Value, len(fields))
|
||||
results = [][]interface{}{}
|
||||
dataResults = map[string][]reflect.Value{}
|
||||
)
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
results = [][]interface{}{make([]interface{}, len(fields))}
|
||||
|
||||
for idx, field := range fields {
|
||||
fieldValues[idx] = field.ReflectValueOf(reflectValue)
|
||||
results[0][idx] = fieldValues[idx].Interface()
|
||||
}
|
||||
|
||||
dataResults[utils.ToStringKey(fieldValues...)] = []reflect.Value{reflectValue}
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
for idx, field := range fields {
|
||||
fieldValues[idx] = field.ReflectValueOf(reflectValue.Index(i))
|
||||
}
|
||||
|
||||
dataKey := utils.ToStringKey(fieldValues...)
|
||||
if _, ok := dataResults[dataKey]; !ok {
|
||||
result := make([]interface{}, len(fieldValues))
|
||||
for idx, fieldValue := range fieldValues {
|
||||
result[idx] = fieldValue.Interface()
|
||||
}
|
||||
results = append(results, result)
|
||||
|
||||
dataResults[dataKey] = []reflect.Value{reflectValue.Index(i)}
|
||||
} else {
|
||||
dataResults[dataKey] = append(dataResults[dataKey], reflectValue.Index(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return dataResults, results
|
||||
}
|
||||
|
||||
func preloadData(tx *gorm.DB, resultSchema *schema.Schema, foreignKeys []string, foreignValues [][]interface{}) reflect.Value {
|
||||
results := reflect.MakeSlice(reflect.SliceOf(resultSchema.ModelType), 0, 0)
|
||||
queryValues := make([]interface{}, len(foreignValues))
|
||||
if len(foreignKeys) == 1 {
|
||||
for idx, r := range foreignValues {
|
||||
queryValues[idx] = r[0]
|
||||
}
|
||||
tx.Where(clause.IN{Column: foreignKeys[0], Values: queryValues}).Find(results.Addr().Interface())
|
||||
} else {
|
||||
for idx, r := range foreignValues {
|
||||
queryValues[idx] = r
|
||||
}
|
||||
tx.Where(clause.IN{Column: foreignKeys, Values: queryValues}).Find(results.Addr().Interface())
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
func preload(tx *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
|
||||
var (
|
||||
reflectValue = tx.Statement.ReflectValue
|
||||
rel = rels[len(rels)-1]
|
||||
relForeignKeys []string
|
||||
relForeignFields []*schema.Field
|
||||
foreignFields []*schema.Field
|
||||
foreignValues [][]interface{}
|
||||
identityMap = map[string][]reflect.Value{}
|
||||
)
|
||||
|
||||
if len(rels) > 1 {
|
||||
reflectValue = getRelationsValue(reflectValue, rels[:len(rels)])
|
||||
}
|
||||
|
||||
if rel.JoinTable != nil {
|
||||
var joinForeignFields, joinRelForeignFields []*schema.Field
|
||||
var joinForeignKeys []string
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName)
|
||||
joinForeignFields = append(joinForeignFields, ref.ForeignKey)
|
||||
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
} else {
|
||||
joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey)
|
||||
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
|
||||
relForeignFields = append(relForeignFields, ref.PrimaryKey)
|
||||
}
|
||||
}
|
||||
|
||||
joinIdentityMap, joinForeignValues := getIdentityFieldValuesMap(reflectValue, joinForeignFields)
|
||||
joinResults := preloadData(tx, rel.JoinTable, joinForeignKeys, joinForeignValues)
|
||||
|
||||
// convert join identity map to relation identity map
|
||||
fieldValues := make([]reflect.Value, len(foreignFields))
|
||||
joinFieldValues := make([]reflect.Value, len(joinForeignFields))
|
||||
for i := 0; i < joinResults.Len(); i++ {
|
||||
for idx, field := range foreignFields {
|
||||
fieldValues[idx] = field.ReflectValueOf(joinResults.Index(i))
|
||||
}
|
||||
|
||||
for idx, field := range joinForeignFields {
|
||||
joinFieldValues[idx] = field.ReflectValueOf(joinResults.Index(i))
|
||||
}
|
||||
|
||||
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
|
||||
identityMap[utils.ToStringKey(joinFieldValues...)] = results
|
||||
}
|
||||
}
|
||||
|
||||
_, foreignValues = getIdentityFieldValuesMap(joinResults, joinRelForeignFields)
|
||||
} else {
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
|
||||
relForeignFields = append(relForeignFields, ref.ForeignKey)
|
||||
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
} else {
|
||||
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
|
||||
relForeignFields = append(relForeignFields, ref.PrimaryKey)
|
||||
foreignFields = append(foreignFields, ref.ForeignKey)
|
||||
}
|
||||
}
|
||||
|
||||
identityMap, foreignValues = getIdentityFieldValuesMap(reflectValue, foreignFields)
|
||||
}
|
||||
|
||||
reflectResults := preloadData(tx, rel.FieldSchema, relForeignKeys, foreignValues)
|
||||
|
||||
fieldValues := make([]reflect.Value, len(foreignFields))
|
||||
for i := 0; i < reflectResults.Len(); i++ {
|
||||
for idx, field := range foreignFields {
|
||||
fieldValues[idx] = field.ReflectValueOf(reflectResults.Index(i))
|
||||
}
|
||||
|
||||
for _, data := range identityMap[utils.ToStringKey(fieldValues...)] {
|
||||
reflectFieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data))
|
||||
switch reflectFieldValue.Kind() {
|
||||
case reflect.Struct:
|
||||
elem := reflectResults.Index(i).Convert(reflectFieldValue.Type().Elem())
|
||||
rel.Field.Set(data, elem.Interface())
|
||||
case reflect.Slice, reflect.Array:
|
||||
elem := reflectResults.Index(i).Convert(reflectFieldValue.Type().Elem())
|
||||
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ func Query(db *gorm.DB) {
|
|||
}
|
||||
}
|
||||
|
||||
// inline joins
|
||||
if len(db.Statement.Joins) != 0 {
|
||||
joins := []clause.Join{}
|
||||
|
||||
|
@ -101,7 +102,6 @@ func Query(db *gorm.DB) {
|
|||
func Preload(db *gorm.DB) {
|
||||
if len(db.Statement.Preloads) > 0 {
|
||||
preloadMap := map[string][]string{}
|
||||
|
||||
for name := range db.Statement.Preloads {
|
||||
preloadFields := strings.Split(name, ".")
|
||||
for idx := range preloadFields {
|
||||
|
@ -118,27 +118,22 @@ func Preload(db *gorm.DB) {
|
|||
sort.Strings(preloadNames)
|
||||
|
||||
for _, name := range preloadNames {
|
||||
curSchema := db.Statement.Schema
|
||||
preloadFields := preloadMap[name]
|
||||
var (
|
||||
curSchema = db.Statement.Schema
|
||||
preloadFields = preloadMap[name]
|
||||
rels = make([]*schema.Relationship, len(preloadFields))
|
||||
)
|
||||
|
||||
for idx, preloadField := range preloadFields {
|
||||
if rel := curSchema.Relationships.Relations[preloadField]; rel != nil {
|
||||
if idx == len(preloadFields)-1 {
|
||||
conds := db.Statement.Preloads[strings.Join(preloadFields[:idx+1], ".")]
|
||||
|
||||
switch rel.Type {
|
||||
case schema.HasOne:
|
||||
case schema.HasMany:
|
||||
case schema.BelongsTo:
|
||||
case schema.Many2Many:
|
||||
}
|
||||
} else {
|
||||
curSchema = rel.FieldSchema
|
||||
}
|
||||
rels[idx] = rel
|
||||
curSchema = rel.FieldSchema
|
||||
} else {
|
||||
db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation))
|
||||
}
|
||||
}
|
||||
|
||||
preload(db.Session(&gorm.Session{}), rels, db.Statement.Preloads[name])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -95,6 +95,15 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
|||
}
|
||||
case string:
|
||||
stmt.DB.Dialector.QuoteTo(writer, v)
|
||||
case []string:
|
||||
writer.WriteByte('(')
|
||||
for idx, d := range v {
|
||||
if idx != 0 {
|
||||
writer.WriteString(",")
|
||||
}
|
||||
stmt.DB.Dialector.QuoteTo(writer, d)
|
||||
}
|
||||
writer.WriteByte(')')
|
||||
default:
|
||||
stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field))
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
@ -38,3 +39,24 @@ func CheckTruth(val interface{}) bool {
|
|||
|
||||
return !reflect.ValueOf(val).IsZero()
|
||||
}
|
||||
|
||||
func ToStringKey(values ...reflect.Value) string {
|
||||
results := make([]string, len(values))
|
||||
|
||||
for idx, value := range values {
|
||||
rv := reflect.Indirect(value).Interface()
|
||||
|
||||
switch v := rv.(type) {
|
||||
case string:
|
||||
results[idx] = v
|
||||
case []byte:
|
||||
results[idx] = string(v)
|
||||
case uint:
|
||||
results[idx] = strconv.FormatUint(uint64(v), 10)
|
||||
default:
|
||||
results[idx] = fmt.Sprint(v)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(results, "_")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue