Add gorm:association:source for association operations for plugins to extend GORM

This commit is contained in:
Jinzhu 2016-12-05 18:30:07 +08:00
parent eb06255b66
commit 0f2ceb5a77
2 changed files with 10 additions and 11 deletions

View File

@ -598,7 +598,7 @@ func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate
// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode // Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode
func (s *DB) Association(column string) *Association { func (s *DB) Association(column string) *Association {
var err error var err error
scope := s.clone().NewScope(s.Value) var scope = s.Set("gorm:association:source", s.Value).NewScope(s.Value)
if primaryField := scope.PrimaryField(); primaryField.IsBlank { if primaryField := scope.PrimaryField(); primaryField.IsBlank {
err = errors.New("primary key can't be nil") err = errors.New("primary key can't be nil")

View File

@ -982,6 +982,7 @@ func (scope *Scope) shouldSaveAssociations() bool {
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
toScope := scope.db.NewScope(value) toScope := scope.db.NewScope(value)
tx := scope.db.Set("gorm:association:source", scope.Value)
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
fromField, _ := scope.FieldByName(foreignKey) fromField, _ := scope.FieldByName(foreignKey)
@ -991,36 +992,34 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
if relationship := fromField.Relationship; relationship != nil { if relationship := fromField.Relationship; relationship != nil {
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
joinTableHandler := relationship.JoinTableHandler joinTableHandler := relationship.JoinTableHandler
scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error) scope.Err(joinTableHandler.JoinWith(joinTableHandler, tx, scope.Value).Find(value).Error)
} else if relationship.Kind == "belongs_to" { } else if relationship.Kind == "belongs_to" {
query := toScope.db
for idx, foreignKey := range relationship.ForeignDBNames { for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(foreignKey); ok { if field, ok := scope.FieldByName(foreignKey); ok {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface())
} }
} }
scope.Err(query.Find(value).Error) scope.Err(tx.Find(value).Error)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
query := toScope.db
for idx, foreignKey := range relationship.ForeignDBNames { for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
} }
} }
if relationship.PolymorphicType != "" { if relationship.PolymorphicType != "" {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue)
} }
scope.Err(query.Find(value).Error) scope.Err(tx.Find(value).Error)
} }
} else { } else {
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error) scope.Err(tx.Where(sql, fromField.Field.Interface()).Find(value).Error)
} }
return scope return scope
} else if toField != nil { } else if toField != nil {
sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName))
scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) scope.Err(tx.Where(sql, scope.PrimaryKeyValue()).Find(value).Error)
return scope return scope
} }
} }