diff --git a/association.go b/association.go index cbec41c3..a8a67be7 100644 --- a/association.go +++ b/association.go @@ -26,7 +26,7 @@ func (association *Association) Find(value interface{}) *Association { return association.setErr(association.Scope.db.Error) } -func (association *Association) Append(values ...interface{}) *Association { +func (association *Association) saveAssociations(values ...interface{}) *Association { scope := association.Scope field := association.Field relationship := association.Field.Relationship @@ -81,6 +81,13 @@ func (association *Association) Append(values ...interface{}) *Association { return association } +func (association *Association) Append(values ...interface{}) *Association { + if relationship := association.Field.Relationship; relationship.Kind == "has_one" { + return association.Replace(values...) + } + return association.saveAssociations(values...) +} + func (association *Association) Replace(values ...interface{}) *Association { var ( relationship = association.Field.Relationship @@ -91,7 +98,7 @@ func (association *Association) Replace(values ...interface{}) *Association { // Append new values association.Field.Set(reflect.Zero(association.Field.Field.Type())) - association.Append(values...) + association.saveAssociations(values...) // Belongs To if relationship.Kind == "belongs_to" { @@ -114,6 +121,10 @@ func (association *Association) Replace(values ...interface{}) *Association { } } + if relationship.PolymorphicDBName != "" { + newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) + } + // Relations except new created if len(values) > 0 { var newPrimaryKeys [][]interface{} diff --git a/field.go b/field.go index db1fdd8f..6bfac0bb 100644 --- a/field.go +++ b/field.go @@ -42,6 +42,10 @@ func (field *Field) Set(value interface{}) error { reflectValue = reflect.ValueOf(value) } + if !reflectValue.IsValid() { + return nil + } + if reflectValue.Type().ConvertibleTo(field.Field.Type()) { field.Field.Set(reflectValue.Convert(field.Field.Type())) } else { diff --git a/migration_test.go b/migration_test.go index 819a378a..57d9530a 100644 --- a/migration_test.go +++ b/migration_test.go @@ -15,7 +15,7 @@ func runMigration() { DB.Exec(fmt.Sprintf("drop table %v;", table)) } - values := []interface{}{&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}} + values := []interface{}{&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Toy{}} for _, value := range values { DB.DropTable(value) } diff --git a/polymorphic_test.go b/polymorphic_test.go index 78b99feb..f926fdd4 100644 --- a/polymorphic_test.go +++ b/polymorphic_test.go @@ -1,6 +1,11 @@ package gorm_test -import "testing" +import ( + "fmt" + "reflect" + "sort" + "testing" +) type Cat struct { Id int @@ -21,15 +26,30 @@ type Toy struct { OwnerType string } -func TestPolymorphic(t *testing.T) { - DB.AutoMigrate(&Cat{}) - DB.AutoMigrate(&Dog{}) - DB.AutoMigrate(&Toy{}) +var compareToys = func(toys []Toy, contents []string) bool { + var toyContents []string + for _, toy := range toys { + toyContents = append(toyContents, toy.Name) + } + sort.Strings(toyContents) + sort.Strings(contents) + return reflect.DeepEqual(toyContents, contents) +} - cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat nip"}} - dog := Dog{Name: "Pluto", Toys: []Toy{Toy{Name: "orange ball"}, Toy{Name: "yellow ball"}}} +func TestPolymorphic(t *testing.T) { + cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat toy"}} + dog := Dog{Name: "Pluto", Toys: []Toy{Toy{Name: "dog toy 1"}, Toy{Name: "dog toy 2"}}} DB.Save(&cat).Save(&dog) + if DB.Model(&cat).Association("Toy").Count() != 1 { + t.Errorf("Cat's toys count should be 1") + } + + if DB.Model(&cat).Association("Toy").Count() != 1 { + t.Errorf("Dog's toys count should be 2") + } + + // Query var catToys []Toy if DB.Model(&cat).Related(&catToys, "Toy").RecordNotFound() { t.Errorf("Did not find any has one polymorphic association") @@ -46,11 +66,52 @@ func TestPolymorphic(t *testing.T) { t.Errorf("Should have found all polymorphic has many associations") } + var catToy Toy + DB.Model(&cat).Association("Toy").Find(&catToy) + if catToy.Name != cat.Toy.Name { + t.Errorf("Should find has one polymorphic association") + } + + var dogToys1 []Toy + DB.Model(&dog).Association("Toys").Find(&dogToys1) + if !compareToys(dogToys1, []string{"dog toy 1", "dog toy 2"}) { + t.Errorf("Should find has many polymorphic association") + } + + // Append + DB.Model(&cat).Association("Toy").Append(&Toy{ + Name: "cat toy 2", + }) + + var catToy2 Toy + DB.Model(&cat).Association("Toy").Find(&catToy2) + if catToy2.Name != "cat toy 2" { + t.Errorf("Should update has one polymorphic association with Append") + } + if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Should return one polymorphic has one association") + t.Errorf("Cat's toys count should be 1 after Append") } if DB.Model(&dog).Association("Toys").Count() != 2 { t.Errorf("Should return two polymorphic has many associations") } + + DB.Model(&dog).Association("Toys").Append(&Toy{ + Name: "dog toy 3", + }) + + var dogToys2 []Toy + DB.Model(&dog).Association("Toys").Find(&dogToys2) + fmt.Println(dogToys2) + if !compareToys(dogToys2, []string{"dog toy 1", "dog toy 2", "dog toy 3"}) { + t.Errorf("Dog's toys should be updated with Append") + } + + if DB.Model(&dog).Association("Toys").Count() != 3 { + t.Errorf("Should return three polymorphic has many associations") + } + // Replace + // Delete + // Clear }