diff --git a/preload.go b/preload.go index cfc65380..2333cade 100644 --- a/preload.go +++ b/preload.go @@ -95,10 +95,10 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) for i := 0; i < resultsValue.Len(); i++ { result := resultsValue.Index(i) if indirectScopeValue.Kind() == reflect.Slice { - value := getValueFromFields(result, relation.ForeignFieldNames) + foreignValues := getValueFromFields(result, relation.ForeignFieldNames) for j := 0; j < indirectScopeValue.Len(); j++ { - if equalAsString(getValueFromFields(indirectScopeValue.Index(j), relation.AssociationForeignFieldNames), value) { - reflect.Indirect(indirectScopeValue.Index(j)).FieldByName(field.Name).Set(result) + if indirectValue := reflect.Indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { + indirectValue.FieldByName(field.Name).Set(result) break } } @@ -110,58 +110,72 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { relation := field.Relationship + + // get relations's primary keys primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) if len(primaryKeys) == 0 { return } + // find relations results := makeSlice(field.Struct.Type) scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - if scope.IndirectValue().Kind() == reflect.Slice { - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - value := getValueFromFields(result, relation.ForeignFieldNames) - objects := scope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - object := reflect.Indirect(objects.Index(j)) - if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), value) { - f := object.FieldByName(field.Name) - f.Set(reflect.Append(f, result)) + // assign find results + var ( + resultsValue = reflect.Indirect(reflect.ValueOf(results)) + indirectScopeValue = scope.IndirectValue() + ) + + if indirectScopeValue.Kind() == reflect.Slice { + for i := 0; i < resultsValue.Len(); i++ { + result := resultsValue.Index(i) + foreignValues := getValueFromFields(result, relation.ForeignFieldNames) + for j := 0; j < indirectScopeValue.Len(); j++ { + object := reflect.Indirect(indirectScopeValue.Index(j)) + if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), foreignValues) { + objectField := object.FieldByName(field.Name) + objectField.Set(reflect.Append(objectField, result)) break } } } } else { - scope.SetColumn(field, resultValues) + scope.Err(field.Set(resultsValue)) } } func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { relation := field.Relationship + + // get relations's primary keys primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames) if len(primaryKeys) == 0 { return } + // find relations results := makeSlice(field.Struct.Type) scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - if scope.IndirectValue().Kind() == reflect.Slice { + // assign find results + var ( + resultsValue = reflect.Indirect(reflect.ValueOf(results)) + indirectScopeValue = scope.IndirectValue() + ) + + for i := 0; i < resultsValue.Len(); i++ { + result := resultsValue.Index(i) + if indirectScopeValue.Kind() == reflect.Slice { value := getValueFromFields(result, relation.AssociationForeignFieldNames) - objects := scope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - object := reflect.Indirect(objects.Index(j)) + for j := 0; j < indirectScopeValue.Len(); j++ { + object := reflect.Indirect(indirectScopeValue.Index(j)) if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { object.FieldByName(field.Name).Set(result) } } } else { - scope.SetColumn(field, result) + scope.Err(field.Set(result)) } } } @@ -170,24 +184,25 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface var ( relation = field.Relationship joinTableHandler = relation.JoinTableHandler - destType = field.StructField.Struct.Type.Elem() - linkHash = make(map[string][]reflect.Value) - sourceKeys = []string{} + fieldType = field.StructField.Struct.Type.Elem() foreignKeyValue interface{} foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type() + linkHash = map[string][]reflect.Value{} isPtr bool ) - if destType.Kind() == reflect.Ptr { + if fieldType.Kind() == reflect.Ptr { isPtr = true - destType = destType.Elem() + fieldType = fieldType.Elem() } + var sourceKeys = []string{} for _, key := range joinTableHandler.SourceForeignKeys() { sourceKeys = append(sourceKeys, key.DBName) } - preloadJoinDB := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()).Select("*") + // generate query with join table + preloadJoinDB := scope.NewDB().Table(scope.New(reflect.New(fieldType).Interface()).TableName()).Select("*") preloadJoinDB = joinTableHandler.JoinWith(joinTableHandler, preloadJoinDB, scope.Value) // preload inline conditions @@ -205,7 +220,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface columns, _ := rows.Columns() for rows.Next() { var ( - elem = reflect.New(destType).Elem() + elem = reflect.New(fieldType).Elem() fields = scope.New(elem.Addr().Interface()).Fields() ) @@ -235,10 +250,11 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface indirectScopeValue = scope.IndirectValue() fieldsSourceMap = map[string]reflect.Value{} foreignFieldNames = []string{} + fields = scope.Fields() ) for _, dbName := range relation.ForeignFieldNames { - if field, ok := scope.FieldByName(dbName); ok { + if field, ok := fields[dbName]; ok { foreignFieldNames = append(foreignFieldNames, field.Name) } } @@ -256,60 +272,3 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface fieldsSourceMap[source].Set(reflect.Append(fieldsSourceMap[source], link...)) } } - -func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) { - values := scope.IndirectValue() - switch values.Kind() { - case reflect.Slice: - for i := 0; i < values.Len(); i++ { - var result []interface{} - for _, column := range columns { - result = append(result, reflect.Indirect(values.Index(i)).FieldByName(column).Interface()) - } - results = append(results, result) - } - case reflect.Struct: - var result []interface{} - for _, column := range columns { - result = append(result, values.FieldByName(column).Interface()) - } - return [][]interface{}{result} - } - return -} - -func (scope *Scope) getColumnAsScope(column string) *Scope { - indirectScopeValue := scope.IndirectValue() - - switch indirectScopeValue.Kind() { - case reflect.Slice: - if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok { - fieldType := fieldStruct.Type - if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr { - fieldType = fieldType.Elem() - } - - results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() - - for i := 0; i < indirectScopeValue.Len(); i++ { - result := reflect.Indirect(reflect.Indirect(indirectScopeValue.Index(i)).FieldByName(column)) - - if result.Kind() == reflect.Slice { - for j := 0; j < result.Len(); j++ { - if elem := result.Index(j); elem.CanAddr() { - results = reflect.Append(results, elem.Addr()) - } - } - } else if result.CanAddr() { - results = reflect.Append(results, result.Addr()) - } - } - return scope.New(results.Interface()) - } - case reflect.Struct: - if field := indirectScopeValue.FieldByName(column); field.CanAddr() { - return scope.New(field.Addr().Interface()) - } - } - return nil -} diff --git a/scope_utils.go b/scope_utils.go new file mode 100644 index 00000000..99957adc --- /dev/null +++ b/scope_utils.go @@ -0,0 +1,61 @@ +package gorm + +import "reflect" + +func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) { + indirectScopeValue := scope.IndirectValue() + switch indirectScopeValue.Kind() { + case reflect.Slice: + for i := 0; i < indirectScopeValue.Len(); i++ { + var result []interface{} + var object = reflect.Indirect(indirectScopeValue.Index(i)) + for _, column := range columns { + result = append(result, object.FieldByName(column).Interface()) + } + results = append(results, result) + } + case reflect.Struct: + var result []interface{} + for _, column := range columns { + result = append(result, indirectScopeValue.FieldByName(column).Interface()) + } + return [][]interface{}{result} + } + return +} + +func (scope *Scope) getColumnAsScope(column string) *Scope { + indirectScopeValue := scope.IndirectValue() + + switch indirectScopeValue.Kind() { + case reflect.Slice: + if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok { + fieldType := fieldStruct.Type + if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() + } + + results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() + + for i := 0; i < indirectScopeValue.Len(); i++ { + result := reflect.Indirect(reflect.Indirect(indirectScopeValue.Index(i)).FieldByName(column)) + + if result.Kind() == reflect.Slice { + for j := 0; j < result.Len(); j++ { + if elem := result.Index(j); elem.CanAddr() { + results = reflect.Append(results, elem.Addr()) + } + } + } else if result.CanAddr() { + results = reflect.Append(results, result.Addr()) + } + } + return scope.New(results.Interface()) + } + case reflect.Struct: + if field := indirectScopeValue.FieldByName(column); field.CanAddr() { + return scope.New(field.Addr().Interface()) + } + } + return nil +}