diff --git a/scope_private.go b/scope_private.go index b6949db3..209db5bb 100644 --- a/scope_private.go +++ b/scope_private.go @@ -419,38 +419,50 @@ func (scope *Scope) typeName() string { func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { toScope := scope.db.NewScope(value) + fromScopeType := scope.typeName() + toScopeType := toScope.typeName() + scopeType := "" for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { - if field, ok := scope.FieldByName(foreignKey); ok { - relationship := field.Relationship - if relationship != nil && relationship.ForeignKey != "" { - foreignKey = relationship.ForeignKey + if keys := strings.Split(foreignKey, "."); len(keys) > 1 { + scopeType = keys[0] + foreignKey = keys[1] + } - if relationship.Kind == "many_to_many" { - joinSql := fmt.Sprintf( - "INNER JOIN %v ON %v.%v = %v.%v", - scope.Quote(relationship.JoinTable), - scope.Quote(relationship.JoinTable), - scope.Quote(ToSnake(relationship.AssociationForeignKey)), - toScope.QuotedTableName(), - scope.Quote(toScope.PrimaryKey())) - whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(ToSnake(relationship.ForeignKey))) - toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value) + if scopeType == "" || scopeType == fromScopeType { + if field, ok := scope.FieldByName(foreignKey); ok { + relationship := field.Relationship + if relationship != nil && relationship.ForeignKey != "" { + foreignKey = relationship.ForeignKey + + if relationship.Kind == "many_to_many" { + joinSql := fmt.Sprintf( + "INNER JOIN %v ON %v.%v = %v.%v", + scope.Quote(relationship.JoinTable), + scope.Quote(relationship.JoinTable), + scope.Quote(ToSnake(relationship.AssociationForeignKey)), + toScope.QuotedTableName(), + scope.Quote(toScope.PrimaryKey())) + whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(ToSnake(relationship.ForeignKey))) + toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value) + return scope + } + } + + // has one + if foreignValue, ok := scope.FieldValueByName(foreignKey); ok { + toScope.inlineCondition(foreignValue).callCallbacks(scope.db.parent.callback.queries) return scope } } - - // has one - if foreignValue, ok := scope.FieldValueByName(foreignKey); ok { - toScope.inlineCondition(foreignValue).callCallbacks(scope.db.parent.callback.queries) - return scope - } } - // has many - if toScope.HasColumn(foreignKey) { - sql := fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey))) - return toScope.inlineCondition(sql, scope.PrimaryKeyValue()).callCallbacks(scope.db.parent.callback.queries) + if scopeType == "" || scopeType == toScopeType { + // has many + if toScope.HasColumn(foreignKey) { + sql := fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey))) + return toScope.inlineCondition(sql, scope.PrimaryKeyValue()).callCallbacks(scope.db.parent.callback.queries) + } } } scope.Err(fmt.Errorf("invalid association %v", foreignKeys))