Refactor preloading many2many for auto preload

This commit is contained in:
Jinzhu 2018-02-10 00:07:16 +08:00
parent ec72a4cb6b
commit 77eb925ea0
3 changed files with 13 additions and 8 deletions

View File

@ -15,7 +15,7 @@ func init() {
// queryCallback used to query data from database // queryCallback used to query data from database
func queryCallback(scope *Scope) { func queryCallback(scope *Scope) {
if _, skip := scope.Get("gorm:skip_query_callback"); skip { if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
return return
} }

View File

@ -10,6 +10,9 @@ import (
// preloadCallback used to preload associations // preloadCallback used to preload associations
func preloadCallback(scope *Scope) { func preloadCallback(scope *Scope) {
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
return
}
if _, ok := scope.Get("gorm:auto_preload"); ok { if _, ok := scope.Get("gorm:auto_preload"); ok {
autoPreload(scope) autoPreload(scope)
@ -325,7 +328,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
scope.scan(rows, columns, append(fields, joinTableFields...)) scope.scan(rows, columns, append(fields, joinTableFields...))
scope.New(elem.Addr().Interface()). scope.New(elem.Addr().Interface()).
Set("gorm:skip_query_callback", true). InstanceSet("gorm:skip_query_callback", true).
callCallbacks(scope.db.parent.callbacks.queries) callCallbacks(scope.db.parent.callbacks.queries)
var foreignKeys = make([]interface{}, len(sourceKeys)) var foreignKeys = make([]interface{}, len(sourceKeys))

View File

@ -1631,9 +1631,11 @@ func TestPreloadManyToManyCallbacks(t *testing.T) {
type ( type (
Level2 struct { Level2 struct {
ID uint ID uint
Name string
} }
Level1 struct { Level1 struct {
ID uint ID uint
Name string
Level2s []Level2 `gorm:"many2many:level1_level2s;AssociationForeignKey:ID;ForeignKey:ID"` Level2s []Level2 `gorm:"many2many:level1_level2s;AssociationForeignKey:ID;ForeignKey:ID"`
} }
) )
@ -1647,8 +1649,9 @@ func TestPreloadManyToManyCallbacks(t *testing.T) {
} }
lvl := Level1{ lvl := Level1{
Name: "l1",
Level2s: []Level2{ Level2s: []Level2{
Level2{}, Level2{Name: "l2-1"}, Level2{Name: "l2-2"},
}, },
} }
DB.Save(&lvl) DB.Save(&lvl)
@ -1659,11 +1662,10 @@ func TestPreloadManyToManyCallbacks(t *testing.T) {
called = called + 1 called = called + 1
}) })
found := Level1{ID: lvl.ID} DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID)
DB.Preload("Level2s").First(&found, &found)
if called != 2 { if called != 3 {
t.Errorf("Wanted callback to be called 2 times but got %d", called) t.Errorf("Wanted callback to be called 3 times but got %d", called)
} }
} }