2015-02-11 08:43:53 +03:00
package gorm
import (
"errors"
"fmt"
"reflect"
2015-04-21 10:00:36 +03:00
"strings"
2015-02-11 08:43:53 +03:00
)
2016-01-17 15:51:11 +03:00
// preloadCallback used to preload associations
2016-01-17 10:30:42 +03:00
func preloadCallback ( scope * Scope ) {
2016-01-03 09:21:21 +03:00
if scope . Search . preload == nil || scope . HasError ( ) {
2015-04-21 11:51:52 +03:00
return
}
2016-01-15 10:53:53 +03:00
var (
preloadedMap = map [ string ] bool { }
fields = scope . Fields ( )
)
2015-04-21 11:51:52 +03:00
for _ , preload := range scope . Search . preload {
2016-01-15 10:53:53 +03:00
var (
preloadFields = strings . Split ( preload . schema , "." )
currentScope = scope
currentFields = fields
)
2015-04-21 11:51:52 +03:00
2016-01-15 10:53:53 +03:00
for idx , preloadField := range preloadFields {
var currentPreloadConditions [ ] interface { }
2015-04-21 11:51:52 +03:00
2016-01-15 10:53:53 +03:00
// if not preloaded
if preloadKey := strings . Join ( preloadFields [ : idx + 1 ] , "." ) ; ! preloadedMap [ preloadKey ] {
// assign search conditions to last preload
if idx == len ( preloadFields ) - 1 {
currentPreloadConditions = preload . conditions
2015-04-21 10:00:36 +03:00
}
2016-01-15 10:53:53 +03:00
for _ , field := range currentFields {
if field . Name != preloadField || field . Relationship == nil {
continue
}
switch field . Relationship . Kind {
case "has_one" :
currentScope . handleHasOnePreload ( field , currentPreloadConditions )
case "has_many" :
currentScope . handleHasManyPreload ( field , currentPreloadConditions )
case "belongs_to" :
currentScope . handleBelongsToPreload ( field , currentPreloadConditions )
case "many_to_many" :
currentScope . handleManyToManyPreload ( field , currentPreloadConditions )
default :
scope . Err ( errors . New ( "unsupported relation" ) )
}
preloadedMap [ preloadKey ] = true
break
2015-02-11 08:43:53 +03:00
}
2015-04-21 11:51:52 +03:00
2016-01-15 10:53:53 +03:00
if ! preloadedMap [ preloadKey ] {
scope . Err ( fmt . Errorf ( "can't preload field %s for %s" , preloadField , currentScope . GetModelStruct ( ) . ModelType ) )
return
2015-04-21 11:51:52 +03:00
}
}
2016-01-15 10:53:53 +03:00
// preload next level
if idx < len ( preloadFields ) - 1 {
currentScope = currentScope . getColumnAsScope ( preloadField )
2015-04-21 11:51:52 +03:00
currentFields = currentScope . Fields ( )
2015-02-11 08:43:53 +03:00
}
}
}
}
2016-01-17 15:51:11 +03:00
// handleHasOnePreload used to preload has one associations
2015-04-21 11:51:52 +03:00
func ( scope * Scope ) handleHasOnePreload ( field * Field , conditions [ ] interface { } ) {
2015-07-30 13:19:49 +03:00
relation := field . Relationship
2015-07-30 17:36:04 +03:00
2016-01-15 10:53:53 +03:00
// get relations's primary keys
2016-01-15 17:53:09 +03:00
primaryKeys := scope . getColumnAsArray ( relation . AssociationForeignFieldNames , scope . Value )
2015-04-21 11:51:52 +03:00
if len ( primaryKeys ) == 0 {
return
}
2016-01-15 10:53:53 +03:00
// find relations
2015-04-21 11:51:52 +03:00
results := makeSlice ( field . Struct . Type )
2015-07-30 13:19:49 +03:00
scope . Err ( scope . NewDB ( ) . Where ( fmt . Sprintf ( "%v IN (%v)" , toQueryCondition ( scope , relation . ForeignDBNames ) , toQueryMarks ( primaryKeys ) ) , toQueryValues ( primaryKeys ) ... ) . Find ( results , conditions ... ) . Error )
2015-04-21 11:51:52 +03:00
2016-01-15 10:53:53 +03:00
// assign find results
var (
2016-01-18 07:20:27 +03:00
resultsValue = indirect ( reflect . ValueOf ( results ) )
2016-01-15 10:53:53 +03:00
indirectScopeValue = scope . IndirectValue ( )
)
for i := 0 ; i < resultsValue . Len ( ) ; i ++ {
result := resultsValue . Index ( i )
if indirectScopeValue . Kind ( ) == reflect . Slice {
2016-01-15 15:37:41 +03:00
foreignValues := getValueFromFields ( result , relation . ForeignFieldNames )
2016-01-15 10:53:53 +03:00
for j := 0 ; j < indirectScopeValue . Len ( ) ; j ++ {
2016-01-18 07:20:27 +03:00
if indirectValue := indirect ( indirectScopeValue . Index ( j ) ) ; equalAsString ( getValueFromFields ( indirectValue , relation . AssociationForeignFieldNames ) , foreignValues ) {
2016-01-15 15:37:41 +03:00
indirectValue . FieldByName ( field . Name ) . Set ( result )
2015-04-21 11:51:52 +03:00
break
}
}
} else {
2016-01-15 10:53:53 +03:00
scope . Err ( field . Set ( result ) )
2015-04-21 11:51:52 +03:00
}
}
}
2016-01-17 15:51:11 +03:00
// handleHasManyPreload used to preload has many associations
2015-04-21 11:51:52 +03:00
func ( scope * Scope ) handleHasManyPreload ( field * Field , conditions [ ] interface { } ) {
2015-07-30 13:19:49 +03:00
relation := field . Relationship
2016-01-15 15:37:41 +03:00
// get relations's primary keys
2016-01-15 17:53:09 +03:00
primaryKeys := scope . getColumnAsArray ( relation . AssociationForeignFieldNames , scope . Value )
2015-04-21 11:51:52 +03:00
if len ( primaryKeys ) == 0 {
return
}
2016-01-15 15:37:41 +03:00
// find relations
2015-04-21 11:51:52 +03:00
results := makeSlice ( field . Struct . Type )
2015-07-30 13:19:49 +03:00
scope . Err ( scope . NewDB ( ) . Where ( fmt . Sprintf ( "%v IN (%v)" , toQueryCondition ( scope , relation . ForeignDBNames ) , toQueryMarks ( primaryKeys ) ) , toQueryValues ( primaryKeys ) ... ) . Find ( results , conditions ... ) . Error )
2016-01-15 15:37:41 +03:00
// assign find results
var (
2016-01-18 07:20:27 +03:00
resultsValue = indirect ( reflect . ValueOf ( results ) )
2016-01-15 15:37:41 +03:00
indirectScopeValue = scope . IndirectValue ( )
)
if indirectScopeValue . Kind ( ) == reflect . Slice {
for i := 0 ; i < resultsValue . Len ( ) ; i ++ {
result := resultsValue . Index ( i )
foreignValues := getValueFromFields ( result , relation . ForeignFieldNames )
for j := 0 ; j < indirectScopeValue . Len ( ) ; j ++ {
2016-01-18 07:20:27 +03:00
object := indirect ( indirectScopeValue . Index ( j ) )
2016-01-15 15:37:41 +03:00
if equalAsString ( getValueFromFields ( object , relation . AssociationForeignFieldNames ) , foreignValues ) {
objectField := object . FieldByName ( field . Name )
objectField . Set ( reflect . Append ( objectField , result ) )
2015-04-21 11:51:52 +03:00
break
}
}
}
} else {
2016-01-15 15:37:41 +03:00
scope . Err ( field . Set ( resultsValue ) )
2015-04-21 11:51:52 +03:00
}
}
2016-01-17 15:51:11 +03:00
// handleBelongsToPreload used to preload belongs to associations
2015-04-21 11:51:52 +03:00
func ( scope * Scope ) handleBelongsToPreload ( field * Field , conditions [ ] interface { } ) {
relation := field . Relationship
2016-01-15 15:37:41 +03:00
// get relations's primary keys
2016-01-15 17:53:09 +03:00
primaryKeys := scope . getColumnAsArray ( relation . ForeignFieldNames , scope . Value )
2015-04-21 11:51:52 +03:00
if len ( primaryKeys ) == 0 {
return
}
2016-01-15 15:37:41 +03:00
// find relations
2015-04-21 11:51:52 +03:00
results := makeSlice ( field . Struct . Type )
2015-07-30 17:36:04 +03:00
scope . Err ( scope . NewDB ( ) . Where ( fmt . Sprintf ( "%v IN (%v)" , toQueryCondition ( scope , relation . AssociationForeignDBNames ) , toQueryMarks ( primaryKeys ) ) , toQueryValues ( primaryKeys ) ... ) . Find ( results , conditions ... ) . Error )
2015-04-21 11:51:52 +03:00
2016-01-15 15:37:41 +03:00
// assign find results
var (
2016-01-18 07:20:27 +03:00
resultsValue = indirect ( reflect . ValueOf ( results ) )
2016-01-15 15:37:41 +03:00
indirectScopeValue = scope . IndirectValue ( )
)
for i := 0 ; i < resultsValue . Len ( ) ; i ++ {
result := resultsValue . Index ( i )
if indirectScopeValue . Kind ( ) == reflect . Slice {
2016-01-15 10:53:53 +03:00
value := getValueFromFields ( result , relation . AssociationForeignFieldNames )
2016-01-15 15:37:41 +03:00
for j := 0 ; j < indirectScopeValue . Len ( ) ; j ++ {
2016-01-18 07:20:27 +03:00
object := indirect ( indirectScopeValue . Index ( j ) )
2016-01-15 10:53:53 +03:00
if equalAsString ( getValueFromFields ( object , relation . ForeignFieldNames ) , value ) {
2015-04-21 11:51:52 +03:00
object . FieldByName ( field . Name ) . Set ( result )
}
}
} else {
2016-01-15 15:37:41 +03:00
scope . Err ( field . Set ( result ) )
2015-04-21 11:51:52 +03:00
}
}
}
2016-01-17 15:51:11 +03:00
// handleManyToManyPreload used to preload many to many associations
2015-11-16 07:19:25 +03:00
func ( scope * Scope ) handleManyToManyPreload ( field * Field , conditions [ ] interface { } ) {
2016-01-13 05:11:31 +03:00
var (
relation = field . Relationship
joinTableHandler = relation . JoinTableHandler
2016-02-14 12:21:40 +03:00
fieldType = field . Struct . Type . Elem ( )
2016-01-15 05:08:22 +03:00
foreignKeyValue interface { }
foreignKeyType = reflect . ValueOf ( & foreignKeyValue ) . Type ( )
2016-01-15 15:37:41 +03:00
linkHash = map [ string ] [ ] reflect . Value { }
2016-01-13 05:11:31 +03:00
isPtr bool
)
2016-01-15 15:37:41 +03:00
if fieldType . Kind ( ) == reflect . Ptr {
2015-08-17 18:09:07 +03:00
isPtr = true
2016-01-15 15:37:41 +03:00
fieldType = fieldType . Elem ( )
2015-08-17 18:09:07 +03:00
}
2015-08-16 10:10:11 +03:00
2016-01-15 15:37:41 +03:00
var sourceKeys = [ ] string { }
2015-08-16 10:10:11 +03:00
for _ , key := range joinTableHandler . SourceForeignKeys ( ) {
sourceKeys = append ( sourceKeys , key . DBName )
}
2016-01-15 15:37:41 +03:00
// generate query with join table
2016-02-14 12:21:40 +03:00
newScope := scope . New ( reflect . New ( fieldType ) . Interface ( ) )
preloadJoinDB := scope . NewDB ( ) . Table ( newScope . TableName ( ) ) . Select ( "*" )
2016-01-15 06:04:49 +03:00
preloadJoinDB = joinTableHandler . JoinWith ( joinTableHandler , preloadJoinDB , scope . Value )
2015-11-16 07:19:25 +03:00
2016-02-14 12:21:40 +03:00
if primaryField := newScope . PrimaryField ( ) ; primaryField != nil {
preloadJoinDB = preloadJoinDB . Order ( fmt . Sprintf ( "%v.%v %v" , newScope . QuotedTableName ( ) , newScope . Quote ( primaryField . DBName ) , "ASC" ) )
}
2016-01-13 05:11:31 +03:00
// preload inline conditions
2015-08-17 18:09:07 +03:00
if len ( conditions ) > 0 {
preloadJoinDB = preloadJoinDB . Where ( conditions [ 0 ] , conditions [ 1 : ] ... )
}
2016-01-13 05:11:31 +03:00
2015-08-17 18:09:07 +03:00
rows , err := preloadJoinDB . Rows ( )
2015-08-16 10:10:11 +03:00
if scope . Err ( err ) != nil {
return
}
defer rows . Close ( )
columns , _ := rows . Columns ( )
for rows . Next ( ) {
2016-01-15 05:08:22 +03:00
var (
2016-01-15 15:37:41 +03:00
elem = reflect . New ( fieldType ) . Elem ( )
2016-01-15 05:08:22 +03:00
fields = scope . New ( elem . Addr ( ) . Interface ( ) ) . Fields ( )
)
// register foreign keys in join tables
for _ , sourceKey := range sourceKeys {
fields [ sourceKey ] = & Field { Field : reflect . New ( foreignKeyType ) . Elem ( ) }
2015-08-16 10:10:11 +03:00
}
2016-01-15 05:08:22 +03:00
scope . scan ( rows , columns , fields )
2015-08-16 10:10:11 +03:00
2016-01-15 05:08:22 +03:00
// generate hashed forkey keys in join table
var foreignKeys = make ( [ ] interface { } , len ( sourceKeys ) )
for idx , sourceKey := range sourceKeys {
foreignKeys [ idx ] = fields [ sourceKey ] . Field . Elem ( ) . Interface ( )
2015-08-16 10:10:11 +03:00
}
2016-01-15 05:08:22 +03:00
hashedSourceKeys := toString ( foreignKeys )
2015-08-16 10:10:11 +03:00
2016-01-15 05:08:22 +03:00
if isPtr {
linkHash [ hashedSourceKeys ] = append ( linkHash [ hashedSourceKeys ] , elem . Addr ( ) )
} else {
linkHash [ hashedSourceKeys ] = append ( linkHash [ hashedSourceKeys ] , elem )
2015-08-17 18:09:07 +03:00
}
2015-08-16 10:10:11 +03:00
}
2016-01-15 06:04:49 +03:00
// assign find results
var (
indirectScopeValue = scope . IndirectValue ( )
fieldsSourceMap = map [ string ] reflect . Value { }
foreignFieldNames = [ ] string { }
2016-01-15 15:37:41 +03:00
fields = scope . Fields ( )
2016-01-15 06:04:49 +03:00
)
2016-01-03 12:20:24 +03:00
for _ , dbName := range relation . ForeignFieldNames {
2016-01-15 15:37:41 +03:00
if field , ok := fields [ dbName ] ; ok {
2016-01-03 12:20:24 +03:00
foreignFieldNames = append ( foreignFieldNames , field . Name )
2015-08-18 04:08:33 +03:00
}
}
2016-01-15 06:04:49 +03:00
if indirectScopeValue . Kind ( ) == reflect . Slice {
for j := 0 ; j < indirectScopeValue . Len ( ) ; j ++ {
2016-01-18 07:20:27 +03:00
object := indirect ( indirectScopeValue . Index ( j ) )
2016-01-15 10:53:53 +03:00
fieldsSourceMap [ toString ( getValueFromFields ( object , foreignFieldNames ) ) ] = object . FieldByName ( field . Name )
2015-08-16 12:25:25 +03:00
}
2016-01-15 06:04:49 +03:00
} else if indirectScopeValue . IsValid ( ) {
2016-01-15 10:53:53 +03:00
fieldsSourceMap [ toString ( getValueFromFields ( indirectScopeValue , foreignFieldNames ) ) ] = indirectScopeValue . FieldByName ( field . Name )
2016-01-15 06:04:49 +03:00
}
for source , link := range linkHash {
fieldsSourceMap [ source ] . Set ( reflect . Append ( fieldsSourceMap [ source ] , link ... ) )
2015-08-16 10:10:11 +03:00
}
}