fix: nested preload with join panic when find (#6877)

This commit is contained in:
black-06 2024-03-09 21:27:19 +08:00 committed by GitHub
parent c4c9aa45e3
commit e4e23d26d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 4 deletions

View File

@ -121,10 +121,23 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati
}
} else if rel := relationships.Relations[name]; rel != nil {
if joined, nestedJoins := isJoined(name); joined {
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
switch rv := db.Statement.ReflectValue; rv.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < rv.Len(); i++ {
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
}
case reflect.Struct:
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
default:
return gorm.ErrInvalidData
}
} else {
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})

View File

@ -8,6 +8,8 @@ import (
"sync"
"testing"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests"
@ -362,6 +364,14 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) {
t.Errorf("failed to find value, got err: %v", err)
}
AssertEqual(t, find2, value)
var finds []Value
err = DB.Joins("Nested.Join").Joins("Nested").Preload("Nested.Preloads").Find(&finds).Error
if err != nil {
t.Errorf("failed to find value, got err: %v", err)
}
require.Len(t, finds, 1)
AssertEqual(t, finds[0], value)
}
func TestEmbedPreload(t *testing.T) {