diff --git a/association.go b/association.go index 0f94683d..f57c99af 100644 --- a/association.go +++ b/association.go @@ -63,7 +63,11 @@ func (association *Association) Replace(values ...interface{}) *Association { } else { // Polymorphic Relations if relationship.PolymorphicDBName != "" { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) + value := scope.TableName() + if relationship.PolymorphicValue != "" { + value = relationship.PolymorphicValue + } + newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), value) } // Delete Relations except new created @@ -282,9 +286,13 @@ func (association *Association) Count() int { } if relationship.PolymorphicType != "" { + value := scope.TableName() + if relationship.PolymorphicValue != "" { + value = relationship.PolymorphicValue + } query = query.Where( fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)), - scope.TableName(), + value, ) } diff --git a/callback_query_preload.go b/callback_query_preload.go index d9ec8bdd..13d0b85d 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -114,7 +114,11 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) values := toQueryValues(primaryKeys) if relation.PolymorphicType != "" { query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) - values = append(values, scope.TableName()) + value := scope.TableName() + if relation.PolymorphicValue != "" { + value = relation.PolymorphicValue + } + values = append(values, value) } results := makeSlice(field.Struct.Type) @@ -163,7 +167,11 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) values := toQueryValues(primaryKeys) if relation.PolymorphicType != "" { query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) - values = append(values, scope.TableName()) + value := scope.TableName() + if relation.PolymorphicValue != "" { + value = relation.PolymorphicValue + } + values = append(values, value) } results := makeSlice(field.Struct.Type) diff --git a/callback_save.go b/callback_save.go index 5ffe53b9..6f4e2652 100644 --- a/callback_save.go +++ b/callback_save.go @@ -60,7 +60,11 @@ func saveAfterAssociationsCallback(scope *Scope) { } if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName())) + value := scope.TableName() + if relationship.PolymorphicValue != "" { + value = relationship.PolymorphicValue + } + scope.Err(newScope.SetColumn(relationship.PolymorphicType, value)) } scope.Err(newDB.Save(elem).Error) @@ -82,7 +86,11 @@ func saveAfterAssociationsCallback(scope *Scope) { } if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName())) + value := scope.TableName() + if relationship.PolymorphicValue != "" { + value = relationship.PolymorphicValue + } + scope.Err(newScope.SetColumn(relationship.PolymorphicType, value)) } scope.Err(scope.NewDB().Save(elem).Error) } diff --git a/migration_test.go b/migration_test.go index 08633283..1545ec2c 100644 --- a/migration_test.go +++ b/migration_test.go @@ -252,7 +252,7 @@ func runMigration() { DB.Exec(fmt.Sprintf("drop table %v;", table)) } - values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Toy{}, &ElementWithIgnoredField{}} + values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}} for _, value := range values { DB.DropTable(value) } diff --git a/model_struct.go b/model_struct.go index 8179a1f3..ad8b9f9a 100644 --- a/model_struct.go +++ b/model_struct.go @@ -99,6 +99,7 @@ type Relationship struct { Kind string PolymorphicType string PolymorphicDBName string + PolymorphicValue string ForeignFieldNames []string ForeignDBNames []string AssociationForeignFieldNames []string @@ -292,6 +293,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { associationType = polymorphic relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicDBName = polymorphicType.DBName + // if Dog has multiple set of toys set name of the set (instead of default 'dogs') + relationship.PolymorphicValue = field.TagSettings["VALUE"] polymorphicType.IsForeignKey = true } } @@ -384,6 +387,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { associationType = polymorphic relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicDBName = polymorphicType.DBName + // if Cat has several different types of toys set name for each (instead of default 'cats') + relationship.PolymorphicValue = field.TagSettings["VALUE"] polymorphicType.IsForeignKey = true } } diff --git a/polymorphic_test.go b/polymorphic_test.go index df573f97..b3a5c79e 100644 --- a/polymorphic_test.go +++ b/polymorphic_test.go @@ -18,6 +18,13 @@ type Dog struct { Toys []Toy `gorm:"polymorphic:Owner;"` } +type Hamster struct { + Id int + Name string + PreferredToy Toy `gorm:"polymorphic:Owner;value:hamster_preferred"` + OtherToy Toy `gorm:"polymorphic:Owner;value:hamster_other"` +} + type Toy struct { Id int Name string @@ -217,3 +224,143 @@ func TestPolymorphic(t *testing.T) { t.Errorf("Dog's toys should be cleared with Clear") } } + +func TestNamedPolymorphic(t *testing.T) { + hamster := Hamster{Name: "Mr. Hammond", PreferredToy: Toy{Name: "bike"}, OtherToy: Toy{Name: "treadmill"}} + DB.Save(&hamster) + + hamster2 := Hamster{} + DB.Preload("PreferredToy").Preload("OtherToy").Find(&hamster2, hamster.Id) + if hamster2.PreferredToy.Id != hamster.PreferredToy.Id || hamster2.PreferredToy.Name != hamster.PreferredToy.Name { + t.Errorf("Hamster's preferred toy couldn't be preloaded") + } + if hamster2.OtherToy.Id != hamster.OtherToy.Id || hamster2.OtherToy.Name != hamster.OtherToy.Name { + t.Errorf("Hamster's other toy couldn't be preloaded") + } + + // clear to omit Toy.Id in count + hamster2.PreferredToy = Toy{} + hamster2.OtherToy = Toy{} + + if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { + t.Errorf("Hamster's preferred toy count should be 1") + } + + if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { + t.Errorf("Hamster's other toy count should be 1") + } + + // Query + var hamsterToys []Toy + if DB.Model(&hamster).Related(&hamsterToys, "PreferredToy").RecordNotFound() { + t.Errorf("Did not find any has one polymorphic association") + } else if len(hamsterToys) != 1 { + t.Errorf("Should have found only one polymorphic has one association") + } else if hamsterToys[0].Name != hamster.PreferredToy.Name { + t.Errorf("Should have found the proper has one polymorphic association") + } + + if DB.Model(&hamster).Related(&hamsterToys, "OtherToy").RecordNotFound() { + t.Errorf("Did not find any has one polymorphic association") + } else if len(hamsterToys) != 1 { + t.Errorf("Should have found only one polymorphic has one association") + } else if hamsterToys[0].Name != hamster.OtherToy.Name { + t.Errorf("Should have found the proper has one polymorphic association") + } + + hamsterToy := Toy{} + DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) + if hamsterToy.Name != hamster.PreferredToy.Name { + t.Errorf("Should find has one polymorphic association") + } + hamsterToy = Toy{} + DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) + if hamsterToy.Name != hamster.OtherToy.Name { + t.Errorf("Should find has one polymorphic association") + } + + // Append + DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ + Name: "bike 2", + }) + DB.Model(&hamster).Association("OtherToy").Append(&Toy{ + Name: "treadmill 2", + }) + + hamsterToy = Toy{} + DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) + if hamsterToy.Name != "bike 2" { + t.Errorf("Should update has one polymorphic association with Append") + } + + hamsterToy = Toy{} + DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) + if hamsterToy.Name != "treadmill 2" { + t.Errorf("Should update has one polymorphic association with Append") + } + + if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { + t.Errorf("Hamster's toys count should be 1 after Append") + } + + if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { + t.Errorf("Hamster's toys count should be 1 after Append") + } + + // Replace + DB.Model(&hamster).Association("PreferredToy").Replace(&Toy{ + Name: "bike 3", + }) + DB.Model(&hamster).Association("OtherToy").Replace(&Toy{ + Name: "treadmill 3", + }) + + hamsterToy = Toy{} + DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) + if hamsterToy.Name != "bike 3" { + t.Errorf("Should update has one polymorphic association with Replace") + } + + hamsterToy = Toy{} + DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) + if hamsterToy.Name != "treadmill 3" { + t.Errorf("Should update has one polymorphic association with Replace") + } + + if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { + t.Errorf("hamster's toys count should be 1 after Replace") + } + + if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { + t.Errorf("hamster's toys count should be 1 after Replace") + } + + // Clear + DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ + Name: "bike 2", + }) + DB.Model(&hamster).Association("OtherToy").Append(&Toy{ + Name: "treadmill 2", + }) + + if DB.Model(&hamster).Association("PreferredToy").Count() != 1 { + t.Errorf("Hamster's toys should be added with Append") + } + if DB.Model(&hamster).Association("OtherToy").Count() != 1 { + t.Errorf("Hamster's toys should be added with Append") + } + + DB.Model(&hamster).Association("PreferredToy").Clear() + + if DB.Model(&hamster2).Association("PreferredToy").Count() != 0 { + t.Errorf("Hamster's preferred toy should be cleared with Clear") + } + if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { + t.Errorf("Hamster's other toy should be still available") + } + + DB.Model(&hamster).Association("OtherToy").Clear() + if DB.Model(&hamster).Association("OtherToy").Count() != 0 { + t.Errorf("Hamster's other toy should be cleared with Clear") + } +} diff --git a/scope.go b/scope.go index c00cd0a1..494a6d76 100644 --- a/scope.go +++ b/scope.go @@ -999,7 +999,11 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { } if relationship.PolymorphicType != "" { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) + value := scope.TableName() + if relationship.PolymorphicValue != "" { + value = relationship.PolymorphicValue + } + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), value) } scope.Err(query.Find(value).Error) }