diff --git a/association.go b/association.go index 86b586e4..69be0d01 100644 --- a/association.go +++ b/association.go @@ -7,11 +7,12 @@ import ( ) type Association struct { - Scope *Scope - PrimaryKey interface{} - Column string - Error error - Field *Field + Scope *Scope + PrimaryKey interface{} + PrimaryType interface{} + Column string + Error error + Field *Field } func (association *Association) err(err error) *Association { @@ -172,7 +173,11 @@ func (association *Association) Count() int { scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).Count(&count) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignKey))) - scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).Count(&count) + countScope := scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey) + if relationship.ForeignType != "" { + countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignType))), association.PrimaryType) + } + countScope.Count(&count) } else if relationship.Kind == "belongs_to" { if v, err := scope.FieldValueByName(association.Column); err == nil { whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignKey))) diff --git a/callback_shared.go b/callback_shared.go index f13cec9d..77f513be 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -35,6 +35,10 @@ func SaveBeforeAssociations(scope *Scope) { if relationship.ForeignKey != "" { scope.SetColumn(relationship.ForeignKey, newDB.NewScope(value.Interface()).PrimaryKeyValue()) } + if relationship.ForeignType != "" { + scope.Err(fmt.Errorf("gorm does not support polymorphic belongs_to associations")) + return + } } } } @@ -57,10 +61,17 @@ func SaveAfterAssociations(scope *Scope) { if relationship.JoinTable == "" && relationship.ForeignKey != "" { newDB.NewScope(elem).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue()) } + if relationship.ForeignType != "" { + newDB.NewScope(elem).SetColumn(relationship.ForeignType, scope.TableName()) + } scope.Err(newDB.Save(elem).Error) if relationship.JoinTable != "" { + if relationship.ForeignType != "" { + scope.Err(fmt.Errorf("gorm does not support polymorphic many-to-many associations")) + } + newScope := scope.New(elem) joinTable := relationship.JoinTable foreignKey := ToSnake(relationship.ForeignKey) @@ -89,6 +100,9 @@ func SaveAfterAssociations(scope *Scope) { if relationship.ForeignKey != "" { newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue()) } + if relationship.ForeignType != "" { + newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignType, scope.TableName()) + } scope.Err(newDB.Save(value.Addr().Interface()).Error) } else { destValue := reflect.New(field.Field.Type()).Elem() @@ -101,6 +115,9 @@ func SaveAfterAssociations(scope *Scope) { if relationship.ForeignKey != "" { newDB.NewScope(elem).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue()) } + if relationship.ForeignType != "" { + newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignType, scope.TableName()) + } scope.Err(newDB.Save(elem).Error) scope.SetColumn(field.Name, destValue.Interface()) } diff --git a/field.go b/field.go index f00c6b80..af278e44 100644 --- a/field.go +++ b/field.go @@ -10,6 +10,7 @@ import ( type relationship struct { JoinTable string ForeignKey string + ForeignType string AssociationForeignKey string Kind string } diff --git a/main.go b/main.go index 4b0f07f7..49686e23 100644 --- a/main.go +++ b/main.go @@ -406,6 +406,7 @@ func (s *DB) Association(column string) *Association { scope := s.clone().NewScope(s.Value) primaryKey := scope.PrimaryKeyValue() + primaryType := scope.TableName() if reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) { scope.Err(errors.New("primary key can't be nil")) } @@ -420,7 +421,7 @@ func (s *DB) Association(column string) *Association { scope.Err(fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)) } - return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, Field: field} + return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, PrimaryType: primaryType, Field: field} } // Set set value by name diff --git a/scope.go b/scope.go index 8c09cedc..a90e4f47 100644 --- a/scope.go +++ b/scope.go @@ -334,8 +334,15 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField, withRelatio scopeTyp := scope.IndirectValue().Type() foreignKey := SnakeToUpperCamel(settings["FOREIGNKEY"]) + foreignType := SnakeToUpperCamel(settings["FOREIGNTYPE"]) associationForeignKey := SnakeToUpperCamel(settings["ASSOCIATIONFOREIGNKEY"]) many2many := settings["MANY2MANY"] + polymorphic := SnakeToUpperCamel(settings["POLYMORPHIC"]) + + if polymorphic != "" { + foreignKey = polymorphic + "Id" + foreignType = polymorphic + "Type" + } switch indirectValue.Kind() { case reflect.Slice: @@ -359,6 +366,7 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField, withRelatio field.Relationship = &relationship{ JoinTable: many2many, ForeignKey: foreignKey, + ForeignType: foreignType, AssociationForeignKey: associationForeignKey, Kind: "has_many", } @@ -400,7 +408,7 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField, withRelatio kind = "has_one" } - field.Relationship = &relationship{ForeignKey: foreignKey, Kind: kind} + field.Relationship = &relationship{ForeignKey: foreignKey, ForeignType: foreignType, Kind: kind} } default: field.IsNormal = true diff --git a/scope_private.go b/scope_private.go index e36dbec7..c2a1428e 100644 --- a/scope_private.go +++ b/scope_private.go @@ -489,29 +489,52 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { foreignKey = keys[1] } + var relationship *relationship + var field *Field + var scopeHasField bool + if field, scopeHasField = scope.FieldByName(foreignKey); scopeHasField { + relationship = field.Relationship + } + if scopeType == "" || scopeType == fromScopeType { - if field, ok := scope.FieldByName(foreignKey); ok { - relationship := field.Relationship + if scopeHasField { if relationship != nil && relationship.ForeignKey != "" { foreignKey = relationship.ForeignKey - - if relationship.Kind == "many_to_many" { - joinSql := fmt.Sprintf( - "INNER JOIN %v ON %v.%v = %v.%v", - scope.Quote(relationship.JoinTable), - scope.Quote(relationship.JoinTable), - scope.Quote(ToSnake(relationship.AssociationForeignKey)), - toScope.QuotedTableName(), - scope.Quote(toScope.PrimaryKey())) - whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(ToSnake(relationship.ForeignKey))) - toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value) - return scope - } } - // has one + if relationship != nil && relationship.Kind == "many_to_many" { + if relationship.ForeignType != "" { + scope.Err(fmt.Errorf("gorm does not support polymorphic many-to-many associations")) + } + joinSql := fmt.Sprintf( + "INNER JOIN %v ON %v.%v = %v.%v", + scope.Quote(relationship.JoinTable), + scope.Quote(relationship.JoinTable), + scope.Quote(ToSnake(relationship.AssociationForeignKey)), + toScope.QuotedTableName(), + scope.Quote(toScope.PrimaryKey())) + whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(ToSnake(relationship.ForeignKey))) + toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value) + return scope + } + + // has many or has one + if toScope.HasColumn(foreignKey) { + toScope.inlineCondition(fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey))), scope.PrimaryKeyValue()) + if relationship != nil && relationship.ForeignType != "" && toScope.HasColumn(relationship.ForeignType) { + toScope.inlineCondition(fmt.Sprintf("%v = ?", scope.Quote(ToSnake(relationship.ForeignType))), scope.TableName()) + } + toScope.callCallbacks(scope.db.parent.callback.queries) + return scope + } + + // belongs to if foreignValue, err := scope.FieldValueByName(foreignKey); err == nil { sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) + if relationship != nil && relationship.ForeignType != "" && scope.HasColumn(relationship.ForeignType) { + scope.Err(fmt.Errorf("gorm does not support polymorphic belongs_to associations")) + return scope + } toScope.inlineCondition(sql, foreignValue).callCallbacks(scope.db.parent.callback.queries) return scope } @@ -519,7 +542,7 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { } if scopeType == "" || scopeType == toScopeType { - // has many + // has many or has one in foreign scope if toScope.HasColumn(foreignKey) { sql := fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey))) return toScope.inlineCondition(sql, scope.PrimaryKeyValue()).callCallbacks(scope.db.parent.callback.queries)