diff --git a/callback_query_preload.go b/callback_query_preload.go index 76d6f993..fff252c9 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -4,11 +4,17 @@ import ( "errors" "fmt" "reflect" + "strconv" "strings" ) // preloadCallback used to preload associations func preloadCallback(scope *Scope) { + + if _, ok := scope.Get("gorm:auto_preload"); ok { + autoPreload(scope) + } + if scope.Search.preload == nil || scope.HasError() { return } @@ -79,6 +85,25 @@ func preloadCallback(scope *Scope) { } } +func autoPreload(scope *Scope) { + for _, field := range scope.Fields() { + if field.Relationship == nil { + continue + } + + if val, ok := field.TagSettings["PRELOAD"]; ok { + if preload, err := strconv.ParseBool(val); err != nil { + scope.Err(errors.New("invalid preload option")) + return + } else if !preload { + continue + } + } + + scope.Search.Preload(field.Name) + } +} + func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) { var ( preloadDB = scope.NewDB() diff --git a/preload_test.go b/preload_test.go index c830025c..1b89e77b 100644 --- a/preload_test.go +++ b/preload_test.go @@ -96,6 +96,33 @@ func TestPreload(t *testing.T) { } } +func TestAutoPreload(t *testing.T) { + user1 := getPreloadUser("auto_user1") + DB.Save(user1) + + preloadDB := DB.Set("gorm:auto_preload", true).Where("role = ?", "Preload") + var user User + preloadDB.Find(&user) + checkUserHasPreloadData(user, t) + + user2 := getPreloadUser("auto_user2") + DB.Save(user2) + + var users []User + preloadDB.Find(&users) + + for _, user := range users { + checkUserHasPreloadData(user, t) + } + + var users2 []*User + preloadDB.Find(&users2) + + for _, user := range users2 { + checkUserHasPreloadData(*user, t) + } +} + func TestNestedPreload1(t *testing.T) { type ( Level1 struct {