Merge pull request #303 from jnfeinstein/dev_poly

Support polymorphic has-one and has-many associations
This commit is contained in:
Jinzhu 2014-11-28 10:10:47 +08:00
commit 6d13ae4ead
8 changed files with 172 additions and 25 deletions

View File

@ -17,6 +17,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
* Iteration Support via [Rows](#row--rows) * Iteration Support via [Rows](#row--rows)
* Scopes * Scopes
* sql.Scanner support * sql.Scanner support
* Polymorphism
* Every feature comes with tests * Every feature comes with tests
* Convention Over Configuration * Convention Over Configuration
* Developer Friendly * Developer Friendly
@ -507,6 +508,32 @@ db.Model(&user).Association("Languages").Clear()
// Remove all relations between the user and languages // Remove all relations between the user and languages
``` ```
### Polymorphism
Supports polymorphic has-many and has-one associations.
```go
type Cat struct {
Id int
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
}
type Dog struct {
Id int
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
}
type Toy struct {
Id int
Name string
OwnerId int
OwnerType int
}
```
Note: polymorphic belongs-to and many-to-many are explicitly NOT supported, and will throw errors.
## Advanced Usage ## Advanced Usage
## FirstOrInit ## FirstOrInit

View File

@ -7,11 +7,12 @@ import (
) )
type Association struct { type Association struct {
Scope *Scope Scope *Scope
PrimaryKey interface{} PrimaryKey interface{}
Column string PrimaryType interface{}
Error error Column string
Field *Field Error error
Field *Field
} }
func (association *Association) err(err error) *Association { func (association *Association) err(err error) *Association {
@ -172,7 +173,11 @@ func (association *Association) Count() int {
scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).Count(&count) scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).Count(&count)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignKey))) whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignKey)))
scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).Count(&count) countScope := scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey)
if relationship.ForeignType != "" {
countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignType))), association.PrimaryType)
}
countScope.Count(&count)
} else if relationship.Kind == "belongs_to" { } else if relationship.Kind == "belongs_to" {
if v, err := scope.FieldValueByName(association.Column); err == nil { if v, err := scope.FieldValueByName(association.Column); err == nil {
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignKey))) whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignKey)))

View File

@ -1,6 +1,29 @@
package gorm_test package gorm_test
import "testing" import "testing"
import "github.com/jinzhu/gorm"
type Cat struct {
Id int
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
}
type Dog struct {
Id int
Name string
Toys []Toy `gorm:"polymorphic:Owner;"`
}
type Toy struct {
Id int
Name string
OwnerId int
OwnerType string
// Define the owner type as a belongs_to so we can ensure it throws an error
Owner Dog `gorm:"foreignkey:owner_id; foreigntype:owner_type;"`
}
func TestHasOneAndHasManyAssociation(t *testing.T) { func TestHasOneAndHasManyAssociation(t *testing.T) {
DB.DropTable(Category{}) DB.DropTable(Category{})
@ -208,3 +231,45 @@ func TestManyToMany(t *testing.T) {
t.Errorf("Relations should be cleared") t.Errorf("Relations should be cleared")
} }
} }
func TestPolymorphic(t *testing.T) {
DB.DropTableIfExists(Cat{})
DB.DropTableIfExists(Dog{})
DB.DropTableIfExists(Toy{})
DB.AutoMigrate(&Cat{})
DB.AutoMigrate(&Dog{})
DB.AutoMigrate(&Toy{})
cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat nip"}}
dog := Dog{Name: "Pluto", Toys: []Toy{Toy{Name: "orange ball"}, Toy{Name: "yellow ball"}}}
DB.Save(&cat).Save(&dog)
var catToys []Toy
if err := DB.Model(&cat).Related(&catToys, "Toy").Error; err == gorm.RecordNotFound {
t.Errorf("Did not find any has one polymorphic association")
} else if len(catToys) != 1 {
t.Errorf("Should have found only one polymorphic has one association")
} else if catToys[0].Name != cat.Toy.Name {
t.Errorf("Should have found the proper has one polymorphic association")
}
var dogToys []Toy
if err := DB.Model(&dog).Related(&dogToys, "Toys").Error; err == gorm.RecordNotFound {
t.Errorf("Did not find any polymorphic has many associations")
} else if len(dogToys) != len(dog.Toys) {
t.Errorf("Should have found all polymorphic has many associations")
}
if DB.Model(&cat).Association("Toy").Count() != 1 {
t.Errorf("Should return one polymorphic has one association")
}
if DB.Model(&dog).Association("Toys").Count() != 2 {
t.Errorf("Should return two polymorphic has many associations")
}
if DB.Model(&Toy{OwnerId: dog.Id, OwnerType: "dog"}).Related(&dog, "Owner").Error == nil {
t.Errorf("Should have thrown unsupported belongs_to error")
}
}

