From 0f2ceb5a775714a46bc344976324e3e439f8cdcc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 5 Dec 2016 18:30:07 +0800 Subject: [PATCH] Add gorm:association:source for association operations for plugins to extend GORM --- main.go | 2 +- scope.go | 19 +++++++++---------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/main.go b/main.go index 4f6377d1..7853456c 100644 --- a/main.go +++ b/main.go @@ -598,7 +598,7 @@ func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate // Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode func (s *DB) Association(column string) *Association { var err error - scope := s.clone().NewScope(s.Value) + var scope = s.Set("gorm:association:source", s.Value).NewScope(s.Value) if primaryField := scope.PrimaryField(); primaryField.IsBlank { err = errors.New("primary key can't be nil") diff --git a/scope.go b/scope.go index ebde05a0..484164ad 100644 --- a/scope.go +++ b/scope.go @@ -982,6 +982,7 @@ func (scope *Scope) shouldSaveAssociations() bool { func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { toScope := scope.db.NewScope(value) + tx := scope.db.Set("gorm:association:source", scope.Value) for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { fromField, _ := scope.FieldByName(foreignKey) @@ -991,36 +992,34 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { if relationship := fromField.Relationship; relationship != nil { if relationship.Kind == "many_to_many" { joinTableHandler := relationship.JoinTableHandler - scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error) + scope.Err(joinTableHandler.JoinWith(joinTableHandler, tx, scope.Value).Find(value).Error) } else if relationship.Kind == "belongs_to" { - query := toScope.db for idx, foreignKey := range relationship.ForeignDBNames { if field, ok := scope.FieldByName(foreignKey); ok { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) + tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) } } - scope.Err(query.Find(value).Error) + scope.Err(tx.Find(value).Error) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - query := toScope.db for idx, foreignKey := range relationship.ForeignDBNames { if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) } } if relationship.PolymorphicType != "" { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) + tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) } - scope.Err(query.Find(value).Error) + scope.Err(tx.Find(value).Error) } } else { sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) - scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error) + scope.Err(tx.Where(sql, fromField.Field.Interface()).Find(value).Error) } return scope } else if toField != nil { sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) - scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) + scope.Err(tx.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) return scope } }