diff --git a/join_table_handler.go b/join_table_handler.go index c7a28cd3..0a81a929 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -134,7 +134,6 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so scope := db.NewScope(source) modelType := scope.GetModelStruct().ModelType var joinConditions []string - var queryConditions []string var values []interface{} if s.Source.ModelType == modelType { destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() @@ -152,12 +151,15 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so foreignFieldValues := scope.getColumnAsArray(foreignFieldNames) - condString := fmt.Sprintf("%v in (%v)", toQueryCondition(scope, foreignDBNames), toQueryMarks(foreignFieldValues)) + var condString string + if len(foreignFieldValues) > 0 { + condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, foreignDBNames), toQueryMarks(foreignFieldValues)) - keys := scope.getColumnAsArray(foreignFieldNames) - values = append(values, toQueryValues(keys)) - - queryConditions = append(queryConditions, condString) + keys := scope.getColumnAsArray(foreignFieldNames) + values = append(values, toQueryValues(keys)) + } else { + condString = fmt.Sprintf("1 <> 1") + } return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))). Where(condString, toQueryValues(foreignFieldValues)...) diff --git a/preload.go b/preload.go index efdfed07..2d1aed2f 100644 --- a/preload.go +++ b/preload.go @@ -63,7 +63,7 @@ func Preload(scope *Scope) { case "belongs_to": currentScope.handleBelongsToPreload(field, conditions) case "many_to_many": - currentScope.handleHasManyToManyPreload(field, conditions) + currentScope.handleManyToManyPreload(field, conditions) default: currentScope.Err(errors.New("not supported relation")) } @@ -191,9 +191,8 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ } } -func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interface{}) { +func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { relation := field.Relationship - joinTableHandler := relation.JoinTableHandler destType := field.StructField.Struct.Type.Elem() var isPtr bool @@ -211,6 +210,7 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()).Select("*") preloadJoinDB := joinTableHandler.JoinWith(joinTableHandler, db, scope.Value) + if len(conditions) > 0 { preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...) } diff --git a/preload_test.go b/preload_test.go index de00f529..d526e324 100644 --- a/preload_test.go +++ b/preload_test.go @@ -689,6 +689,10 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { if !reflect.DeepEqual(got4, []Level2{got, got2}) { t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) } + + if err := DB.Preload("Level1s").Find(&got4, "value IN (?)", []string{"non-existing"}).Error; err != nil { + panic(err) + } } func TestManyToManyPreloadForPointer(t *testing.T) {