Refactor nested preload all associations

This commit is contained in:
Jinzhu 2021-03-14 10:18:43 +08:00
parent c575a4e719
commit 2055e29eb8
3 changed files with 17 additions and 23 deletions

View File

@ -175,15 +175,6 @@ func Preload(db *gorm.DB) {
if db.Error == nil && len(db.Statement.Preloads) > 0 { if db.Error == nil && len(db.Statement.Preloads) > 0 {
preloadMap := map[string]map[string][]interface{}{} preloadMap := map[string]map[string][]interface{}{}
for name := range db.Statement.Preloads { for name := range db.Statement.Preloads {
if name == clause.Associations {
for _, rel := range db.Statement.Schema.Relationships.Relations {
if rel.Schema == db.Statement.Schema {
if _, ok := preloadMap[rel.Name]; !ok {
preloadMap[rel.Name] = map[string][]interface{}{}
}
}
}
} else {
preloadFields := strings.Split(name, ".") preloadFields := strings.Split(name, ".")
if preloadFields[0] == clause.Associations { if preloadFields[0] == clause.Associations {
for _, rel := range db.Statement.Schema.Relationships.Relations { for _, rel := range db.Statement.Schema.Relationships.Relations {
@ -207,7 +198,6 @@ func Preload(db *gorm.DB) {
} }
} }
} }
}
preloadNames := make([]string, 0, len(preloadMap)) preloadNames := make([]string, 0, len(preloadMap))
for key := range preloadMap { for key := range preloadMap {

View File

@ -7,11 +7,11 @@ require (
github.com/jinzhu/now v1.1.1 github.com/jinzhu/now v1.1.1
github.com/lib/pq v1.6.0 github.com/lib/pq v1.6.0
github.com/stretchr/testify v1.5.1 github.com/stretchr/testify v1.5.1
gorm.io/driver/mysql v1.0.4 gorm.io/driver/mysql v1.0.5
gorm.io/driver/postgres v1.0.8 gorm.io/driver/postgres v1.0.8
gorm.io/driver/sqlite v1.1.4 gorm.io/driver/sqlite v1.1.4
gorm.io/driver/sqlserver v1.0.6 gorm.io/driver/sqlserver v1.0.6
gorm.io/gorm v1.20.12 gorm.io/gorm v1.21.3
) )
replace gorm.io/gorm => ../ replace gorm.io/gorm => ../

View File

@ -65,6 +65,10 @@ func TestNestedPreload(t *testing.T) {
DB.Preload("Pets.Toy").Find(&user2, "id = ?", user.ID) DB.Preload("Pets.Toy").Find(&user2, "id = ?", user.ID)
CheckUser(t, user2, user) CheckUser(t, user2, user)
var user3 User
DB.Preload(clause.Associations+"."+clause.Associations).Find(&user3, "id = ?", user.ID)
CheckUser(t, user3, user)
} }
func TestNestedPreloadForSlice(t *testing.T) { func TestNestedPreloadForSlice(t *testing.T) {