forked from mirror/gorm
Refactor preloading many2many for auto preload
This commit is contained in:
parent
ec72a4cb6b
commit
77eb925ea0
|
@ -15,7 +15,7 @@ func init() {
|
|||
|
||||
// queryCallback used to query data from database
|
||||
func queryCallback(scope *Scope) {
|
||||
if _, skip := scope.Get("gorm:skip_query_callback"); skip {
|
||||
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -10,6 +10,9 @@ import (
|
|||
|
||||
// preloadCallback used to preload associations
|
||||
func preloadCallback(scope *Scope) {
|
||||
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := scope.Get("gorm:auto_preload"); ok {
|
||||
autoPreload(scope)
|
||||
|
@ -325,7 +328,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
|
|||
scope.scan(rows, columns, append(fields, joinTableFields...))
|
||||
|
||||
scope.New(elem.Addr().Interface()).
|
||||
Set("gorm:skip_query_callback", true).
|
||||
InstanceSet("gorm:skip_query_callback", true).
|
||||
callCallbacks(scope.db.parent.callbacks.queries)
|
||||
|
||||
var foreignKeys = make([]interface{}, len(sourceKeys))
|
||||
|
|
|
@ -1630,10 +1630,12 @@ func TestPrefixedPreloadDuplication(t *testing.T) {
|
|||
func TestPreloadManyToManyCallbacks(t *testing.T) {
|
||||
type (
|
||||
Level2 struct {
|
||||
ID uint
|
||||
ID uint
|
||||
Name string
|
||||
}
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Name string
|
||||
Level2s []Level2 `gorm:"many2many:level1_level2s;AssociationForeignKey:ID;ForeignKey:ID"`
|
||||
}
|
||||
)
|
||||
|
@ -1647,8 +1649,9 @@ func TestPreloadManyToManyCallbacks(t *testing.T) {
|
|||
}
|
||||
|
||||
lvl := Level1{
|
||||
Name: "l1",
|
||||
Level2s: []Level2{
|
||||
Level2{},
|
||||
Level2{Name: "l2-1"}, Level2{Name: "l2-2"},
|
||||
},
|
||||
}
|
||||
DB.Save(&lvl)
|
||||
|
@ -1659,11 +1662,10 @@ func TestPreloadManyToManyCallbacks(t *testing.T) {
|
|||
called = called + 1
|
||||
})
|
||||
|
||||
found := Level1{ID: lvl.ID}
|
||||
DB.Preload("Level2s").First(&found, &found)
|
||||
DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID)
|
||||
|
||||
if called != 2 {
|
||||
t.Errorf("Wanted callback to be called 2 times but got %d", called)
|
||||
if called != 3 {
|
||||
t.Errorf("Wanted callback to be called 3 times but got %d", called)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue