diff --git a/association.go b/association.go index 5181279c..8dd844ed 100644 --- a/association.go +++ b/association.go @@ -75,10 +75,8 @@ func (association *Association) Delete(values ...interface{}) *Association { relationship := association.Field.Relationship // many to many if relationship.Kind == "many_to_many" { - sql := fmt.Sprintf("%v.%v = ? AND %v.%v IN (?)", - scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName), - scope.Quote(relationship.JoinTable), scope.Quote(relationship.AssociationForeignDBName)) - query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey, primaryKeys) + sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) + query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys) if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil { leftValues := reflect.Zero(association.Field.Field.Type()) for i := 0; i < association.Field.Field.Len(); i++ { @@ -132,8 +130,8 @@ func (association *Association) Replace(values ...interface{}) *Association { addedPrimaryKeys = append(addedPrimaryKeys, primaryKey) } - sql := fmt.Sprintf("%v.%v = ? AND %v.%v NOT IN (?)", scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.JoinTable), scope.Quote(relationship.AssociationForeignDBName)) - query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey, addedPrimaryKeys) + sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) + query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys) association.setErr(scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship)) } else { association.setErr(errors.New("replace only support many to many")) @@ -145,8 +143,8 @@ func (association *Association) Clear() *Association { relationship := association.Field.Relationship scope := association.Scope if relationship.Kind == "many_to_many" { - sql := fmt.Sprintf("%v.%v = ?", relationship.JoinTable, scope.Quote(relationship.ForeignDBName)) - query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey) + sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName)) + query := scope.NewDB().Where(sql, association.PrimaryKey) if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil { association.Field.Set(reflect.Zero(association.Field.Field.Type())) } else { @@ -165,8 +163,7 @@ func (association *Association) Count() int { newScope := scope.New(association.Field.Field.Interface()) if relationship.Kind == "many_to_many" { - query := scope.DB().Table(relationship.JoinTable). - Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName). + query := scope.DB().Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName). Where(relationship.ForeignDBName+" = ?", association.PrimaryKey) scope.db.GetJoinTableHandler(relationship.JoinTable).Scope(query, relationship).Row().Scan(&count) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { diff --git a/join_table.go b/join_table.go index 7d9c419e..3ffa4f87 100644 --- a/join_table.go +++ b/join_table.go @@ -6,6 +6,7 @@ import ( ) type JoinTableHandler interface { + Table(*DB, *Relationship) string Add(*DB, *Relationship, interface{}, interface{}) error Delete(*DB, *Relationship) error Scope(*DB, *Relationship) *DB @@ -13,17 +14,22 @@ type JoinTableHandler interface { type defaultJoinTableHandler struct{} -func (*defaultJoinTableHandler) Add(db *DB, relationship *Relationship, foreignValue interface{}, associationValue interface{}) error { +func (s *defaultJoinTableHandler) Table(db *DB, relationship *Relationship) string { + return relationship.JoinTable +} + +func (s *defaultJoinTableHandler) Add(db *DB, relationship *Relationship, foreignValue interface{}, associationValue interface{}) error { scope := db.NewScope("") quotedForeignDBName := scope.Quote(relationship.ForeignDBName) quotedAssociationDBName := scope.Quote(relationship.AssociationForeignDBName) + table := s.Table(db, relationship) sql := fmt.Sprintf( "INSERT INTO %v (%v) SELECT ?,? %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = ? AND %v = ?);", - scope.Quote(relationship.JoinTable), + scope.Quote(table), strings.Join([]string{quotedForeignDBName, quotedAssociationDBName}, ","), scope.Dialect().SelectFromDummyTable(), - scope.Quote(relationship.JoinTable), + scope.Quote(table), quotedForeignDBName, quotedAssociationDBName, ) @@ -31,12 +37,12 @@ func (*defaultJoinTableHandler) Add(db *DB, relationship *Relationship, foreignV return db.Exec(sql, foreignValue, associationValue, foreignValue, associationValue).Error } -func (*defaultJoinTableHandler) Delete(db *DB, relationship *Relationship) error { - return db.Delete("").Error +func (s *defaultJoinTableHandler) Delete(db *DB, relationship *Relationship) error { + return db.Table(s.Table(db, relationship)).Delete("").Error } -func (*defaultJoinTableHandler) Scope(db *DB, relationship *Relationship) *DB { - return db +func (s *defaultJoinTableHandler) Scope(db *DB, relationship *Relationship) *DB { + return db.Table(s.Table(db, relationship)) } var DefaultJoinTableHandler = &defaultJoinTableHandler{} diff --git a/join_table_test.go b/join_table_test.go index 76866e39..2624fdb2 100644 --- a/join_table_test.go +++ b/join_table_test.go @@ -21,6 +21,10 @@ type PersonAddress struct { CreatedAt time.Time } +func (*PersonAddress) Table(db *gorm.DB, relationship *gorm.Relationship) string { + return relationship.JoinTable +} + func (*PersonAddress) Add(db *gorm.DB, relationship *gorm.Relationship, foreignValue interface{}, associationValue interface{}) error { return db.Where(map[string]interface{}{ relationship.ForeignDBName: foreignValue, @@ -36,8 +40,9 @@ func (*PersonAddress) Delete(db *gorm.DB, relationship *gorm.Relationship) error return db.Delete(&PersonAddress{}).Error } -func (*PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *gorm.DB { - return db.Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", relationship.JoinTable, relationship.JoinTable)) +func (pa *PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *gorm.DB { + table := pa.Table(db, relationship) + return db.Table(table).Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) } func TestJoinTable(t *testing.T) { diff --git a/scope_private.go b/scope_private.go index f11ee901..5b69c0da 100644 --- a/scope_private.go +++ b/scope_private.go @@ -402,14 +402,17 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { if fromField != nil { if relationship := fromField.Relationship; relationship != nil { if relationship.Kind == "many_to_many" { + joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable) + quotedJoinTable := scope.Quote(joinTableHandler.Table(scope.db, relationship)) + joinSql := fmt.Sprintf( "INNER JOIN %v ON %v.%v = %v.%v", - scope.Quote(relationship.JoinTable), - scope.Quote(relationship.JoinTable), + quotedJoinTable, + quotedJoinTable, scope.Quote(relationship.AssociationForeignDBName), toScope.QuotedTableName(), scope.Quote(toScope.PrimaryKey())) - whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName)) + whereSql := fmt.Sprintf("%v.%v = ?", quotedJoinTable, scope.Quote(relationship.ForeignDBName)) scope.Err(toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value).Error) } else if relationship.Kind == "belongs_to" { sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) @@ -441,16 +444,18 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { func (scope *Scope) createJoinTable(field *StructField) { if relationship := field.Relationship; relationship != nil && relationship.JoinTable != "" { - if !scope.Dialect().HasTable(scope, relationship.JoinTable) { + joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable) + joinTable := joinTableHandler.Table(scope.db, relationship) + if !scope.Dialect().HasTable(scope, joinTable) { primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryKeyField().Field, 255) scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", - scope.Quote(relationship.JoinTable), + scope.Quote(joinTable), strings.Join([]string{ scope.Quote(relationship.ForeignDBName) + " " + primaryKeySqlType, scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")), ).Error) } - scope.NewDB().Table(relationship.JoinTable).AutoMigrate(scope.db.GetJoinTableHandler(relationship.JoinTable)) + scope.NewDB().Table(joinTable).AutoMigrate() } }