fix: nested preload with join panic when find (#6877)

This commit is contained in:
black-06 2024-03-09 21:27:19 +08:00 committed by GitHub
parent c4c9aa45e3
commit e4e23d26d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 4 deletions

View File

@ -121,10 +121,23 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati
} }
} else if rel := relationships.Relations[name]; rel != nil { } else if rel := relationships.Relations[name]; rel != nil {
if joined, nestedJoins := isJoined(name); joined { if joined, nestedJoins := isJoined(name); joined {
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) switch rv := db.Statement.ReflectValue; rv.Kind() {
tx := preloadDB(db, reflectValue, reflectValue.Interface()) case reflect.Slice, reflect.Array:
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { for i := 0; i < rv.Len(); i++ {
return err reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
}
case reflect.Struct:
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
default:
return gorm.ErrInvalidData
} }
} else { } else {
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}) tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})

View File

@ -8,6 +8,8 @@ import (
"sync" "sync"
"testing" "testing"
"github.com/stretchr/testify/require"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
@ -362,6 +364,14 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) {
t.Errorf("failed to find value, got err: %v", err) t.Errorf("failed to find value, got err: %v", err)
} }
AssertEqual(t, find2, value) AssertEqual(t, find2, value)
var finds []Value
err = DB.Joins("Nested.Join").Joins("Nested").Preload("Nested.Preloads").Find(&finds).Error
if err != nil {
t.Errorf("failed to find value, got err: %v", err)
}
require.Len(t, finds, 1)
AssertEqual(t, finds[0], value)
} }
func TestEmbedPreload(t *testing.T) { func TestEmbedPreload(t *testing.T) {