mirror of https://github.com/go-gorm/gorm.git
fix: preload shouldn't overwrite the value of join (#6771)
* fix: preload shouldn't overwrite the value of join * fix lint * fix: join may automatically add nested query
This commit is contained in:
parent
e043924fe7
commit
418ee3fc19
|
@ -3,6 +3,7 @@ package callbacks
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
@ -82,27 +83,80 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string {
|
||||||
return names
|
return names
|
||||||
}
|
}
|
||||||
|
|
||||||
func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error {
|
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
|
||||||
if relationships == nil {
|
// If the current relationship is embedded or joined, current query will be ignored.
|
||||||
return nil
|
//
|
||||||
|
//nolint:cyclop
|
||||||
|
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
|
||||||
|
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
|
||||||
|
|
||||||
|
// avoid random traversal of the map
|
||||||
|
preloadNames := make([]string, 0, len(preloadMap))
|
||||||
|
for key := range preloadMap {
|
||||||
|
preloadNames = append(preloadNames, key)
|
||||||
}
|
}
|
||||||
preloadMap := parsePreloadMap(s, preloads)
|
sort.Strings(preloadNames)
|
||||||
for name := range preloadMap {
|
|
||||||
if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil {
|
isJoined := func(name string) (joined bool, nestedJoins []string) {
|
||||||
if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil {
|
for _, join := range joins {
|
||||||
|
if _, ok := relationships.Relations[join]; ok && name == join {
|
||||||
|
joined = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
joinNames := strings.SplitN(join, ".", 2)
|
||||||
|
if len(joinNames) == 2 {
|
||||||
|
if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] {
|
||||||
|
joined = true
|
||||||
|
nestedJoins = append(nestedJoins, joinNames[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return joined, nestedJoins
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range preloadNames {
|
||||||
|
if relations := relationships.EmbeddedRelations[name]; relations != nil {
|
||||||
|
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else if rel := relationships.Relations[name]; rel != nil {
|
} else if rel := relationships.Relations[name]; rel != nil {
|
||||||
if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil {
|
if joined, nestedJoins := isJoined(name); joined {
|
||||||
return err
|
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||||
|
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
||||||
|
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
|
||||||
|
tx.Statement.ReflectValue = db.Statement.ReflectValue
|
||||||
|
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||||
|
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name)
|
return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
|
||||||
|
tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
|
||||||
|
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||||
|
tx.Statement.Settings.Store(k, v)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := tx.Statement.Parse(dest); err != nil {
|
||||||
|
tx.AddError(err)
|
||||||
|
return tx
|
||||||
|
}
|
||||||
|
tx.Statement.ReflectValue = reflectValue
|
||||||
|
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||||
|
return tx
|
||||||
|
}
|
||||||
|
|
||||||
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
|
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
|
||||||
var (
|
var (
|
||||||
reflectValue = tx.Statement.ReflectValue
|
reflectValue = tx.Statement.ReflectValue
|
||||||
|
|
|
@ -3,7 +3,6 @@ package callbacks
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
@ -254,7 +253,6 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Statement.AddClause(fromClause)
|
db.Statement.AddClause(fromClause)
|
||||||
db.Statement.Joins = nil
|
|
||||||
} else {
|
} else {
|
||||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||||
}
|
}
|
||||||
|
@ -272,38 +270,23 @@ func Preload(db *gorm.DB) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads)
|
joins := make([]string, 0, len(db.Statement.Joins))
|
||||||
preloadNames := make([]string, 0, len(preloadMap))
|
for _, join := range db.Statement.Joins {
|
||||||
for key := range preloadMap {
|
joins = append(joins, join.Name)
|
||||||
preloadNames = append(preloadNames, key)
|
|
||||||
}
|
}
|
||||||
sort.Strings(preloadNames)
|
|
||||||
|
|
||||||
preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
|
tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest)
|
||||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
if tx.Error != nil {
|
||||||
preloadDB.Statement.Settings.Store(k, v)
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
preloadDB.Statement.ReflectValue = db.Statement.ReflectValue
|
|
||||||
preloadDB.Statement.Unscoped = db.Statement.Unscoped
|
|
||||||
|
|
||||||
for _, name := range preloadNames {
|
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
|
||||||
if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil {
|
|
||||||
db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations]))
|
|
||||||
} else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
|
|
||||||
db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
|
|
||||||
} else {
|
|
||||||
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func AfterQuery(db *gorm.DB) {
|
func AfterQuery(db *gorm.DB) {
|
||||||
|
// clear the joins after query because preload need it
|
||||||
|
db.Statement.Joins = nil
|
||||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
|
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||||
if i, ok := value.(AfterFindInterface); ok {
|
if i, ok := value.(AfterFindInterface); ok {
|
||||||
|
|
|
@ -307,6 +307,63 @@ func TestNestedPreloadWithUnscoped(t *testing.T) {
|
||||||
CheckUserUnscoped(t, *user6, user)
|
CheckUserUnscoped(t, *user6, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNestedPreloadWithNestedJoin(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Preload struct {
|
||||||
|
ID uint
|
||||||
|
Value string
|
||||||
|
NestedID uint
|
||||||
|
}
|
||||||
|
Join struct {
|
||||||
|
ID uint
|
||||||
|
Value string
|
||||||
|
NestedID uint
|
||||||
|
}
|
||||||
|
Nested struct {
|
||||||
|
ID uint
|
||||||
|
Preloads []*Preload
|
||||||
|
Join Join
|
||||||
|
ValueID uint
|
||||||
|
}
|
||||||
|
Value struct {
|
||||||
|
ID uint
|
||||||
|
Name string
|
||||||
|
Nested Nested
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{})
|
||||||
|
DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{})
|
||||||
|
|
||||||
|
value := Value{
|
||||||
|
Name: "value",
|
||||||
|
Nested: Nested{
|
||||||
|
Preloads: []*Preload{
|
||||||
|
{Value: "p1"}, {Value: "p2"},
|
||||||
|
},
|
||||||
|
Join: Join{Value: "j1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := DB.Create(&value).Error; err != nil {
|
||||||
|
t.Errorf("failed to create value, got err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var find1 Value
|
||||||
|
err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to find value, got err: %v", err)
|
||||||
|
}
|
||||||
|
AssertEqual(t, find1, value)
|
||||||
|
|
||||||
|
var find2 Value
|
||||||
|
// Joins will automatically add Nested queries.
|
||||||
|
err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to find value, got err: %v", err)
|
||||||
|
}
|
||||||
|
AssertEqual(t, find2, value)
|
||||||
|
}
|
||||||
|
|
||||||
func TestEmbedPreload(t *testing.T) {
|
func TestEmbedPreload(t *testing.T) {
|
||||||
type Country struct {
|
type Country struct {
|
||||||
ID int `gorm:"primaryKey"`
|
ID int `gorm:"primaryKey"`
|
||||||
|
|
Loading…
Reference in New Issue