Fix associations find

This commit is contained in:
Jinzhu 2020-05-23 11:57:28 +08:00
parent 0f21272c7f
commit 72460df1bd
5 changed files with 86 additions and 8 deletions

View File

@ -26,6 +26,8 @@ func (db *DB) Association(column string) *Association {
if association.Relationship == nil { if association.Relationship == nil {
association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column) association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column)
} }
db.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(db.Statement.Model))
} else { } else {
association.Error = err association.Error = err
} }
@ -36,8 +38,8 @@ func (db *DB) Association(column string) *Association {
func (association *Association) Find(out interface{}, conds ...interface{}) error { func (association *Association) Find(out interface{}, conds ...interface{}) error {
if association.Error == nil { if association.Error == nil {
var ( var (
tx = association.DB queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
queryConds = association.Relationship.ToQueryConditions(tx.Statement.ReflectValue) tx = association.DB.Model(out).Table("")
) )
if association.Relationship.JoinTable != nil { if association.Relationship.JoinTable != nil {

View File

@ -83,7 +83,13 @@ func (p *processor) Execute(db *DB) {
db.AddError(err) db.AddError(err)
} }
} }
if stmt.Dest != nil {
stmt.ReflectValue = reflect.Indirect(reflect.ValueOf(stmt.Dest)) stmt.ReflectValue = reflect.Indirect(reflect.ValueOf(stmt.Dest))
if !stmt.ReflectValue.IsValid() {
db.AddError(fmt.Errorf("invalid value"))
}
}
} }
for _, f := range p.fns { for _, f := range p.fns {

6
go.mod
View File

@ -1,12 +1,8 @@
module github.com/jinzhu/gorm module github.com/jinzhu/gorm
go 1.13 go 1.14
require ( require (
github.com/denisenkom/go-mssqldb v0.0.0-20200206145737-bbfc9a55622e // indirect
github.com/go-sql-driver/mysql v1.5.0 // indirect
github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/inflection v1.0.0
github.com/jinzhu/now v1.1.1 github.com/jinzhu/now v1.1.1
github.com/lib/pq v1.3.0 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
) )

73
tests/associations.go Normal file
View File

@ -0,0 +1,73 @@
package tests
import (
"testing"
"github.com/jinzhu/gorm"
)
func TestAssociations(t *testing.T, db *gorm.DB) {
db.Migrator().DropTable(&Account{}, &Company{}, &Pet{}, &Toy{}, &Language{})
db.Migrator().AutoMigrate(&Account{}, &Company{}, &Pet{}, &Toy{}, &Language{})
TestBelongsToAssociations(t, db)
}
func TestBelongsToAssociations(t *testing.T, db *gorm.DB) {
check := func(t *testing.T, user User, old User) {
if old.Company.Name != "" {
if user.CompanyID == nil {
t.Errorf("Company's foreign key should be saved")
} else {
var company Company
db.First(&company, "id = ?", *user.CompanyID)
if company.Name != old.Company.Name {
t.Errorf("Company's name should be same, expects: %v, got %v", old.Company.Name, user.Company.Name)
} else if user.Company.Name != old.Company.Name {
t.Errorf("Company's name should be same, expects: %v, got %v", old.Company.Name, user.Company.Name)
}
}
} else if user.CompanyID != nil {
t.Errorf("Company should not be created for zero value, got: %+v", user.CompanyID)
}
if old.Manager != nil {
if user.ManagerID == nil {
t.Errorf("Manager's foreign key should be saved")
} else {
var manager User
db.First(&manager, "id = ?", *user.ManagerID)
if manager.Name != user.Manager.Name {
t.Errorf("Manager's name should be same")
} else if user.Manager.Name != old.Manager.Name {
t.Errorf("Manager's name should be same")
}
}
} else if user.ManagerID != nil {
t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID)
}
}
t.Run("BelongsTo", func(t *testing.T) {
var user = User{
Name: "create",
Age: 18,
Birthday: Now(),
Company: Company{Name: "company-belongs-to-association"},
Manager: &User{Name: "manager-belongs-to-association"},
}
if err := db.Create(&user).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
}
check(t, user, user)
var user2 User
db.Find(&user2, "id = ?", user.ID)
db.Model(&user2).Association("Company").Find(&user2.Company)
user2.Manager = &User{}
db.Model(&user2).Association("Manager").Find(user2.Manager)
check(t, user2, user)
})
}

View File

@ -20,4 +20,5 @@ func RunTestsSuit(t *testing.T, db *gorm.DB) {
TestGroupBy(t, db) TestGroupBy(t, db)
TestJoins(t, db) TestJoins(t, db)
TestAssociations(t, db)
} }