View File

@ -35,6 +35,10 @@ func SaveBeforeAssociations(scope *Scope) {
if relationship.ForeignKey != "" { if relationship.ForeignKey != "" {
scope.SetColumn(relationship.ForeignKey, newDB.NewScope(value.Interface()).PrimaryKeyValue()) scope.SetColumn(relationship.ForeignKey, newDB.NewScope(value.Interface()).PrimaryKeyValue())
} }
if relationship.ForeignType != "" {
scope.Err(fmt.Errorf("gorm does not support polymorphic belongs_to associations"))
return
}
} }
} }
} }
@ -57,10 +61,17 @@ func SaveAfterAssociations(scope *Scope) {
if relationship.JoinTable == "" && relationship.ForeignKey != "" { if relationship.JoinTable == "" && relationship.ForeignKey != "" {
newDB.NewScope(elem).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue()) newDB.NewScope(elem).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue())
} }
if relationship.ForeignType != "" {
newDB.NewScope(elem).SetColumn(relationship.ForeignType, scope.TableName())
}
scope.Err(newDB.Save(elem).Error) scope.Err(newDB.Save(elem).Error)
if relationship.JoinTable != "" { if relationship.JoinTable != "" {
if relationship.ForeignType != "" {
scope.Err(fmt.Errorf("gorm does not support polymorphic many-to-many associations"))
}
newScope := scope.New(elem) newScope := scope.New(elem)
joinTable := relationship.JoinTable joinTable := relationship.JoinTable
foreignKey := ToSnake(relationship.ForeignKey) foreignKey := ToSnake(relationship.ForeignKey)
@ -89,6 +100,9 @@ func SaveAfterAssociations(scope *Scope) {
if relationship.ForeignKey != "" { if relationship.ForeignKey != "" {
newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue()) newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue())
} }
if relationship.ForeignType != "" {
newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignType, scope.TableName())
}
scope.Err(newDB.Save(value.Addr().Interface()).Error) scope.Err(newDB.Save(value.Addr().Interface()).Error)
} else { } else {
destValue := reflect.New(field.Field.Type()).Elem() destValue := reflect.New(field.Field.Type()).Elem()
@ -101,6 +115,9 @@ func SaveAfterAssociations(scope *Scope) {
if relationship.ForeignKey != "" { if relationship.ForeignKey != "" {
newDB.NewScope(elem).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue()) newDB.NewScope(elem).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue())
} }
if relationship.ForeignType != "" {
newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignType, scope.TableName())
}
scope.Err(newDB.Save(elem).Error) scope.Err(newDB.Save(elem).Error)
scope.SetColumn(field.Name, destValue.Interface()) scope.SetColumn(field.Name, destValue.Interface())
} }

View File

@ -10,6 +10,7 @@ import (
type relationship struct { type relationship struct {
JoinTable string JoinTable string
ForeignKey string ForeignKey string
ForeignType string
AssociationForeignKey string AssociationForeignKey string
Kind string Kind string
} }

View File

@ -406,6 +406,7 @@ func (s *DB) Association(column string) *Association {
scope := s.clone().NewScope(s.Value) scope := s.clone().NewScope(s.Value)
primaryKey := scope.PrimaryKeyValue() primaryKey := scope.PrimaryKeyValue()
primaryType := scope.TableName()
if reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) { if reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) {
scope.Err(errors.New("primary key can't be nil")) scope.Err(errors.New("primary key can't be nil"))
} }
@ -420,7 +421,7 @@ func (s *DB) Association(column string) *Association {
scope.Err(fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)) scope.Err(fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column))
} }
return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, Field: field} return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, PrimaryType: primaryType, Field: field}
} }
// Set set value by name // Set set value by name

View File

