From 2055e29eb81281289673d7ebc612c245fce7c333 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 14 Mar 2021 10:18:43 +0800 Subject: [PATCH] Refactor nested preload all associations --- callbacks/query.go | 32 +++++++++++--------------------- tests/go.mod | 4 ++-- tests/preload_test.go | 4 ++++ 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index df5b4d60..11753472 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -175,36 +175,26 @@ func Preload(db *gorm.DB) { if db.Error == nil && len(db.Statement.Preloads) > 0 { preloadMap := map[string]map[string][]interface{}{} for name := range db.Statement.Preloads { - if name == clause.Associations { + preloadFields := strings.Split(name, ".") + if preloadFields[0] == clause.Associations { for _, rel := range db.Statement.Schema.Relationships.Relations { if rel.Schema == db.Statement.Schema { if _, ok := preloadMap[rel.Name]; !ok { preloadMap[rel.Name] = map[string][]interface{}{} } + + if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { + preloadMap[rel.Name][value] = db.Statement.Preloads[name] + } } } } else { - preloadFields := strings.Split(name, ".") - if preloadFields[0] == clause.Associations { - for _, rel := range db.Statement.Schema.Relationships.Relations { - if rel.Schema == db.Statement.Schema { - if _, ok := preloadMap[rel.Name]; !ok { - preloadMap[rel.Name] = map[string][]interface{}{} - } + if _, ok := preloadMap[preloadFields[0]]; !ok { + preloadMap[preloadFields[0]] = map[string][]interface{}{} + } - if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { - preloadMap[rel.Name][value] = db.Statement.Preloads[name] - } - } - } - } else { - if _, ok := preloadMap[preloadFields[0]]; !ok { - preloadMap[preloadFields[0]] = map[string][]interface{}{} - } - - if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { - preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name] - } + if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { + preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name] } } } diff --git a/tests/go.mod b/tests/go.mod index 20d7206a..0765142c 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,11 +7,11 @@ require ( github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 github.com/stretchr/testify v1.5.1 - gorm.io/driver/mysql v1.0.4 + gorm.io/driver/mysql v1.0.5 gorm.io/driver/postgres v1.0.8 gorm.io/driver/sqlite v1.1.4 gorm.io/driver/sqlserver v1.0.6 - gorm.io/gorm v1.20.12 + gorm.io/gorm v1.21.3 ) replace gorm.io/gorm => ../ diff --git a/tests/preload_test.go b/tests/preload_test.go index 4b31b12c..c9f5d278 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -65,6 +65,10 @@ func TestNestedPreload(t *testing.T) { DB.Preload("Pets.Toy").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + var user3 User + DB.Preload(clause.Associations+"."+clause.Associations).Find(&user3, "id = ?", user.ID) + CheckUser(t, user3, user) } func TestNestedPreloadForSlice(t *testing.T) {