diff --git a/join_table_handler.go b/join_table_handler.go index 878bb491..006701a6 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -133,17 +133,17 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { var ( - scope = db.NewScope(source) - modelType = scope.GetModelStruct().ModelType - quotedTable = scope.Quote(handler.Table(db)) - joinConditions []string - values []interface{} + scope = db.NewScope(source) + tableName = handler.Table(db) + quotedTableName = scope.Quote(tableName) + joinConditions []string + values []interface{} ) - if s.Source.ModelType == modelType { + if s.Source.ModelType == scope.GetModelStruct().ModelType { destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() for _, foreignKey := range s.Destination.ForeignKeys { - joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) + joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) } var foreignDBNames []string @@ -158,7 +158,12 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so var condString string if len(foreignFieldValues) > 0 { - condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, foreignDBNames), toQueryMarks(foreignFieldValues)) + var quotedForeignDBNames []string + for _, dbName := range foreignDBNames { + quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName) + } + + condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) keys := scope.getColumnAsArray(foreignFieldNames) values = append(values, toQueryValues(keys)) @@ -166,7 +171,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so condString = fmt.Sprintf("1 <> 1") } - return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))). + return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))). Where(condString, toQueryValues(foreignFieldValues)...) } else { db.Error = errors.New("wrong source type for join table handler")