From 85299bfca7172489d7f93a5525ee5ab0d92d514b Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Thu, 25 Apr 2024 20:21:03 +0800 Subject: [PATCH] perf: merge nested preload query when using join (#6990) * pref: merge nest preload query * fix: preload test --- callbacks/preload.go | 14 ++++++- tests/preload_test.go | 86 ++++++++++++++++++++++++++++++++++++++----- 2 files changed, 88 insertions(+), 12 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index 09f151c7..112343fa 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -123,8 +123,18 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati if joined, nestedJoins := isJoined(name); joined { switch rv := db.Statement.ReflectValue; rv.Kind() { case reflect.Slice, reflect.Array: - for i := 0; i < rv.Len(); i++ { - reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i)) + if rv.Len() > 0 { + reflectValue := rel.FieldSchema.MakeSlice().Elem() + reflectValue.SetLen(rv.Len()) + for i := 0; i < rv.Len(); i++ { + frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i)) + if frv.Kind() != reflect.Ptr { + reflectValue.Index(i).Set(frv.Addr()) + } else { + reflectValue.Index(i).Set(frv) + } + } + tx := preloadDB(db, reflectValue, reflectValue.Interface()) if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { return err diff --git a/tests/preload_test.go b/tests/preload_test.go index 5c87534f..6e0e91ba 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -1,14 +1,14 @@ package tests_test import ( + "context" "encoding/json" "regexp" "sort" "strconv" "sync" "testing" - - "github.com/stretchr/testify/require" + "time" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -337,7 +337,7 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) { DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{}) DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{}) - value := Value{ + value1 := Value{ Name: "value", Nested: Nested{ Preloads: []*Preload{ @@ -346,32 +346,98 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) { Join: Join{Value: "j1"}, }, } - if err := DB.Create(&value).Error; err != nil { + value2 := Value{ + Name: "value2", + Nested: Nested{ + Preloads: []*Preload{ + {Value: "p3"}, {Value: "p4"}, {Value: "p5"}, + }, + Join: Join{Value: "j2"}, + }, + } + + values := []*Value{&value1, &value2} + if err := DB.Create(&values).Error; err != nil { t.Errorf("failed to create value, got err: %v", err) } var find1 Value - err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1).Error + err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1, value1.ID).Error if err != nil { t.Errorf("failed to find value, got err: %v", err) } - AssertEqual(t, find1, value) + AssertEqual(t, find1, value1) var find2 Value // Joins will automatically add Nested queries. - err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2).Error + err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2, value2.ID).Error if err != nil { t.Errorf("failed to find value, got err: %v", err) } - AssertEqual(t, find2, value) + AssertEqual(t, find2, value2) var finds []Value err = DB.Joins("Nested.Join").Joins("Nested").Preload("Nested.Preloads").Find(&finds).Error if err != nil { t.Errorf("failed to find value, got err: %v", err) } - require.Len(t, finds, 1) - AssertEqual(t, finds[0], value) + AssertEqual(t, len(finds), 2) + AssertEqual(t, finds[0], value1) + AssertEqual(t, finds[1], value2) +} + +func TestMergeNestedPreloadWithNestedJoin(t *testing.T) { + users := []User{ + { + Name: "TestMergeNestedPreloadWithNestedJoin-1", + Manager: &User{ + Name: "Alexis Manager", + Tools: []Tools{ + {Name: "Alexis Tool 1"}, + {Name: "Alexis Tool 2"}, + }, + }, + }, + { + Name: "TestMergeNestedPreloadWithNestedJoin-2", + Manager: &User{ + Name: "Jinzhu Manager", + Tools: []Tools{ + {Name: "Jinzhu Tool 1"}, + {Name: "Jinzhu Tool 2"}, + }, + }, + }, + } + + DB.Create(&users) + + query := make([]string, 0) + sess := DB.Session(&gorm.Session{Logger: Tracer{ + Logger: DB.Config.Logger, + Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + query = append(query, sql) + }, + }}) + + var result []User + err := sess. + Joins("Manager"). + Preload("Manager.Tools"). + Where("users.name Like ?", "TestMergeNestedPreloadWithNestedJoin%"). + Find(&result).Error + + if err != nil { + t.Fatalf("failed to preload and find users: %v", err) + } + + AssertEqual(t, result, users) + AssertEqual(t, len(query), 2) // Check preload queries are merged + + if !regexp.MustCompile(`SELECT \* FROM .*tools.* WHERE .*IN.*`).MatchString(query[0]) { + t.Fatalf("Expected first query to preload manager tools, got: %s", query[0]) + } } func TestEmbedPreload(t *testing.T) {