diff --git a/main_test.go b/main_test.go index 540c1e13..b6d1c17c 100644 --- a/main_test.go +++ b/main_test.go @@ -1351,10 +1351,17 @@ func TestRelated(t *testing.T) { if len(emails) != 2 { t.Errorf("Should have two emails") } + + var emails2 []Email + db.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2) + if len(emails2) != 1 { + t.Errorf("Should have two emails") + } + var user1 User db.Model(&user).Related(&user1.Emails) if len(user1.Emails) != 2 { - t.Errorf("Should have two emails") + t.Errorf("Should have only one email match related condition") } var address1 Address @@ -1554,9 +1561,9 @@ func TestTransaction(t *testing.T) { t.Errorf("Should find saved record, but got", err) } - if sql_tx, ok := tx.CommonDB().(*sql.Tx); !ok || sql_tx == nil { - t.Errorf("Should return the underlying sql.Tx") - } + if sql_tx, ok := tx.CommonDB().(*sql.Tx); !ok || sql_tx == nil { + t.Errorf("Should return the underlying sql.Tx") + } tx.Rollback() diff --git a/scope_private.go b/scope_private.go index c2e5627e..72f631cc 100644 --- a/scope_private.go +++ b/scope_private.go @@ -393,8 +393,7 @@ func (scope *Scope) typeName() string { } func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { - toScope := scope.New(value) - toScope.db = scope.db + toScope := scope.db.NewScope(value) for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { if foreignValue, ok := scope.FieldByName(foreignKey); ok {