Refactor nested preload all associations

This commit is contained in:
Jinzhu 2021-03-14 10:18:43 +08:00
parent c575a4e719
commit 2055e29eb8
3 changed files with 17 additions and 23 deletions

View File

@ -175,36 +175,26 @@ func Preload(db *gorm.DB) {
if db.Error == nil && len(db.Statement.Preloads) > 0 { if db.Error == nil && len(db.Statement.Preloads) > 0 {
preloadMap := map[string]map[string][]interface{}{} preloadMap := map[string]map[string][]interface{}{}
for name := range db.Statement.Preloads { for name := range db.Statement.Preloads {
if name == clause.Associations { preloadFields := strings.Split(name, ".")
if preloadFields[0] == clause.Associations {
for _, rel := range db.Statement.Schema.Relationships.Relations { for _, rel := range db.Statement.Schema.Relationships.Relations {
if rel.Schema == db.Statement.Schema { if rel.Schema == db.Statement.Schema {
if _, ok := preloadMap[rel.Name]; !ok { if _, ok := preloadMap[rel.Name]; !ok {
preloadMap[rel.Name] = map[string][]interface{}{} preloadMap[rel.Name] = map[string][]interface{}{}
} }
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
preloadMap[rel.Name][value] = db.Statement.Preloads[name]
}
} }
} }
} else { } else {
preloadFields := strings.Split(name, ".") if _, ok := preloadMap[preloadFields[0]]; !ok {
if preloadFields[0] == clause.Associations { preloadMap[preloadFields[0]] = map[string][]interface{}{}
for _, rel := range db.Statement.Schema.Relationships.Relations { }
if rel.Schema == db.Statement.Schema {
if _, ok := preloadMap[rel.Name]; !ok {
preloadMap[rel.Name] = map[string][]interface{}{}
}
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
preloadMap[rel.Name][value] = db.Statement.Preloads[name] preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name]
}
}
}
} else {
if _, ok := preloadMap[preloadFields[0]]; !ok {
preloadMap[preloadFields[0]] = map[string][]interface{}{}
}
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name]
}
} }
} }
} }

View File

@ -7,11 +7,11 @@ require (
github.com/jinzhu/now v1.1.1 github.com/jinzhu/now v1.1.1
github.com/lib/pq v1.6.0 github.com/lib/pq v1.6.0
github.com/stretchr/testify v1.5.1 github.com/stretchr/testify v1.5.1
gorm.io/driver/mysql v1.0.4 gorm.io/driver/mysql v1.0.5
gorm.io/driver/postgres v1.0.8 gorm.io/driver/postgres v1.0.8
gorm.io/driver/sqlite v1.1.4 gorm.io/driver/sqlite v1.1.4
gorm.io/driver/sqlserver v1.0.6 gorm.io/driver/sqlserver v1.0.6
gorm.io/gorm v1.20.12 gorm.io/gorm v1.21.3
) )
replace gorm.io/gorm => ../ replace gorm.io/gorm => ../

View File

@ -65,6 +65,10 @@ func TestNestedPreload(t *testing.T) {
DB.Preload("Pets.Toy").Find(&user2, "id = ?", user.ID) DB.Preload("Pets.Toy").Find(&user2, "id = ?", user.ID)
CheckUser(t, user2, user) CheckUser(t, user2, user)
var user3 User
DB.Preload(clause.Associations+"."+clause.Associations).Find(&user3, "id = ?", user.ID)
CheckUser(t, user3, user)
} }
func TestNestedPreloadForSlice(t *testing.T) { func TestNestedPreloadForSlice(t *testing.T) {