diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c54d572d..52dbd8b2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -27,26 +27,26 @@ package main import ( - _ "github.com/mattn/go-sqlite3" - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - "github.com/jinzhu/gorm" + _ "github.com/go-sql-driver/mysql" + "github.com/jinzhu/gorm" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" ) -var db gorm.DB +var db *gorm.DB func init() { - var err error - db, err = gorm.Open("sqlite3", "test.db") - // db, err := gorm.Open("postgres", "user=username dbname=password sslmode=disable") - // db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True") - if err != nil { - panic(err) - } - db.LogMode(true) + var err error + db, err = gorm.Open("sqlite3", "test.db") + // db, err = gorm.Open("postgres", "user=username dbname=password sslmode=disable") + // db, err = gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True") + if err != nil { + panic(err) + } + db.LogMode(true) } func main() { - // Your code + // Your code } ``` diff --git a/association.go b/association.go index 0f94683d..9a3a338b 100644 --- a/association.go +++ b/association.go @@ -63,7 +63,7 @@ func (association *Association) Replace(values ...interface{}) *Association { } else { // Polymorphic Relations if relationship.PolymorphicDBName != "" { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) + newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) } // Delete Relations except new created @@ -284,7 +284,7 @@ func (association *Association) Count() int { if relationship.PolymorphicType != "" { query = query.Where( fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)), - scope.TableName(), + relationship.PolymorphicValue, ) } diff --git a/callback_query_preload.go b/callback_query_preload.go index d9ec8bdd..efa65bb5 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -114,7 +114,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) values := toQueryValues(primaryKeys) if relation.PolymorphicType != "" { query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) - values = append(values, scope.TableName()) + values = append(values, relation.PolymorphicValue) } results := makeSlice(field.Struct.Type) @@ -163,7 +163,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) values := toQueryValues(primaryKeys) if relation.PolymorphicType != "" { query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) - values = append(values, scope.TableName()) + values = append(values, relation.PolymorphicValue) } results := makeSlice(field.Struct.Type) diff --git a/callback_save.go b/callback_save.go index 5ffe53b9..ea9ec174 100644 --- a/callback_save.go +++ b/callback_save.go @@ -60,7 +60,7 @@ func saveAfterAssociationsCallback(scope *Scope) { } if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName())) + scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) } scope.Err(newDB.Save(elem).Error) @@ -82,7 +82,7 @@ func saveAfterAssociationsCallback(scope *Scope) { } if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName())) + scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) } scope.Err(scope.NewDB().Save(elem).Error) } diff --git a/main.go b/main.go index a6add969..52a536d0 100644 --- a/main.go +++ b/main.go @@ -75,6 +75,9 @@ func Open(dialect string, args ...interface{}) (*DB, error) { if err == nil { err = db.DB().Ping() // Send a ping to make sure the database connection is alive. + if err != nil { + db.DB().Close() + } } } diff --git a/migration_test.go b/migration_test.go index 08633283..1545ec2c 100644 --- a/migration_test.go +++ b/migration_test.go @@ -252,7 +252,7 @@ func runMigration() { DB.Exec(fmt.Sprintf("drop table %v;", table)) } - values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Toy{}, &ElementWithIgnoredField{}} + values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}} for _, value := range values { DB.DropTable(value) } diff --git a/model_struct.go b/model_struct.go index 8179a1f3..9a609585 100644 --- a/model_struct.go +++ b/model_struct.go @@ -99,6 +99,7 @@ type Relationship struct { Kind string PolymorphicType string PolymorphicDBName string + PolymorphicValue string ForeignFieldNames []string ForeignDBNames []string AssociationForeignFieldNames []string @@ -292,6 +293,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct { associationType = polymorphic relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicDBName = polymorphicType.DBName + // if Dog has multiple set of toys set name of the set (instead of default 'dogs') + if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { + relationship.PolymorphicValue = value + } else { + relationship.PolymorphicValue = scope.TableName() + } polymorphicType.IsForeignKey = true } } @@ -384,6 +391,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct { associationType = polymorphic relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicDBName = polymorphicType.DBName + // if Cat has several different types of toys set name for each (instead of default 'cats') + if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { + relationship.PolymorphicValue = value + } else { + relationship.PolymorphicValue = scope.TableName() + } polymorphicType.IsForeignKey = true } } diff --git a/polymorphic_test.go b/polymorphic_test.go index df573f97..d1ecfbbb 100644 --- a/polymorphic_test.go +++ b/polymorphic_test.go @@ -18,6 +18,13 @@ type Dog struct { Toys []Toy `gorm:"polymorphic:Owner;"` } +type Hamster struct { + Id int + Name string + PreferredToy Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_preferred"` + OtherToy Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_other"` +} + type Toy struct { Id int Name string @@ -217,3 +224,143 @@ func TestPolymorphic(t *testing.T) { t.Errorf("Dog's toys should be cleared with Clear") } } + +func TestNamedPolymorphic(t *testing.T) { + hamster := Hamster{Name: "Mr. Hammond", PreferredToy: Toy{Name: "bike"}, OtherToy: Toy{Name: "treadmill"}} + DB.Save(&hamster) + + hamster2 := Hamster{} + DB.Preload("PreferredToy").Preload("OtherToy").Find(&hamster2, hamster.Id) + if hamster2.PreferredToy.Id != hamster.PreferredToy.Id || hamster2.PreferredToy.Name != hamster.PreferredToy.Name { + t.Errorf("Hamster's preferred toy couldn't be preloaded") + } + if hamster2.OtherToy.Id != hamster.OtherToy.Id || hamster2.OtherToy.Name != hamster.OtherToy.Name { + t.Errorf("Hamster's other toy couldn't be preloaded") + } + + // clear to omit Toy.Id in count + hamster2.PreferredToy = Toy{} + hamster2.OtherToy = Toy{} + + if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { + t.Errorf("Hamster's preferred toy count should be 1") + } + + if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { + t.Errorf("Hamster's other toy count should be 1") + } + + // Query + var hamsterToys []Toy + if DB.Model(&hamster).Related(&hamsterToys, "PreferredToy").RecordNotFound() { + t.Errorf("Did not find any has one polymorphic association") + } else if len(hamsterToys) != 1 { + t.Errorf("Should have found only one polymorphic has one association") + } else if hamsterToys[0].Name != hamster.PreferredToy.Name { + t.Errorf("Should have found the proper has one polymorphic association") + } + + if DB.Model(&hamster).Related(&hamsterToys, "OtherToy").RecordNotFound() { + t.Errorf("Did not find any has one polymorphic association") + } else if len(hamsterToys) != 1 { + t.Errorf("Should have found only one polymorphic has one association") + } else if hamsterToys[0].Name != hamster.OtherToy.Name { + t.Errorf("Should have found the proper has one polymorphic association") + } + + hamsterToy := Toy{} + DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) + if hamsterToy.Name != hamster.PreferredToy.Name { + t.Errorf("Should find has one polymorphic association") + } + hamsterToy = Toy{} + DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) + if hamsterToy.Name != hamster.OtherToy.Name { + t.Errorf("Should find has one polymorphic association") + } + + // Append + DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ + Name: "bike 2", + }) + DB.Model(&hamster).Association("OtherToy").Append(&Toy{ + Name: "treadmill 2", + }) + + hamsterToy = Toy{} + DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) + if hamsterToy.Name != "bike 2" { + t.Errorf("Should update has one polymorphic association with Append") + } + + hamsterToy = Toy{} + DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) + if hamsterToy.Name != "treadmill 2" { + t.Errorf("Should update has one polymorphic association with Append") + } + + if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { + t.Errorf("Hamster's toys count should be 1 after Append") + } + + if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { + t.Errorf("Hamster's toys count should be 1 after Append") + } + + // Replace + DB.Model(&hamster).Association("PreferredToy").Replace(&Toy{ + Name: "bike 3", + }) + DB.Model(&hamster).Association("OtherToy").Replace(&Toy{ + Name: "treadmill 3", + }) + + hamsterToy = Toy{} + DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) + if hamsterToy.Name != "bike 3" { + t.Errorf("Should update has one polymorphic association with Replace") + } + + hamsterToy = Toy{} + DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) + if hamsterToy.Name != "treadmill 3" { + t.Errorf("Should update has one polymorphic association with Replace") + } + + if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { + t.Errorf("hamster's toys count should be 1 after Replace") + } + + if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { + t.Errorf("hamster's toys count should be 1 after Replace") + } + + // Clear + DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ + Name: "bike 2", + }) + DB.Model(&hamster).Association("OtherToy").Append(&Toy{ + Name: "treadmill 2", + }) + + if DB.Model(&hamster).Association("PreferredToy").Count() != 1 { + t.Errorf("Hamster's toys should be added with Append") + } + if DB.Model(&hamster).Association("OtherToy").Count() != 1 { + t.Errorf("Hamster's toys should be added with Append") + } + + DB.Model(&hamster).Association("PreferredToy").Clear() + + if DB.Model(&hamster2).Association("PreferredToy").Count() != 0 { + t.Errorf("Hamster's preferred toy should be cleared with Clear") + } + if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { + t.Errorf("Hamster's other toy should be still available") + } + + DB.Model(&hamster).Association("OtherToy").Clear() + if DB.Model(&hamster).Association("OtherToy").Count() != 0 { + t.Errorf("Hamster's other toy should be cleared with Clear") + } +} diff --git a/scope.go b/scope.go index c00cd0a1..494a6d76 100644 --- a/scope.go +++ b/scope.go @@ -999,7 +999,11 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { } if relationship.PolymorphicType != "" { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) + value := scope.TableName() + if relationship.PolymorphicValue != "" { + value = relationship.PolymorphicValue + } + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), value) } scope.Err(query.Find(value).Error) }