diff --git a/preload.go b/preload.go index 02e9eb38..7ca8ca55 100644 --- a/preload.go +++ b/preload.go @@ -34,62 +34,68 @@ func Preload(scope *Scope) { switch relation.Kind { case "has_one": - condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) - scope.NewDB().Where(condition, scope.getColumnAsArray(primaryName)).Find(results, conditions...) + if primaryKeys := scope.getColumnAsArray(primaryName); len(primaryKeys) > 0 { + condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) + scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - if isSlice { - value := getRealValue(result, relation.ForeignFieldName) - objects := scope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - if equalAsString(getRealValue(objects.Index(j), primaryName), value) { - reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) - break + resultValues := reflect.Indirect(reflect.ValueOf(results)) + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + if isSlice { + value := getRealValue(result, relation.ForeignFieldName) + objects := scope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + if equalAsString(getRealValue(objects.Index(j), primaryName), value) { + reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) + break + } } + } else { + scope.SetColumn(field, result) } - } else { - scope.SetColumn(field, result) } } case "has_many": - condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) - scope.NewDB().Where(condition, scope.getColumnAsArray(primaryName)).Find(results, conditions...) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - if isSlice { - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - value := getRealValue(result, relation.ForeignFieldName) - objects := scope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, primaryName), value) { - f := object.FieldByName(field.Name) - f.Set(reflect.Append(f, result)) - break - } - } - } - } else { - scope.SetColumn(field, resultValues) - } - case "belongs_to": - scope.NewDB().Where(scope.getColumnAsArray(relation.ForeignFieldName)).Find(results, conditions...) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) + if primaryKeys := scope.getColumnAsArray(primaryName); len(primaryKeys) > 0 { + condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) + scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) + resultValues := reflect.Indirect(reflect.ValueOf(results)) if isSlice { - value := getRealValue(result, associationPrimaryKey) - objects := scope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, relation.ForeignFieldName), value) { - object.FieldByName(field.Name).Set(result) + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + value := getRealValue(result, relation.ForeignFieldName) + objects := scope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + object := reflect.Indirect(objects.Index(j)) + if equalAsString(getRealValue(object, primaryName), value) { + f := object.FieldByName(field.Name) + f.Set(reflect.Append(f, result)) + break + } } } } else { - scope.SetColumn(field, result) + scope.SetColumn(field, resultValues) + } + } + case "belongs_to": + if primaryKeys := scope.getColumnAsArray(relation.ForeignFieldName); len(primaryKeys) > 0 { + scope.NewDB().Where(primaryKeys).Find(results, conditions...) + resultValues := reflect.Indirect(reflect.ValueOf(results)) + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + if isSlice { + value := getRealValue(result, associationPrimaryKey) + objects := scope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + object := reflect.Indirect(objects.Index(j)) + if equalAsString(getRealValue(object, relation.ForeignFieldName), value) { + object.FieldByName(field.Name).Set(result) + } + } + } else { + scope.SetColumn(field, result) + } } } case "many_to_many":