@ -334,8 +334,15 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField, withRelatio
scopeTyp := scope.IndirectValue().Type() scopeTyp := scope.IndirectValue().Type()
foreignKey := SnakeToUpperCamel(settings["FOREIGNKEY"]) foreignKey := SnakeToUpperCamel(settings["FOREIGNKEY"])
foreignType := SnakeToUpperCamel(settings["FOREIGNTYPE"])
associationForeignKey := SnakeToUpperCamel(settings["ASSOCIATIONFOREIGNKEY"]) associationForeignKey := SnakeToUpperCamel(settings["ASSOCIATIONFOREIGNKEY"])
many2many := settings["MANY2MANY"] many2many := settings["MANY2MANY"]
polymorphic := SnakeToUpperCamel(settings["POLYMORPHIC"])
if polymorphic != "" {
foreignKey = polymorphic + "Id"
foreignType = polymorphic + "Type"
}
switch indirectValue.Kind() { switch indirectValue.Kind() {
case reflect.Slice: case reflect.Slice:
@ -359,6 +366,7 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField, withRelatio
field.Relationship = &relationship{ field.Relationship = &relationship{
JoinTable: many2many, JoinTable: many2many,
ForeignKey: foreignKey, ForeignKey: foreignKey,
ForeignType: foreignType,
AssociationForeignKey: associationForeignKey, AssociationForeignKey: associationForeignKey,
Kind: "has_many", Kind: "has_many",
} }
@ -400,7 +408,7 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField, withRelatio
kind = "has_one" kind = "has_one"
} }
field.Relationship = &relationship{ForeignKey: foreignKey, Kind: kind} field.Relationship = &relationship{ForeignKey: foreignKey, ForeignType: foreignType, Kind: kind}
} }
default: default:
field.IsNormal = true field.IsNormal = true

View File

@ -489,29 +489,52 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
foreignKey = keys[1] foreignKey = keys[1]
} }
var relationship *relationship
var field *Field
var scopeHasField bool
if field, scopeHasField = scope.FieldByName(foreignKey); scopeHasField {
relationship = field.Relationship
}
if scopeType == "" || scopeType == fromScopeType { if scopeType == "" || scopeType == fromScopeType {
if field, ok := scope.FieldByName(foreignKey); ok { if scopeHasField {
relationship := field.Relationship
if relationship != nil && relationship.ForeignKey != "" { if relationship != nil && relationship.ForeignKey != "" {
foreignKey = relationship.ForeignKey foreignKey = relationship.ForeignKey
if relationship.Kind == "many_to_many" {
joinSql := fmt.Sprintf(
"INNER JOIN %v ON %v.%v = %v.%v",
scope.Quote(relationship.JoinTable),
scope.Quote(relationship.JoinTable),
scope.Quote(ToSnake(relationship.AssociationForeignKey)),
toScope.QuotedTableName(),
scope.Quote(toScope.PrimaryKey()))
whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(ToSnake(relationship.ForeignKey)))
toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value)
return scope
}
} }
// has one if relationship != nil && relationship.Kind == "many_to_many" {
if relationship.ForeignType != "" {
scope.Err(fmt.Errorf("gorm does not support polymorphic many-to-many associations"))
}
joinSql := fmt.Sprintf(
"INNER JOIN %v ON %v.%v = %v.%v",
scope.Quote(relationship.JoinTable),
scope.Quote(relationship.JoinTable),
scope.Quote(ToSnake(relationship.AssociationForeignKey)),
toScope.QuotedTableName(),
scope.Quote(toScope.PrimaryKey()))
whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(ToSnake(relationship.ForeignKey)))
toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value)
return scope
}
// has many or has one
if toScope.HasColumn(foreignKey) {
toScope.inlineCondition(fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey))), scope.PrimaryKeyValue())
if relationship != nil && relationship.ForeignType != "" && toScope.HasColumn(relationship.ForeignType) {
toScope.inlineCondition(fmt.Sprintf("%v = ?", scope.Quote(ToSnake(relationship.ForeignType))), scope.TableName())
}
toScope.callCallbacks(scope.db.parent.callback.queries)
return scope
}
// belongs to
if foreignValue, err := scope.FieldValueByName(foreignKey); err == nil { if foreignValue, err := scope.FieldValueByName(foreignKey); err == nil {
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
if relationship != nil && relationship.ForeignType != "" && scope.HasColumn(relationship.ForeignType) {
scope.Err(fmt.Errorf("gorm does not support polymorphic belongs_to associations"))
return scope
}
toScope.inlineCondition(sql, foreignValue).callCallbacks(scope.db.parent.callback.queries) toScope.inlineCondition(sql, foreignValue).callCallbacks(scope.db.parent.callback.queries)
return scope return scope
} }
@ -519,7 +542,7 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
} }
if scopeType == "" || scopeType == toScopeType { if scopeType == "" || scopeType == toScopeType {
// has many // has many or has one in foreign scope
if toScope.HasColumn(foreignKey) { if toScope.HasColumn(foreignKey) {
sql := fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey))) sql := fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey)))
return toScope.inlineCondition(sql, scope.PrimaryKeyValue()).callCallbacks(scope.db.parent.callback.queries) return toScope.inlineCondition(sql, scope.PrimaryKeyValue()).callCallbacks(scope.db.parent.callback.queries)