diff --git a/association.go b/association.go index a889157b..ab9090ac 100644 --- a/association.go +++ b/association.go @@ -26,6 +26,8 @@ func (db *DB) Association(column string) *Association { if association.Relationship == nil { association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column) } + + db.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(db.Statement.Model)) } else { association.Error = err } @@ -36,8 +38,8 @@ func (db *DB) Association(column string) *Association { func (association *Association) Find(out interface{}, conds ...interface{}) error { if association.Error == nil { var ( - tx = association.DB - queryConds = association.Relationship.ToQueryConditions(tx.Statement.ReflectValue) + queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) + tx = association.DB.Model(out).Table("") ) if association.Relationship.JoinTable != nil { diff --git a/callbacks.go b/callbacks.go index 6c70b392..61cebc81 100644 --- a/callbacks.go +++ b/callbacks.go @@ -83,7 +83,13 @@ func (p *processor) Execute(db *DB) { db.AddError(err) } } - stmt.ReflectValue = reflect.Indirect(reflect.ValueOf(stmt.Dest)) + + if stmt.Dest != nil { + stmt.ReflectValue = reflect.Indirect(reflect.ValueOf(stmt.Dest)) + if !stmt.ReflectValue.IsValid() { + db.AddError(fmt.Errorf("invalid value")) + } + } } for _, f := range p.fns { diff --git a/go.mod b/go.mod index 3e067d3c..d3421e1b 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,8 @@ module github.com/jinzhu/gorm -go 1.13 +go 1.14 require ( - github.com/denisenkom/go-mssqldb v0.0.0-20200206145737-bbfc9a55622e // indirect - github.com/go-sql-driver/mysql v1.5.0 // indirect github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.1.1 - github.com/lib/pq v1.3.0 // indirect - github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect ) diff --git a/tests/associations.go b/tests/associations.go new file mode 100644 index 00000000..7e93e81e --- /dev/null +++ b/tests/associations.go @@ -0,0 +1,73 @@ +package tests + +import ( + "testing" + + "github.com/jinzhu/gorm" +) + +func TestAssociations(t *testing.T, db *gorm.DB) { + db.Migrator().DropTable(&Account{}, &Company{}, &Pet{}, &Toy{}, &Language{}) + db.Migrator().AutoMigrate(&Account{}, &Company{}, &Pet{}, &Toy{}, &Language{}) + + TestBelongsToAssociations(t, db) +} + +func TestBelongsToAssociations(t *testing.T, db *gorm.DB) { + check := func(t *testing.T, user User, old User) { + if old.Company.Name != "" { + if user.CompanyID == nil { + t.Errorf("Company's foreign key should be saved") + } else { + var company Company + db.First(&company, "id = ?", *user.CompanyID) + if company.Name != old.Company.Name { + t.Errorf("Company's name should be same, expects: %v, got %v", old.Company.Name, user.Company.Name) + } else if user.Company.Name != old.Company.Name { + t.Errorf("Company's name should be same, expects: %v, got %v", old.Company.Name, user.Company.Name) + } + } + } else if user.CompanyID != nil { + t.Errorf("Company should not be created for zero value, got: %+v", user.CompanyID) + } + + if old.Manager != nil { + if user.ManagerID == nil { + t.Errorf("Manager's foreign key should be saved") + } else { + var manager User + db.First(&manager, "id = ?", *user.ManagerID) + if manager.Name != user.Manager.Name { + t.Errorf("Manager's name should be same") + } else if user.Manager.Name != old.Manager.Name { + t.Errorf("Manager's name should be same") + } + } + } else if user.ManagerID != nil { + t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) + } + } + + t.Run("BelongsTo", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association"}, + Manager: &User{Name: "manager-belongs-to-association"}, + } + + if err := db.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + check(t, user, user) + + var user2 User + db.Find(&user2, "id = ?", user.ID) + db.Model(&user2).Association("Company").Find(&user2.Company) + user2.Manager = &User{} + db.Model(&user2).Association("Manager").Find(user2.Manager) + check(t, user2, user) + }) +} diff --git a/tests/tests.go b/tests/tests.go index cc9c1a78..87005a71 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -20,4 +20,5 @@ func RunTestsSuit(t *testing.T, db *gorm.DB) { TestGroupBy(t, db) TestJoins(t, db) + TestAssociations(t, db) }