From eeb9ba2250b21a764abc8d7b5688a5998a684af6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 1 Oct 2015 07:43:38 +0800 Subject: [PATCH] Fix query with Joins --- main_test.go | 20 +++++++++++++++++--- preload.go | 2 +- scope_private.go | 3 +++ 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/main_test.go b/main_test.go index 2683ac61..33503b3d 100644 --- a/main_test.go +++ b/main_test.go @@ -425,21 +425,35 @@ func TestGroup(t *testing.T) { } func TestJoins(t *testing.T) { + var user = User{ + Name: "joins", + Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, + } + DB.Save(&user) + + var result User + DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").First(&result) + if result.Name != "joins" || result.Id != user.Id { + t.Errorf("Should find all two emails with Join") + } +} + +func TestJoinsWithSelect(t *testing.T) { type result struct { Name string Email string } user := User{ - Name: "joins", + Name: "joins_with_select", Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, } DB.Save(&user) var results []result - DB.Table("users").Select("name, email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").Scan(&results) + DB.Table("users").Select("name, email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins_with_select").Scan(&results) if len(results) != 2 || results[0].Email != "join1@example.com" || results[1].Email != "join2@example.com" { - t.Errorf("Should find all two emails with Join") + t.Errorf("Should find all two emails with Join select") } } diff --git a/preload.go b/preload.go index 4621dd91..efdfed07 100644 --- a/preload.go +++ b/preload.go @@ -209,7 +209,7 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf sourceKeys = append(sourceKeys, key.DBName) } - db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()) + db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()).Select("*") preloadJoinDB := joinTableHandler.JoinWith(joinTableHandler, db, scope.Value) if len(conditions) > 0 { preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...) diff --git a/scope_private.go b/scope_private.go index c743c119..16a96628 100644 --- a/scope_private.go +++ b/scope_private.go @@ -213,6 +213,9 @@ var hasCountRegexp = regexp.MustCompile(`(?i)count(.+)`) func (scope *Scope) selectSql() string { if len(scope.Search.selects) == 0 { + if scope.Search.joins != "" { + return fmt.Sprintf("%v.*", scope.QuotedTableName()) + } return "*" } sql := scope.buildSelectQuery(scope.Search.selects)