diff --git a/README.md b/README.md index 0f195561..1c06820e 100644 --- a/README.md +++ b/README.md @@ -929,7 +929,7 @@ db.Table("users").Select("users.name, emails.email").Joins("left join emails on db.Joins("inner join emails on emails.user_id = users.id").Where("emails.email = ?", "x@example.org").Find(&user) // find all email addresses for a user -db.Joins("left join users on users.id = emails.user_id").Where("users.name = ?", "jinzhu").Find(&emails) +db.Joins("LEFT JOIN users ON users.id = emails.user_id AND users.name = ?", "jinzhu").Find(&emails) ``` ## Transactions diff --git a/main.go b/main.go index 9581a216..b7f0d2aa 100644 --- a/main.go +++ b/main.go @@ -171,8 +171,8 @@ func (s *DB) Having(query string, values ...interface{}) *DB { return s.clone().search.Having(query, values...).db } -func (s *DB) Joins(query string) *DB { - return s.clone().search.Joins(query).db +func (s *DB) Joins(query string, args ...interface{}) *DB { + return s.clone().search.Joins(query, args...).db } func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { diff --git a/main_test.go b/main_test.go index 97a3d84e..39189cd3 100644 --- a/main_test.go +++ b/main_test.go @@ -506,10 +506,16 @@ func TestJoins(t *testing.T) { } DB.Save(&user) - var result User - DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").First(&result) - if result.Name != "joins" || result.Id != user.Id { - t.Errorf("Should find all two emails with Join") + var users1 []User + DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").Find(&users1) + if len(users1) != 2 { + t.Errorf("should find two users using left join") + } + + var users2 []User + DB.Joins("left join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Where("name = ?", "joins").First(&users2) + if len(users2) != 1 { + t.Errorf("should find one users using left join with conditions") } } diff --git a/scope_private.go b/scope_private.go index 4ed2060c..6b34a4b3 100644 --- a/scope_private.go +++ b/scope_private.go @@ -234,7 +234,7 @@ var hasCountRegexp = regexp.MustCompile(`(?i)count\(.+\)`) func (scope *Scope) selectSql() string { if len(scope.Search.selects) == 0 { - if scope.Search.joins != "" { + if len(scope.Search.joinConditions) > 0 { return fmt.Sprintf("%v.*", scope.QuotedTableName()) } return "*" @@ -263,12 +263,11 @@ func (scope *Scope) groupSql() string { } func (scope *Scope) havingSql() string { - if scope.Search.havingConditions == nil { + if len(scope.Search.havingConditions) == 0 { return "" } var andConditions []string - for _, clause := range scope.Search.havingConditions { if sql := scope.buildWhereCondition(clause); sql != "" { andConditions = append(andConditions, sql) @@ -284,7 +283,14 @@ func (scope *Scope) havingSql() string { } func (scope *Scope) joinsSql() string { - return scope.Search.joins + " " + var joinConditions []string + for _, clause := range scope.Search.joinConditions { + if sql := scope.buildWhereCondition(clause); sql != "" { + joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")")) + } + } + + return strings.Join(joinConditions, " ") + " " } func (scope *Scope) prepareQuerySql() { diff --git a/search.go b/search.go index c6d070f0..4e31ae03 100644 --- a/search.go +++ b/search.go @@ -8,12 +8,12 @@ type search struct { orConditions []map[string]interface{} notConditions []map[string]interface{} havingConditions []map[string]interface{} + joinConditions []map[string]interface{} initAttrs []interface{} assignAttrs []interface{} selects map[string]interface{} omits []string orders []string - joins string preload []searchPreload offset int limit int @@ -102,8 +102,8 @@ func (s *search) Having(query string, values ...interface{}) *search { return s } -func (s *search) Joins(query string) *search { - s.joins = query +func (s *search) Joins(query string, values ...interface{}) *search { + s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values}) return s }