From 24e0de116a4b1b63f67397af87f24aa3aedcb102 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 11 Feb 2015 19:08:42 +0800 Subject: [PATCH] Add inline condition support for Preload --- preload.go | 8 ++++---- preload_test.go | 13 +++++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/preload.go b/preload.go index 4e48db53..c4381b2a 100644 --- a/preload.go +++ b/preload.go @@ -36,7 +36,7 @@ func Preload(scope *Scope) { } if scope.Search.Preload != nil { - for key := range scope.Search.Preload { + for key, conditions := range scope.Search.Preload { for _, field := range fields { if field.Name == key && field.Relationship != nil { results := makeSlice(field.Field) @@ -47,7 +47,7 @@ func Preload(scope *Scope) { switch relation.Kind { case "has_one": condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName())) - scope.NewDB().Find(results, condition, scope.getColumnAsArray(primaryName)) + scope.NewDB().Where(condition, scope.getColumnAsArray(primaryName)).Find(results, conditions...) resultValues := reflect.Indirect(reflect.ValueOf(results)) for i := 0; i < resultValues.Len(); i++ { @@ -67,7 +67,7 @@ func Preload(scope *Scope) { } case "has_many": condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName())) - scope.NewDB().Find(results, condition, scope.getColumnAsArray(primaryName)) + scope.NewDB().Where(condition, scope.getColumnAsArray(primaryName)).Find(results, conditions...) resultValues := reflect.Indirect(reflect.ValueOf(results)) if isSlice { for i := 0; i < resultValues.Len(); i++ { @@ -87,7 +87,7 @@ func Preload(scope *Scope) { scope.SetColumn(field, resultValues) } case "belongs_to": - scope.NewDB().Find(results, scope.getColumnAsArray(relation.ForeignKey)) + scope.NewDB().Where(scope.getColumnAsArray(relation.ForeignKey)).Find(results, conditions...) resultValues := reflect.Indirect(reflect.ValueOf(results)) for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) diff --git a/preload_test.go b/preload_test.go index 2a0e967b..0c2e9e18 100644 --- a/preload_test.go +++ b/preload_test.go @@ -84,4 +84,17 @@ func TestPreload(t *testing.T) { for _, user := range users2 { checkUserHasPreloadData(*user, t) } + + var users3 []*User + DB.Where("role = ?", "Preload").Preload("Emails", "email = ?", user3.Emails[0].Email).Find(&users3) + + for _, user := range users3 { + if user.Name == user3.Name { + if len(user.Emails) != 1 { + t.Errorf("should only preload one emails for user3 when with condition") + } + } else if len(user.Emails) != 0 { + t.Errorf("should not preload any emails for other users when with condition") + } + } }