2015-02-11 08:43:53 +03:00
package gorm
import (
2015-02-11 10:37:04 +03:00
"database/sql/driver"
2015-02-11 08:43:53 +03:00
"errors"
"fmt"
"reflect"
2015-04-21 10:00:36 +03:00
"strings"
2015-02-11 08:43:53 +03:00
)
2015-07-30 13:19:49 +03:00
func getRealValue ( value reflect . Value , columns [ ] string ) ( results [ ] interface { } ) {
2015-12-16 05:40:57 +03:00
// If value is a nil pointer, Indirect returns a zero Value!
// Therefor we need to check for a zero value,
// as FieldByName could panic
if pointedValue := reflect . Indirect ( value ) ; pointedValue . IsValid ( ) {
for _ , column := range columns {
if pointedValue . FieldByName ( column ) . IsValid ( ) {
result := pointedValue . FieldByName ( column ) . Interface ( )
if r , ok := result . ( driver . Valuer ) ; ok {
result , _ = r . Value ( )
}
results = append ( results , result )
2015-08-16 10:10:11 +03:00
}
2015-07-30 13:19:49 +03:00
}
2015-02-11 10:37:04 +03:00
}
2015-07-30 13:19:49 +03:00
return
2015-02-11 10:37:04 +03:00
}
func equalAsString ( a interface { } , b interface { } ) bool {
2016-01-16 06:37:16 +03:00
return toString ( a ) == toString ( b )
2015-02-11 10:37:04 +03:00
}
2015-02-11 08:43:53 +03:00
func Preload ( scope * Scope ) {
2016-01-03 09:21:21 +03:00
if scope . Search . preload == nil || scope . HasError ( ) {
2015-04-21 11:51:52 +03:00
return
}
2015-04-21 10:00:36 +03:00
preloadMap := map [ string ] bool { }
2015-04-21 11:51:52 +03:00
fields := scope . Fields ( )
for _ , preload := range scope . Search . preload {
schema , conditions := preload . schema , preload . conditions
keys := strings . Split ( schema , "." )
currentScope := scope
currentFields := fields
originalConditions := conditions
conditions = [ ] interface { } { }
for i , key := range keys {
var found bool
if preloadMap [ strings . Join ( keys [ : i + 1 ] , "." ) ] {
goto nextLoop
}
if i == len ( keys ) - 1 {
conditions = originalConditions
}
for _ , field := range currentFields {
if field . Name != key || field . Relationship == nil {
continue
2015-04-21 10:00:36 +03:00
}
2015-04-21 11:51:52 +03:00
found = true
switch field . Relationship . Kind {
case "has_one" :
currentScope . handleHasOnePreload ( field , conditions )
case "has_many" :
currentScope . handleHasManyPreload ( field , conditions )
case "belongs_to" :
currentScope . handleBelongsToPreload ( field , conditions )
case "many_to_many" :
2015-11-16 07:19:25 +03:00
currentScope . handleManyToManyPreload ( field , conditions )
2015-04-21 11:51:52 +03:00
default :
currentScope . Err ( errors . New ( "not supported relation" ) )
2015-02-11 08:43:53 +03:00
}
2015-04-21 11:51:52 +03:00
break
}
if ! found {
value := reflect . ValueOf ( currentScope . Value )
if value . Kind ( ) == reflect . Slice && value . Type ( ) . Elem ( ) . Kind ( ) == reflect . Interface {
value = value . Index ( 0 ) . Elem ( )
}
2015-06-11 17:14:36 +03:00
scope . Err ( fmt . Errorf ( "can't find field %s in %s" , key , value . Type ( ) ) )
2015-04-21 11:51:52 +03:00
return
}
preloadMap [ strings . Join ( keys [ : i + 1 ] , "." ) ] = true
nextLoop :
if i < len ( keys ) - 1 {
currentScope = currentScope . getColumnsAsScope ( key )
currentFields = currentScope . Fields ( )
2015-02-11 08:43:53 +03:00
}
}
}
2015-04-21 11:51:52 +03:00
2015-02-11 08:43:53 +03:00
}
2015-02-17 17:55:14 +03:00
func makeSlice ( typ reflect . Type ) interface { } {
if typ . Kind ( ) == reflect . Slice {
2015-02-11 08:43:53 +03:00
typ = typ . Elem ( )
}
sliceType := reflect . SliceOf ( typ )
slice := reflect . New ( sliceType )
slice . Elem ( ) . Set ( reflect . MakeSlice ( sliceType , 0 , 0 ) )
return slice . Interface ( )
}
2015-04-21 11:51:52 +03:00
func ( scope * Scope ) handleHasOnePreload ( field * Field , conditions [ ] interface { } ) {
2015-07-30 13:19:49 +03:00
relation := field . Relationship
2015-07-30 17:36:04 +03:00
2015-07-30 13:19:49 +03:00
primaryKeys := scope . getColumnAsArray ( relation . AssociationForeignFieldNames )
2015-04-21 11:51:52 +03:00
if len ( primaryKeys ) == 0 {
return
}
results := makeSlice ( field . Struct . Type )
2015-07-30 13:19:49 +03:00
scope . Err ( scope . NewDB ( ) . Where ( fmt . Sprintf ( "%v IN (%v)" , toQueryCondition ( scope , relation . ForeignDBNames ) , toQueryMarks ( primaryKeys ) ) , toQueryValues ( primaryKeys ) ... ) . Find ( results , conditions ... ) . Error )
2015-04-22 10:36:10 +03:00
resultValues := reflect . Indirect ( reflect . ValueOf ( results ) )
2015-04-21 11:51:52 +03:00
for i := 0 ; i < resultValues . Len ( ) ; i ++ {
result := resultValues . Index ( i )
if scope . IndirectValue ( ) . Kind ( ) == reflect . Slice {
2015-07-30 13:19:49 +03:00
value := getRealValue ( result , relation . ForeignFieldNames )
2015-04-21 11:51:52 +03:00
objects := scope . IndirectValue ( )
for j := 0 ; j < objects . Len ( ) ; j ++ {
2015-07-30 13:19:49 +03:00
if equalAsString ( getRealValue ( objects . Index ( j ) , relation . AssociationForeignFieldNames ) , value ) {
2015-04-21 11:51:52 +03:00
reflect . Indirect ( objects . Index ( j ) ) . FieldByName ( field . Name ) . Set ( result )
break
}
}
} else {
2015-04-22 10:36:10 +03:00
if err := scope . SetColumn ( field , result ) ; err != nil {
2015-04-21 11:51:52 +03:00
scope . Err ( err )
return
}
}
}
}
func ( scope * Scope ) handleHasManyPreload ( field * Field , conditions [ ] interface { } ) {
2015-07-30 13:19:49 +03:00
relation := field . Relationship
primaryKeys := scope . getColumnAsArray ( relation . AssociationForeignFieldNames )
2015-04-21 11:51:52 +03:00
if len ( primaryKeys ) == 0 {
return
}
results := makeSlice ( field . Struct . Type )
2015-07-30 13:19:49 +03:00
scope . Err ( scope . NewDB ( ) . Where ( fmt . Sprintf ( "%v IN (%v)" , toQueryCondition ( scope , relation . ForeignDBNames ) , toQueryMarks ( primaryKeys ) ) , toQueryValues ( primaryKeys ) ... ) . Find ( results , conditions ... ) . Error )
2015-04-22 10:36:10 +03:00
resultValues := reflect . Indirect ( reflect . ValueOf ( results ) )
2015-04-21 11:51:52 +03:00
if scope . IndirectValue ( ) . Kind ( ) == reflect . Slice {
for i := 0 ; i < resultValues . Len ( ) ; i ++ {
result := resultValues . Index ( i )
2015-07-30 13:19:49 +03:00
value := getRealValue ( result , relation . ForeignFieldNames )
2015-04-21 11:51:52 +03:00
objects := scope . IndirectValue ( )
for j := 0 ; j < objects . Len ( ) ; j ++ {
object := reflect . Indirect ( objects . Index ( j ) )
2015-07-30 13:19:49 +03:00
if equalAsString ( getRealValue ( object , relation . AssociationForeignFieldNames ) , value ) {
2016-02-04 23:19:29 +03:00
if object . Kind ( ) == reflect . Ptr {
object = object . Elem ( )
}
2015-04-21 11:51:52 +03:00
f := object . FieldByName ( field . Name )
f . Set ( reflect . Append ( f , result ) )
break
}
}
}
} else {
scope . SetColumn ( field , resultValues )
}
}
func ( scope * Scope ) handleBelongsToPreload ( field * Field , conditions [ ] interface { } ) {
relation := field . Relationship
2015-07-30 13:19:49 +03:00
primaryKeys := scope . getColumnAsArray ( relation . ForeignFieldNames )
2015-04-21 11:51:52 +03:00
if len ( primaryKeys ) == 0 {
return
}
results := makeSlice ( field . Struct . Type )
2015-07-30 17:36:04 +03:00
scope . Err ( scope . NewDB ( ) . Where ( fmt . Sprintf ( "%v IN (%v)" , toQueryCondition ( scope , relation . AssociationForeignDBNames ) , toQueryMarks ( primaryKeys ) ) , toQueryValues ( primaryKeys ) ... ) . Find ( results , conditions ... ) . Error )
2015-04-22 10:36:10 +03:00
resultValues := reflect . Indirect ( reflect . ValueOf ( results ) )
2015-04-21 11:51:52 +03:00
for i := 0 ; i < resultValues . Len ( ) ; i ++ {
result := resultValues . Index ( i )
if scope . IndirectValue ( ) . Kind ( ) == reflect . Slice {
2015-07-30 13:19:49 +03:00
value := getRealValue ( result , relation . AssociationForeignFieldNames )
2015-04-21 11:51:52 +03:00
objects := scope . IndirectValue ( )
for j := 0 ; j < objects . Len ( ) ; j ++ {
object := reflect . Indirect ( objects . Index ( j ) )
2016-01-17 15:59:15 +03:00
if object . Kind ( ) == reflect . Ptr {
object = reflect . Indirect ( objects . Index ( j ) . Elem ( ) )
}
2015-07-30 13:19:49 +03:00
if equalAsString ( getRealValue ( object , relation . ForeignFieldNames ) , value ) {
2015-04-21 11:51:52 +03:00
object . FieldByName ( field . Name ) . Set ( result )
}
}
} else {
scope . SetColumn ( field , result )
}
}
}
2015-11-16 07:19:25 +03:00
func ( scope * Scope ) handleManyToManyPreload ( field * Field , conditions [ ] interface { } ) {
2015-08-16 10:10:11 +03:00
relation := field . Relationship
joinTableHandler := relation . JoinTableHandler
2015-08-17 18:09:07 +03:00
destType := field . StructField . Struct . Type . Elem ( )
var isPtr bool
if destType . Kind ( ) == reflect . Ptr {
isPtr = true
destType = destType . Elem ( )
}
2015-08-16 10:10:11 +03:00
var sourceKeys [ ] string
2015-08-18 02:43:08 +03:00
var linkHash = make ( map [ string ] [ ] reflect . Value )
2015-08-16 10:10:11 +03:00
for _ , key := range joinTableHandler . SourceForeignKeys ( ) {
sourceKeys = append ( sourceKeys , key . DBName )
}
2015-10-01 02:43:38 +03:00
db := scope . NewDB ( ) . Table ( scope . New ( reflect . New ( destType ) . Interface ( ) ) . TableName ( ) ) . Select ( "*" )
2015-12-09 05:40:12 +03:00
2015-08-18 03:05:44 +03:00
preloadJoinDB := joinTableHandler . JoinWith ( joinTableHandler , db , scope . Value )
2015-11-16 07:19:25 +03:00
2015-08-17 18:09:07 +03:00
if len ( conditions ) > 0 {
preloadJoinDB = preloadJoinDB . Where ( conditions [ 0 ] , conditions [ 1 : ] ... )
}
rows , err := preloadJoinDB . Rows ( )
2015-08-16 10:10:11 +03:00
if scope . Err ( err ) != nil {
return
}
defer rows . Close ( )
columns , _ := rows . Columns ( )
for rows . Next ( ) {
elem := reflect . New ( destType ) . Elem ( )
var values = make ( [ ] interface { } , len ( columns ) )
fields := scope . New ( elem . Addr ( ) . Interface ( ) ) . Fields ( )
2015-12-09 05:40:12 +03:00
var foundFields = map [ string ] bool { }
2015-08-16 10:10:11 +03:00
for index , column := range columns {
2015-12-09 05:40:12 +03:00
if field , ok := fields [ column ] ; ok && ! foundFields [ column ] {
2015-08-16 10:10:11 +03:00
if field . Field . Kind ( ) == reflect . Ptr {
values [ index ] = field . Field . Addr ( ) . Interface ( )
} else {
values [ index ] = reflect . New ( reflect . PtrTo ( field . Field . Type ( ) ) ) . Interface ( )
}
2015-12-09 05:40:12 +03:00
foundFields [ column ] = true
2015-08-16 10:10:11 +03:00
} else {
var i interface { }
values [ index ] = & i
}
}
scope . Err ( rows . Scan ( values ... ) )
var sourceKey [ ] interface { }
2015-12-09 05:40:12 +03:00
var scannedFields = map [ string ] bool { }
2015-08-16 10:10:11 +03:00
for index , column := range columns {
value := values [ index ]
2015-12-09 05:40:12 +03:00
if field , ok := fields [ column ] ; ok && ! scannedFields [ column ] {
2015-08-16 10:10:11 +03:00
if field . Field . Kind ( ) == reflect . Ptr {
field . Field . Set ( reflect . ValueOf ( value ) . Elem ( ) )
} else if v := reflect . ValueOf ( value ) . Elem ( ) . Elem ( ) ; v . IsValid ( ) {
field . Field . Set ( v )
}
2015-12-09 05:40:12 +03:00
scannedFields [ column ] = true
2015-08-16 10:10:11 +03:00
} else if strInSlice ( column , sourceKeys ) {
sourceKey = append ( sourceKey , * ( value . ( * interface { } ) ) )
}
}
2015-08-17 22:28:40 +03:00
if len ( sourceKey ) != 0 {
if isPtr {
linkHash [ toString ( sourceKey ) ] = append ( linkHash [ toString ( sourceKey ) ] , elem . Addr ( ) )
} else {
linkHash [ toString ( sourceKey ) ] = append ( linkHash [ toString ( sourceKey ) ] , elem )
}
2015-08-17 18:09:07 +03:00
}
2015-08-16 10:10:11 +03:00
}
2016-01-03 12:20:24 +03:00
var foreignFieldNames [ ] string
for _ , dbName := range relation . ForeignFieldNames {
2015-08-18 04:08:33 +03:00
if field , ok := scope . FieldByName ( dbName ) ; ok {
2016-01-03 12:20:24 +03:00
foreignFieldNames = append ( foreignFieldNames , field . Name )
2015-08-18 04:08:33 +03:00
}
}
2015-08-16 10:10:11 +03:00
if scope . IndirectValue ( ) . Kind ( ) == reflect . Slice {
objects := scope . IndirectValue ( )
for j := 0 ; j < objects . Len ( ) ; j ++ {
object := reflect . Indirect ( objects . Index ( j ) )
2016-01-03 12:20:24 +03:00
source := getRealValue ( object , foreignFieldNames )
2015-08-18 02:43:08 +03:00
field := object . FieldByName ( field . Name )
for _ , link := range linkHash [ toString ( source ) ] {
field . Set ( reflect . Append ( field , link ) )
2015-08-16 10:10:11 +03:00
}
}
2015-08-16 12:25:25 +03:00
} else {
2015-12-16 11:58:45 +03:00
if object := scope . IndirectValue ( ) ; object . IsValid ( ) {
2016-01-03 12:20:24 +03:00
source := getRealValue ( object , foreignFieldNames )
2015-12-16 11:58:45 +03:00
field := object . FieldByName ( field . Name )
for _ , link := range linkHash [ toString ( source ) ] {
field . Set ( reflect . Append ( field , link ) )
}
2015-08-16 12:25:25 +03:00
}
2015-08-16 10:10:11 +03:00
}
}
2015-07-30 13:19:49 +03:00
func ( scope * Scope ) getColumnAsArray ( columns [ ] string ) ( results [ ] [ ] interface { } ) {
2015-02-11 08:43:53 +03:00
values := scope . IndirectValue ( )
switch values . Kind ( ) {
case reflect . Slice :
for i := 0 ; i < values . Len ( ) ; i ++ {
2015-07-30 13:19:49 +03:00
var result [ ] interface { }
for _ , column := range columns {
2016-01-17 15:59:15 +03:00
value := reflect . Indirect ( values . Index ( i ) )
if value . Kind ( ) == reflect . Ptr {
value = reflect . Indirect ( values . Index ( i ) . Elem ( ) )
}
result = append ( result , value . FieldByName ( column ) . Interface ( ) )
2015-07-30 13:19:49 +03:00
}
results = append ( results , result )
2015-02-11 08:43:53 +03:00
}
case reflect . Struct :
2015-07-30 13:19:49 +03:00
var result [ ] interface { }
for _ , column := range columns {
result = append ( result , values . FieldByName ( column ) . Interface ( ) )
}
return [ ] [ ] interface { } { result }
2015-02-11 08:43:53 +03:00
}
return
}
2015-04-21 10:00:36 +03:00
func ( scope * Scope ) getColumnsAsScope ( column string ) * Scope {
values := scope . IndirectValue ( )
switch values . Kind ( ) {
case reflect . Slice :
2015-04-22 10:36:10 +03:00
modelType := values . Type ( ) . Elem ( )
if modelType . Kind ( ) == reflect . Ptr {
modelType = modelType . Elem ( )
2015-04-21 11:51:52 +03:00
}
2015-04-22 10:36:10 +03:00
fieldStruct , _ := modelType . FieldByName ( column )
2015-04-21 10:00:36 +03:00
var columns reflect . Value
2015-06-05 13:54:52 +03:00
if fieldStruct . Type . Kind ( ) == reflect . Slice || fieldStruct . Type . Kind ( ) == reflect . Ptr {
2015-04-22 10:36:10 +03:00
columns = reflect . New ( reflect . SliceOf ( reflect . PtrTo ( fieldStruct . Type . Elem ( ) ) ) ) . Elem ( )
2015-04-21 10:00:36 +03:00
} else {
2015-04-22 10:36:10 +03:00
columns = reflect . New ( reflect . SliceOf ( reflect . PtrTo ( fieldStruct . Type ) ) ) . Elem ( )
2015-04-21 10:00:36 +03:00
}
for i := 0 ; i < values . Len ( ) ; i ++ {
column := reflect . Indirect ( values . Index ( i ) ) . FieldByName ( column )
2015-06-05 13:54:52 +03:00
if column . Kind ( ) == reflect . Ptr {
column = column . Elem ( )
}
2015-04-21 10:00:36 +03:00
if column . Kind ( ) == reflect . Slice {
for i := 0 ; i < column . Len ( ) ; i ++ {
2015-09-24 14:34:51 +03:00
elem := column . Index ( i )
if elem . CanAddr ( ) {
columns = reflect . Append ( columns , elem . Addr ( ) )
}
2015-04-21 10:00:36 +03:00
}
} else {
2015-09-24 14:34:51 +03:00
if column . CanAddr ( ) {
columns = reflect . Append ( columns , column . Addr ( ) )
}
2015-04-21 10:00:36 +03:00
}
}
return scope . New ( columns . Interface ( ) )
case reflect . Struct :
2015-09-24 14:34:51 +03:00
field := values . FieldByName ( column )
if ! field . CanAddr ( ) {
return nil
}
return scope . New ( field . Addr ( ) . Interface ( ) )
2015-04-21 10:00:36 +03:00
}
return nil
}