2015-02-11 08:43:53 +03:00
package gorm
import (
"errors"
"fmt"
"reflect"
2017-04-28 01:53:39 +03:00
"strconv"
2015-04-21 10:00:36 +03:00
"strings"
2015-02-11 08:43:53 +03:00
)
2016-01-17 15:51:11 +03:00
// preloadCallback used to preload associations
2016-01-17 10:30:42 +03:00
func preloadCallback ( scope * Scope ) {
2018-02-09 19:07:16 +03:00
if _ , skip := scope . InstanceGet ( "gorm:skip_query_callback" ) ; skip {
return
}
2017-04-28 01:53:39 +03:00
2018-08-19 02:11:27 +03:00
if ap , ok := scope . Get ( "gorm:auto_preload" ) ; ok {
// If gorm:auto_preload IS NOT a bool then auto preload.
// Else if it IS a bool, use the value
if apb , ok := ap . ( bool ) ; ! ok {
autoPreload ( scope )
} else if apb {
autoPreload ( scope )
}
2017-04-28 01:53:39 +03:00
}
2016-01-03 09:21:21 +03:00
if scope . Search . preload == nil || scope . HasError ( ) {
2015-04-21 11:51:52 +03:00
return
}
2016-01-15 10:53:53 +03:00
var (
preloadedMap = map [ string ] bool { }
fields = scope . Fields ( )
)
2015-04-21 11:51:52 +03:00
for _ , preload := range scope . Search . preload {
2016-01-15 10:53:53 +03:00
var (
preloadFields = strings . Split ( preload . schema , "." )
currentScope = scope
currentFields = fields
)
2015-04-21 11:51:52 +03:00
2016-01-15 10:53:53 +03:00
for idx , preloadField := range preloadFields {
var currentPreloadConditions [ ] interface { }
2015-04-21 11:51:52 +03:00
2016-05-23 19:54:51 +03:00
if currentScope == nil {
continue
}
2016-01-15 10:53:53 +03:00
// if not preloaded
if preloadKey := strings . Join ( preloadFields [ : idx + 1 ] , "." ) ; ! preloadedMap [ preloadKey ] {
// assign search conditions to last preload
if idx == len ( preloadFields ) - 1 {
currentPreloadConditions = preload . conditions
2015-04-21 10:00:36 +03:00
}
2016-01-15 10:53:53 +03:00
for _ , field := range currentFields {
if field . Name != preloadField || field . Relationship == nil {
continue
}
switch field . Relationship . Kind {
case "has_one" :
currentScope . handleHasOnePreload ( field , currentPreloadConditions )
case "has_many" :
currentScope . handleHasManyPreload ( field , currentPreloadConditions )
case "belongs_to" :
currentScope . handleBelongsToPreload ( field , currentPreloadConditions )
case "many_to_many" :
currentScope . handleManyToManyPreload ( field , currentPreloadConditions )
default :
scope . Err ( errors . New ( "unsupported relation" ) )
}
preloadedMap [ preloadKey ] = true
break
2015-02-11 08:43:53 +03:00
}
2015-04-21 11:51:52 +03:00
2016-01-15 10:53:53 +03:00
if ! preloadedMap [ preloadKey ] {
scope . Err ( fmt . Errorf ( "can't preload field %s for %s" , preloadField , currentScope . GetModelStruct ( ) . ModelType ) )
return
2015-04-21 11:51:52 +03:00
}
}
2016-01-15 10:53:53 +03:00
// preload next level
if idx < len ( preloadFields ) - 1 {
currentScope = currentScope . getColumnAsScope ( preloadField )
2016-05-23 19:54:51 +03:00
if currentScope != nil {
currentFields = currentScope . Fields ( )
}
2015-02-11 08:43:53 +03:00
}
}
}
}
2017-04-28 01:53:39 +03:00
func autoPreload ( scope * Scope ) {
for _ , field := range scope . Fields ( ) {
if field . Relationship == nil {
continue
}
if val , ok := field . TagSettings [ "PRELOAD" ] ; ok {
if preload , err := strconv . ParseBool ( val ) ; err != nil {
scope . Err ( errors . New ( "invalid preload option" ) )
return
} else if ! preload {
continue
}
}
scope . Search . Preload ( field . Name )
}
}
2016-02-15 16:29:47 +03:00
func ( scope * Scope ) generatePreloadDBWithConditions ( conditions [ ] interface { } ) ( * DB , [ ] interface { } ) {
var (
preloadDB = scope . NewDB ( )
preloadConditions [ ] interface { }
)
for _ , condition := range conditions {
if scopes , ok := condition . ( func ( * DB ) * DB ) ; ok {
preloadDB = scopes ( preloadDB )
} else {
preloadConditions = append ( preloadConditions , condition )
}
}
return preloadDB , preloadConditions
}
2016-01-17 15:51:11 +03:00
// handleHasOnePreload used to preload has one associations
2015-04-21 11:51:52 +03:00
func ( scope * Scope ) handleHasOnePreload ( field * Field , conditions [ ] interface { } ) {
2015-07-30 13:19:49 +03:00
relation := field . Relationship
2015-07-30 17:36:04 +03:00
2016-01-15 10:53:53 +03:00
// get relations's primary keys
2016-01-15 17:53:09 +03:00
primaryKeys := scope . getColumnAsArray ( relation . AssociationForeignFieldNames , scope . Value )
2015-04-21 11:51:52 +03:00
if len ( primaryKeys ) == 0 {
return
}
2016-02-15 16:29:47 +03:00
// preload conditions
preloadDB , preloadConditions := scope . generatePreloadDBWithConditions ( conditions )
2016-01-15 10:53:53 +03:00
// find relations
2016-06-01 14:46:45 +03:00
query := fmt . Sprintf ( "%v IN (%v)" , toQueryCondition ( scope , relation . ForeignDBNames ) , toQueryMarks ( primaryKeys ) )
values := toQueryValues ( primaryKeys )
if relation . PolymorphicType != "" {
query += fmt . Sprintf ( " AND %v = ?" , scope . Quote ( relation . PolymorphicDBName ) )
2016-10-06 15:33:48 +03:00
values = append ( values , relation . PolymorphicValue )
2016-06-01 14:46:45 +03:00
}
2015-04-21 11:51:52 +03:00
results := makeSlice ( field . Struct . Type )
2016-06-01 14:46:45 +03:00
scope . Err ( preloadDB . Where ( query , values ... ) . Find ( results , preloadConditions ... ) . Error )
2015-04-21 11:51:52 +03:00
2016-01-15 10:53:53 +03:00
// assign find results
var (
2016-01-18 07:20:27 +03:00
resultsValue = indirect ( reflect . ValueOf ( results ) )
2016-01-15 10:53:53 +03:00
indirectScopeValue = scope . IndirectValue ( )
)
2016-06-16 12:58:25 +03:00
if indirectScopeValue . Kind ( ) == reflect . Slice {
for j := 0 ; j < indirectScopeValue . Len ( ) ; j ++ {
for i := 0 ; i < resultsValue . Len ( ) ; i ++ {
result := resultsValue . Index ( i )
foreignValues := getValueFromFields ( result , relation . ForeignFieldNames )
2016-01-18 07:20:27 +03:00
if indirectValue := indirect ( indirectScopeValue . Index ( j ) ) ; equalAsString ( getValueFromFields ( indirectValue , relation . AssociationForeignFieldNames ) , foreignValues ) {
2016-01-15 15:37:41 +03:00
indirectValue . FieldByName ( field . Name ) . Set ( result )
2015-04-21 11:51:52 +03:00
break
}
}
2016-06-16 12:58:25 +03:00
}
} else {
for i := 0 ; i < resultsValue . Len ( ) ; i ++ {
result := resultsValue . Index ( i )
2016-01-15 10:53:53 +03:00
scope . Err ( field . Set ( result ) )
2015-04-21 11:51:52 +03:00
}
}
}
2016-01-17 15:51:11 +03:00
// handleHasManyPreload used to preload has many associations
2015-04-21 11:51:52 +03:00
func ( scope * Scope ) handleHasManyPreload ( field * Field , conditions [ ] interface { } ) {
2015-07-30 13:19:49 +03:00
relation := field . Relationship
2016-01-15 15:37:41 +03:00
// get relations's primary keys
2016-01-15 17:53:09 +03:00
primaryKeys := scope . getColumnAsArray ( relation . AssociationForeignFieldNames , scope . Value )
2015-04-21 11:51:52 +03:00
if len ( primaryKeys ) == 0 {
return
}
2016-02-15 16:29:47 +03:00
// preload conditions
preloadDB , preloadConditions := scope . generatePreloadDBWithConditions ( conditions )
2016-01-15 15:37:41 +03:00
// find relations
2016-06-01 14:46:45 +03:00
query := fmt . Sprintf ( "%v IN (%v)" , toQueryCondition ( scope , relation . ForeignDBNames ) , toQueryMarks ( primaryKeys ) )
values := toQueryValues ( primaryKeys )
if relation . PolymorphicType != "" {
query += fmt . Sprintf ( " AND %v = ?" , scope . Quote ( relation . PolymorphicDBName ) )
2016-10-06 15:33:48 +03:00
values = append ( values , relation . PolymorphicValue )
2016-06-01 14:46:45 +03:00
}
2015-04-21 11:51:52 +03:00
results := makeSlice ( field . Struct . Type )
2016-06-01 14:46:45 +03:00
scope . Err ( preloadDB . Where ( query , values ... ) . Find ( results , preloadConditions ... ) . Error )
2016-01-15 15:37:41 +03:00
// assign find results
var (
2016-01-18 07:20:27 +03:00
resultsValue = indirect ( reflect . ValueOf ( results ) )
2016-01-15 15:37:41 +03:00
indirectScopeValue = scope . IndirectValue ( )
)
if indirectScopeValue . Kind ( ) == reflect . Slice {
2016-05-09 15:15:35 +03:00
preloadMap := make ( map [ string ] [ ] reflect . Value )
2016-01-15 15:37:41 +03:00
for i := 0 ; i < resultsValue . Len ( ) ; i ++ {
result := resultsValue . Index ( i )
foreignValues := getValueFromFields ( result , relation . ForeignFieldNames )
2016-05-09 15:15:35 +03:00
preloadMap [ toString ( foreignValues ) ] = append ( preloadMap [ toString ( foreignValues ) ] , result )
}
2016-05-09 17:42:07 +03:00
for j := 0 ; j < indirectScopeValue . Len ( ) ; j ++ {
object := indirect ( indirectScopeValue . Index ( j ) )
2016-05-09 15:15:35 +03:00
objectRealValue := getValueFromFields ( object , relation . AssociationForeignFieldNames )
2016-09-27 23:55:04 +03:00
f := object . FieldByName ( field . Name )
2016-05-09 17:42:07 +03:00
if results , ok := preloadMap [ toString ( objectRealValue ) ] ; ok {
2016-05-09 15:15:35 +03:00
f . Set ( reflect . Append ( f , results ... ) )
2016-09-27 23:55:04 +03:00
} else {
f . Set ( reflect . MakeSlice ( f . Type ( ) , 0 , 0 ) )
2015-04-21 11:51:52 +03:00
}
}
} else {
2016-01-15 15:37:41 +03:00
scope . Err ( field . Set ( resultsValue ) )
2015-04-21 11:51:52 +03:00
}
}
2016-01-17 15:51:11 +03:00
// handleBelongsToPreload used to preload belongs to associations
2015-04-21 11:51:52 +03:00
func ( scope * Scope ) handleBelongsToPreload ( field * Field , conditions [ ] interface { } ) {
relation := field . Relationship
2016-01-15 15:37:41 +03:00
2016-02-15 16:29:47 +03:00
// preload conditions
preloadDB , preloadConditions := scope . generatePreloadDBWithConditions ( conditions )
2016-01-15 15:37:41 +03:00
// get relations's primary keys
2016-01-15 17:53:09 +03:00
primaryKeys := scope . getColumnAsArray ( relation . ForeignFieldNames , scope . Value )
2015-04-21 11:51:52 +03:00
if len ( primaryKeys ) == 0 {
return
}
2016-01-15 15:37:41 +03:00
// find relations
2015-04-21 11:51:52 +03:00
results := makeSlice ( field . Struct . Type )
2016-02-15 16:29:47 +03:00
scope . Err ( preloadDB . Where ( fmt . Sprintf ( "%v IN (%v)" , toQueryCondition ( scope , relation . AssociationForeignDBNames ) , toQueryMarks ( primaryKeys ) ) , toQueryValues ( primaryKeys ) ... ) . Find ( results , preloadConditions ... ) . Error )
2015-04-21 11:51:52 +03:00
2016-01-15 15:37:41 +03:00
// assign find results
var (
2016-01-18 07:20:27 +03:00
resultsValue = indirect ( reflect . ValueOf ( results ) )
2016-01-15 15:37:41 +03:00
indirectScopeValue = scope . IndirectValue ( )
)
for i := 0 ; i < resultsValue . Len ( ) ; i ++ {
result := resultsValue . Index ( i )
if indirectScopeValue . Kind ( ) == reflect . Slice {
2016-01-15 10:53:53 +03:00
value := getValueFromFields ( result , relation . AssociationForeignFieldNames )
2016-01-15 15:37:41 +03:00
for j := 0 ; j < indirectScopeValue . Len ( ) ; j ++ {
2016-01-18 07:20:27 +03:00
object := indirect ( indirectScopeValue . Index ( j ) )
2016-01-15 10:53:53 +03:00
if equalAsString ( getValueFromFields ( object , relation . ForeignFieldNames ) , value ) {
2015-04-21 11:51:52 +03:00
object . FieldByName ( field . Name ) . Set ( result )
}
}
} else {
2016-01-15 15:37:41 +03:00
scope . Err ( field . Set ( result ) )
2015-04-21 11:51:52 +03:00
}
}
}
2016-01-17 15:51:11 +03:00
// handleManyToManyPreload used to preload many to many associations
2015-11-16 07:19:25 +03:00
func ( scope * Scope ) handleManyToManyPreload ( field * Field , conditions [ ] interface { } ) {
2016-01-13 05:11:31 +03:00
var (
relation = field . Relationship
joinTableHandler = relation . JoinTableHandler
2016-02-14 12:21:40 +03:00
fieldType = field . Struct . Type . Elem ( )
2016-01-15 05:08:22 +03:00
foreignKeyValue interface { }
foreignKeyType = reflect . ValueOf ( & foreignKeyValue ) . Type ( )
2016-01-15 15:37:41 +03:00
linkHash = map [ string ] [ ] reflect . Value { }
2016-01-13 05:11:31 +03:00
isPtr bool
)
2016-01-15 15:37:41 +03:00
if fieldType . Kind ( ) == reflect . Ptr {
2015-08-17 18:09:07 +03:00
isPtr = true
2016-01-15 15:37:41 +03:00
fieldType = fieldType . Elem ( )
2015-08-17 18:09:07 +03:00
}
2015-08-16 10:10:11 +03:00
2016-01-15 15:37:41 +03:00
var sourceKeys = [ ] string { }
2015-08-16 10:10:11 +03:00
for _ , key := range joinTableHandler . SourceForeignKeys ( ) {
sourceKeys = append ( sourceKeys , key . DBName )
}
2016-02-15 16:29:47 +03:00
// preload conditions
preloadDB , preloadConditions := scope . generatePreloadDBWithConditions ( conditions )
2016-01-15 15:37:41 +03:00
// generate query with join table
2016-02-14 12:21:40 +03:00
newScope := scope . New ( reflect . New ( fieldType ) . Interface ( ) )
2017-08-02 02:05:11 +03:00
preloadDB = preloadDB . Table ( newScope . TableName ( ) ) . Model ( newScope . Value )
if len ( preloadDB . search . selects ) == 0 {
preloadDB = preloadDB . Select ( "*" )
}
2016-02-15 16:29:47 +03:00
preloadDB = joinTableHandler . JoinWith ( joinTableHandler , preloadDB , scope . Value )
2016-02-14 12:21:40 +03:00
2016-01-13 05:11:31 +03:00
// preload inline conditions
2016-02-15 16:29:47 +03:00
if len ( preloadConditions ) > 0 {
preloadDB = preloadDB . Where ( preloadConditions [ 0 ] , preloadConditions [ 1 : ] ... )
2015-08-17 18:09:07 +03:00
}
2016-01-13 05:11:31 +03:00
2016-02-15 16:29:47 +03:00
rows , err := preloadDB . Rows ( )
2015-08-16 10:10:11 +03:00
if scope . Err ( err ) != nil {
return
}
defer rows . Close ( )
columns , _ := rows . Columns ( )
for rows . Next ( ) {
2016-01-15 05:08:22 +03:00
var (
2016-01-15 15:37:41 +03:00
elem = reflect . New ( fieldType ) . Elem ( )
2016-03-10 12:13:48 +03:00
fields = scope . New ( elem . Addr ( ) . Interface ( ) ) . Fields ( )
2016-01-15 05:08:22 +03:00
)
// register foreign keys in join tables
2016-03-10 12:13:48 +03:00
var joinTableFields [ ] * Field
2016-01-15 05:08:22 +03:00
for _ , sourceKey := range sourceKeys {
2016-03-10 12:13:48 +03:00
joinTableFields = append ( joinTableFields , & Field { StructField : & StructField { DBName : sourceKey , IsNormal : true } , Field : reflect . New ( foreignKeyType ) . Elem ( ) } )
2015-08-16 10:10:11 +03:00
}
2016-03-10 12:13:48 +03:00
scope . scan ( rows , columns , append ( fields , joinTableFields ... ) )
2015-08-16 10:10:11 +03:00
2018-02-09 18:22:53 +03:00
scope . New ( elem . Addr ( ) . Interface ( ) ) .
2018-02-09 19:07:16 +03:00
InstanceSet ( "gorm:skip_query_callback" , true ) .
2018-02-09 18:22:53 +03:00
callCallbacks ( scope . db . parent . callbacks . queries )
2016-01-15 05:08:22 +03:00
var foreignKeys = make ( [ ] interface { } , len ( sourceKeys ) )
2016-03-10 12:13:48 +03:00
// generate hashed forkey keys in join table
for idx , joinTableField := range joinTableFields {
if ! joinTableField . Field . IsNil ( ) {
foreignKeys [ idx ] = joinTableField . Field . Elem ( ) . Interface ( )
}
2015-08-16 10:10:11 +03:00
}
2016-01-15 05:08:22 +03:00
hashedSourceKeys := toString ( foreignKeys )
2015-08-16 10:10:11 +03:00
2016-01-15 05:08:22 +03:00
if isPtr {
linkHash [ hashedSourceKeys ] = append ( linkHash [ hashedSourceKeys ] , elem . Addr ( ) )
} else {
linkHash [ hashedSourceKeys ] = append ( linkHash [ hashedSourceKeys ] , elem )
2015-08-17 18:09:07 +03:00
}
2015-08-16 10:10:11 +03:00
}
2017-03-24 04:28:06 +03:00
if err := rows . Err ( ) ; err != nil {
scope . Err ( err )
}
2016-01-15 06:04:49 +03:00
// assign find results
var (
indirectScopeValue = scope . IndirectValue ( )
2016-05-10 09:43:50 +03:00
fieldsSourceMap = map [ string ] [ ] reflect . Value { }
2016-01-15 06:04:49 +03:00
foreignFieldNames = [ ] string { }
)
2016-01-03 12:20:24 +03:00
for _ , dbName := range relation . ForeignFieldNames {
2016-03-10 12:13:48 +03:00
if field , ok := scope . FieldByName ( dbName ) ; ok {
2016-01-03 12:20:24 +03:00
foreignFieldNames = append ( foreignFieldNames , field . Name )
2015-08-18 04:08:33 +03:00
}
}
2016-01-15 06:04:49 +03:00
if indirectScopeValue . Kind ( ) == reflect . Slice {
for j := 0 ; j < indirectScopeValue . Len ( ) ; j ++ {
2016-01-18 07:20:27 +03:00
object := indirect ( indirectScopeValue . Index ( j ) )
2016-05-10 09:43:50 +03:00
key := toString ( getValueFromFields ( object , foreignFieldNames ) )
fieldsSourceMap [ key ] = append ( fieldsSourceMap [ key ] , object . FieldByName ( field . Name ) )
2015-08-16 12:25:25 +03:00
}
2016-01-15 06:04:49 +03:00
} else if indirectScopeValue . IsValid ( ) {
2016-05-10 09:43:50 +03:00
key := toString ( getValueFromFields ( indirectScopeValue , foreignFieldNames ) )
fieldsSourceMap [ key ] = append ( fieldsSourceMap [ key ] , indirectScopeValue . FieldByName ( field . Name ) )
2016-01-15 06:04:49 +03:00
}
for source , link := range linkHash {
2016-05-10 09:43:50 +03:00
for i , field := range fieldsSourceMap [ source ] {
//If not 0 this means Value is a pointer and we already added preloaded models to it
if fieldsSourceMap [ source ] [ i ] . Len ( ) != 0 {
continue
}
field . Set ( reflect . Append ( fieldsSourceMap [ source ] [ i ] , link ... ) )
}
2015-08-16 10:10:11 +03:00
}
}