2015-02-11 08:43:53 +03:00
package gorm
import (
2015-02-11 10:37:04 +03:00
"database/sql/driver"
2015-02-11 08:43:53 +03:00
"errors"
"fmt"
"reflect"
2015-04-21 10:00:36 +03:00
"strings"
2015-02-11 08:43:53 +03:00
)
2015-07-30 13:19:49 +03:00
func getRealValue ( value reflect . Value , columns [ ] string ) ( results [ ] interface { } ) {
for _ , column := range columns {
2015-08-16 10:10:11 +03:00
if reflect . Indirect ( value ) . FieldByName ( column ) . IsValid ( ) {
result := reflect . Indirect ( value ) . FieldByName ( column ) . Interface ( )
if r , ok := result . ( driver . Valuer ) ; ok {
result , _ = r . Value ( )
}
results = append ( results , result )
2015-07-30 13:19:49 +03:00
}
2015-02-11 10:37:04 +03:00
}
2015-07-30 13:19:49 +03:00
return
2015-02-11 10:37:04 +03:00
}
func equalAsString ( a interface { } , b interface { } ) bool {
return fmt . Sprintf ( "%v" , a ) == fmt . Sprintf ( "%v" , b )
}
2015-02-11 08:43:53 +03:00
func Preload ( scope * Scope ) {
2015-04-21 11:51:52 +03:00
if scope . Search . preload == nil {
return
}
2015-04-21 10:00:36 +03:00
preloadMap := map [ string ] bool { }
2015-04-21 11:51:52 +03:00
fields := scope . Fields ( )
for _ , preload := range scope . Search . preload {
schema , conditions := preload . schema , preload . conditions
keys := strings . Split ( schema , "." )
currentScope := scope
currentFields := fields
originalConditions := conditions
conditions = [ ] interface { } { }
for i , key := range keys {
var found bool
if preloadMap [ strings . Join ( keys [ : i + 1 ] , "." ) ] {
goto nextLoop
}
if i == len ( keys ) - 1 {
conditions = originalConditions
}
for _ , field := range currentFields {
if field . Name != key || field . Relationship == nil {
continue
2015-04-21 10:00:36 +03:00
}
2015-04-21 11:51:52 +03:00
found = true
switch field . Relationship . Kind {
case "has_one" :
currentScope . handleHasOnePreload ( field , conditions )
case "has_many" :
currentScope . handleHasManyPreload ( field , conditions )
case "belongs_to" :
currentScope . handleBelongsToPreload ( field , conditions )
case "many_to_many" :
2015-08-16 10:10:11 +03:00
currentScope . handleHasManyToManyPreload ( field , conditions )
2015-04-21 11:51:52 +03:00
default :
currentScope . Err ( errors . New ( "not supported relation" ) )
2015-02-11 08:43:53 +03:00
}
2015-04-21 11:51:52 +03:00
break
}
if ! found {
value := reflect . ValueOf ( currentScope . Value )
if value . Kind ( ) == reflect . Slice && value . Type ( ) . Elem ( ) . Kind ( ) == reflect . Interface {
value = value . Index ( 0 ) . Elem ( )
}
2015-06-11 17:14:36 +03:00
scope . Err ( fmt . Errorf ( "can't find field %s in %s" , key , value . Type ( ) ) )
2015-04-21 11:51:52 +03:00
return
}
preloadMap [ strings . Join ( keys [ : i + 1 ] , "." ) ] = true
nextLoop :
if i < len ( keys ) - 1 {
currentScope = currentScope . getColumnsAsScope ( key )
currentFields = currentScope . Fields ( )
2015-02-11 08:43:53 +03:00
}
}
}
2015-04-21 11:51:52 +03:00
2015-02-11 08:43:53 +03:00
}
2015-02-17 17:55:14 +03:00
func makeSlice ( typ reflect . Type ) interface { } {
if typ . Kind ( ) == reflect . Slice {
2015-02-11 08:43:53 +03:00
typ = typ . Elem ( )
}
sliceType := reflect . SliceOf ( typ )
slice := reflect . New ( sliceType )
slice . Elem ( ) . Set ( reflect . MakeSlice ( sliceType , 0 , 0 ) )
return slice . Interface ( )
}
2015-04-21 11:51:52 +03:00
func ( scope * Scope ) handleHasOnePreload ( field * Field , conditions [ ] interface { } ) {
2015-07-30 13:19:49 +03:00
relation := field . Relationship
2015-07-30 17:36:04 +03:00
2015-07-30 13:19:49 +03:00
primaryKeys := scope . getColumnAsArray ( relation . AssociationForeignFieldNames )
2015-04-21 11:51:52 +03:00
if len ( primaryKeys ) == 0 {
return
}
results := makeSlice ( field . Struct . Type )
2015-07-30 13:19:49 +03:00
scope . Err ( scope . NewDB ( ) . Where ( fmt . Sprintf ( "%v IN (%v)" , toQueryCondition ( scope , relation . ForeignDBNames ) , toQueryMarks ( primaryKeys ) ) , toQueryValues ( primaryKeys ) ... ) . Find ( results , conditions ... ) . Error )
2015-04-22 10:36:10 +03:00
resultValues := reflect . Indirect ( reflect . ValueOf ( results ) )
2015-04-21 11:51:52 +03:00
for i := 0 ; i < resultValues . Len ( ) ; i ++ {
result := resultValues . Index ( i )
if scope . IndirectValue ( ) . Kind ( ) == reflect . Slice {
2015-07-30 13:19:49 +03:00
value := getRealValue ( result , relation . ForeignFieldNames )
2015-04-21 11:51:52 +03:00
objects := scope . IndirectValue ( )
for j := 0 ; j < objects . Len ( ) ; j ++ {
2015-07-30 13:19:49 +03:00
if equalAsString ( getRealValue ( objects . Index ( j ) , relation . AssociationForeignFieldNames ) , value ) {
2015-04-21 11:51:52 +03:00
reflect . Indirect ( objects . Index ( j ) ) . FieldByName ( field . Name ) . Set ( result )
break
}
}
} else {
2015-04-22 10:36:10 +03:00
if err := scope . SetColumn ( field , result ) ; err != nil {
2015-04-21 11:51:52 +03:00
scope . Err ( err )
return
}
}
}
}
func ( scope * Scope ) handleHasManyPreload ( field * Field , conditions [ ] interface { } ) {
2015-07-30 13:19:49 +03:00
relation := field . Relationship
primaryKeys := scope . getColumnAsArray ( relation . AssociationForeignFieldNames )
2015-04-21 11:51:52 +03:00
if len ( primaryKeys ) == 0 {
return
}
results := makeSlice ( field . Struct . Type )
2015-07-30 13:19:49 +03:00
scope . Err ( scope . NewDB ( ) . Where ( fmt . Sprintf ( "%v IN (%v)" , toQueryCondition ( scope , relation . ForeignDBNames ) , toQueryMarks ( primaryKeys ) ) , toQueryValues ( primaryKeys ) ... ) . Find ( results , conditions ... ) . Error )
2015-04-22 10:36:10 +03:00
resultValues := reflect . Indirect ( reflect . ValueOf ( results ) )
2015-04-21 11:51:52 +03:00
if scope . IndirectValue ( ) . Kind ( ) == reflect . Slice {
for i := 0 ; i < resultValues . Len ( ) ; i ++ {
result := resultValues . Index ( i )
2015-07-30 13:19:49 +03:00
value := getRealValue ( result , relation . ForeignFieldNames )
2015-04-21 11:51:52 +03:00
objects := scope . IndirectValue ( )
for j := 0 ; j < objects . Len ( ) ; j ++ {
object := reflect . Indirect ( objects . Index ( j ) )
2015-07-30 13:19:49 +03:00
if equalAsString ( getRealValue ( object , relation . AssociationForeignFieldNames ) , value ) {
2015-04-21 11:51:52 +03:00
f := object . FieldByName ( field . Name )
f . Set ( reflect . Append ( f , result ) )
break
}
}
}
} else {
scope . SetColumn ( field , resultValues )
}
}
func ( scope * Scope ) handleBelongsToPreload ( field * Field , conditions [ ] interface { } ) {
relation := field . Relationship
2015-07-30 13:19:49 +03:00
primaryKeys := scope . getColumnAsArray ( relation . ForeignFieldNames )
2015-04-21 11:51:52 +03:00
if len ( primaryKeys ) == 0 {
return
}
results := makeSlice ( field . Struct . Type )
2015-07-30 17:36:04 +03:00
scope . Err ( scope . NewDB ( ) . Where ( fmt . Sprintf ( "%v IN (%v)" , toQueryCondition ( scope , relation . AssociationForeignDBNames ) , toQueryMarks ( primaryKeys ) ) , toQueryValues ( primaryKeys ) ... ) . Find ( results , conditions ... ) . Error )
2015-04-22 10:36:10 +03:00
resultValues := reflect . Indirect ( reflect . ValueOf ( results ) )
2015-04-21 11:51:52 +03:00
for i := 0 ; i < resultValues . Len ( ) ; i ++ {
result := resultValues . Index ( i )
if scope . IndirectValue ( ) . Kind ( ) == reflect . Slice {
2015-07-30 13:19:49 +03:00
value := getRealValue ( result , relation . AssociationForeignFieldNames )
2015-04-21 11:51:52 +03:00
objects := scope . IndirectValue ( )
for j := 0 ; j < objects . Len ( ) ; j ++ {
object := reflect . Indirect ( objects . Index ( j ) )
2015-07-30 13:19:49 +03:00
if equalAsString ( getRealValue ( object , relation . ForeignFieldNames ) , value ) {
2015-04-21 11:51:52 +03:00
object . FieldByName ( field . Name ) . Set ( result )
}
}
} else {
scope . SetColumn ( field , result )
}
}
}
2015-08-16 10:10:11 +03:00
func ( scope * Scope ) handleHasManyToManyPreload ( field * Field , conditions [ ] interface { } ) {
relation := field . Relationship
joinTableHandler := relation . JoinTableHandler
destType := joinTableHandler . DestinationType ( )
db := scope . NewDB ( ) . Table ( scope . db . NewScope ( reflect . New ( destType ) . Elem ( ) . Interface ( ) ) . TableName ( ) )
var destKeys [ ] string
var sourceKeys [ ] string
linkHash := make ( map [ string ] [ ] string )
for _ , key := range joinTableHandler . DestinationForeignKeys ( ) {
destKeys = append ( destKeys , key . DBName )
}
for _ , key := range joinTableHandler . SourceForeignKeys ( ) {
sourceKeys = append ( sourceKeys , key . DBName )
}
results := reflect . New ( field . Struct . Type ) . Elem ( )
rows , err := joinTableHandler . PreloadWithJoin ( joinTableHandler , db , scope . Value , conditions ... ) . Rows ( )
if scope . Err ( err ) != nil {
return
}
defer rows . Close ( )
columns , _ := rows . Columns ( )
for rows . Next ( ) {
elem := reflect . New ( destType ) . Elem ( )
var values = make ( [ ] interface { } , len ( columns ) )
fields := scope . New ( elem . Addr ( ) . Interface ( ) ) . Fields ( )
for index , column := range columns {
if field , ok := fields [ column ] ; ok {
if field . Field . Kind ( ) == reflect . Ptr {
values [ index ] = field . Field . Addr ( ) . Interface ( )
} else {
values [ index ] = reflect . New ( reflect . PtrTo ( field . Field . Type ( ) ) ) . Interface ( )
}
} else {
var i interface { }
values [ index ] = & i
}
}
scope . Err ( rows . Scan ( values ... ) )
var destKey [ ] interface { }
var sourceKey [ ] interface { }
for index , column := range columns {
value := values [ index ]
if field , ok := fields [ column ] ; ok {
if field . Field . Kind ( ) == reflect . Ptr {
field . Field . Set ( reflect . ValueOf ( value ) . Elem ( ) )
} else if v := reflect . ValueOf ( value ) . Elem ( ) . Elem ( ) ; v . IsValid ( ) {
field . Field . Set ( v )
}
} else if strInSlice ( column , destKeys ) {
destKey = append ( destKey , * ( value . ( * interface { } ) ) )
} else if strInSlice ( column , sourceKeys ) {
sourceKey = append ( sourceKey , * ( value . ( * interface { } ) ) )
}
}
if len ( destKey ) != 0 && len ( sourceKey ) != 0 {
linkHash [ toString ( sourceKey ) ] = append ( linkHash [ toString ( sourceKey ) ] , toString ( destKey ) )
}
results = reflect . Append ( results , elem )
}
if scope . IndirectValue ( ) . Kind ( ) == reflect . Slice {
objects := scope . IndirectValue ( )
for j := 0 ; j < objects . Len ( ) ; j ++ {
var checked [ ] string
object := reflect . Indirect ( objects . Index ( j ) )
2015-08-16 12:36:23 +03:00
source := getRealValue ( object , relation . AssociationForeignStructFieldNames )
2015-08-16 10:10:11 +03:00
for i := 0 ; i < results . Len ( ) ; i ++ {
result := results . Index ( i )
2015-08-16 12:36:23 +03:00
value := getRealValue ( result , relation . ForeignStructFieldNames )
2015-08-16 10:10:11 +03:00
if strInSlice ( toString ( value ) , linkHash [ toString ( source ) ] ) && ! strInSlice ( toString ( value ) , checked ) {
f := object . FieldByName ( field . Name )
f . Set ( reflect . Append ( f , result ) )
checked = append ( checked , toString ( value ) )
continue
}
}
}
2015-08-16 12:25:25 +03:00
} else {
object := scope . IndirectValue ( )
var checked [ ] string
2015-08-16 12:36:23 +03:00
source := getRealValue ( object , relation . AssociationForeignStructFieldNames )
2015-08-16 12:25:25 +03:00
for i := 0 ; i < results . Len ( ) ; i ++ {
result := results . Index ( i )
2015-08-16 12:36:23 +03:00
value := getRealValue ( result , relation . ForeignStructFieldNames )
2015-08-16 12:25:25 +03:00
if strInSlice ( toString ( value ) , linkHash [ toString ( source ) ] ) && ! strInSlice ( toString ( value ) , checked ) {
f := object . FieldByName ( field . Name )
f . Set ( reflect . Append ( f , result ) )
checked = append ( checked , toString ( value ) )
continue
}
}
2015-08-16 10:10:11 +03:00
}
}
2015-07-30 13:19:49 +03:00
func ( scope * Scope ) getColumnAsArray ( columns [ ] string ) ( results [ ] [ ] interface { } ) {
2015-02-11 08:43:53 +03:00
values := scope . IndirectValue ( )
switch values . Kind ( ) {
case reflect . Slice :
for i := 0 ; i < values . Len ( ) ; i ++ {
2015-07-30 13:19:49 +03:00
var result [ ] interface { }
for _ , column := range columns {
result = append ( result , reflect . Indirect ( values . Index ( i ) ) . FieldByName ( column ) . Interface ( ) )
}
results = append ( results , result )
2015-02-11 08:43:53 +03:00
}
case reflect . Struct :
2015-07-30 13:19:49 +03:00
var result [ ] interface { }
for _ , column := range columns {
result = append ( result , values . FieldByName ( column ) . Interface ( ) )
}
return [ ] [ ] interface { } { result }
2015-02-11 08:43:53 +03:00
}
return
}
2015-04-21 10:00:36 +03:00
func ( scope * Scope ) getColumnsAsScope ( column string ) * Scope {
values := scope . IndirectValue ( )
switch values . Kind ( ) {
case reflect . Slice :
2015-04-22 10:36:10 +03:00
modelType := values . Type ( ) . Elem ( )
if modelType . Kind ( ) == reflect . Ptr {
modelType = modelType . Elem ( )
2015-04-21 11:51:52 +03:00
}
2015-04-22 10:36:10 +03:00
fieldStruct , _ := modelType . FieldByName ( column )
2015-04-21 10:00:36 +03:00
var columns reflect . Value
2015-06-05 13:54:52 +03:00
if fieldStruct . Type . Kind ( ) == reflect . Slice || fieldStruct . Type . Kind ( ) == reflect . Ptr {
2015-04-22 10:36:10 +03:00
columns = reflect . New ( reflect . SliceOf ( reflect . PtrTo ( fieldStruct . Type . Elem ( ) ) ) ) . Elem ( )
2015-04-21 10:00:36 +03:00
} else {
2015-04-22 10:36:10 +03:00
columns = reflect . New ( reflect . SliceOf ( reflect . PtrTo ( fieldStruct . Type ) ) ) . Elem ( )
2015-04-21 10:00:36 +03:00
}
for i := 0 ; i < values . Len ( ) ; i ++ {
column := reflect . Indirect ( values . Index ( i ) ) . FieldByName ( column )
2015-06-05 13:54:52 +03:00
if column . Kind ( ) == reflect . Ptr {
column = column . Elem ( )
}
2015-04-21 10:00:36 +03:00
if column . Kind ( ) == reflect . Slice {
for i := 0 ; i < column . Len ( ) ; i ++ {
columns = reflect . Append ( columns , column . Index ( i ) . Addr ( ) )
}
} else {
columns = reflect . Append ( columns , column . Addr ( ) )
}
}
return scope . New ( columns . Interface ( ) )
case reflect . Struct :
return scope . New ( values . FieldByName ( column ) . Addr ( ) . Interface ( ) )
}
return nil
}