diff --git a/join_table_handler.go b/join_table_handler.go index ad788412..c7a28cd3 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -142,17 +142,25 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) } + var foreignDBNames []string + var foreignFieldNames []string + for _, foreignKey := range s.Source.ForeignKeys { - condString := fmt.Sprintf("%v.%v in (?)", quotedTable, scope.Quote(foreignKey.DBName)) - - keys := scope.getColumnAsArray([]string{scope.Fields()[foreignKey.AssociationDBName].Name}) - values = append(values, toQueryValues(keys)) - - queryConditions = append(queryConditions, condString) + foreignDBNames = append(foreignDBNames, foreignKey.DBName) + foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name) } + foreignFieldValues := scope.getColumnAsArray(foreignFieldNames) + + condString := fmt.Sprintf("%v in (%v)", toQueryCondition(scope, foreignDBNames), toQueryMarks(foreignFieldValues)) + + keys := scope.getColumnAsArray(foreignFieldNames) + values = append(values, toQueryValues(keys)) + + queryConditions = append(queryConditions, condString) + return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))). - Where(strings.Join(queryConditions, " AND "), values...) + Where(condString, toQueryValues(foreignFieldValues)...) } else { db.Error = errors.New("wrong source type for join table handler") return db diff --git a/model_struct.go b/model_struct.go index db6d9a88..f6e035f8 100644 --- a/model_struct.go +++ b/model_struct.go @@ -66,7 +66,6 @@ type Relationship struct { PolymorphicType string PolymorphicDBName string ForeignFieldNames []string - ForeignStructFieldNames []string ForeignDBNames []string AssociationForeignFieldNames []string AssociationForeignStructFieldNames []string @@ -226,7 +225,6 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for _, foreignKey := range foreignKeys { if field, ok := scope.FieldByName(foreignKey); ok { relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName) - relationship.ForeignStructFieldNames = append(relationship.ForeignFieldNames, field.Name) joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) } diff --git a/preload.go b/preload.go index d5c8da10..dd85c327 100644 --- a/preload.go +++ b/preload.go @@ -267,11 +267,18 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf } } + var associationForeignStructFieldNames []string + for _, dbName := range relation.AssociationForeignFieldNames { + if field, ok := scope.FieldByName(dbName); ok { + associationForeignStructFieldNames = append(associationForeignStructFieldNames, field.Name) + } + } + if scope.IndirectValue().Kind() == reflect.Slice { objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(objects.Index(j)) - source := getRealValue(object, relation.AssociationForeignStructFieldNames) + source := getRealValue(object, associationForeignStructFieldNames) field := object.FieldByName(field.Name) for _, link := range linkHash[toString(source)] { field.Set(reflect.Append(field, link)) @@ -279,7 +286,7 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf } } else { object := scope.IndirectValue() - source := getRealValue(object, relation.AssociationForeignStructFieldNames) + source := getRealValue(object, associationForeignStructFieldNames) field := object.FieldByName(field.Name) for _, link := range linkHash[toString(source)] { field.Set(reflect.Append(field, link)) diff --git a/preload_test.go b/preload_test.go index cde0e739..3dcd325b 100644 --- a/preload_test.go +++ b/preload_test.go @@ -2,6 +2,7 @@ package gorm_test import ( "encoding/json" + "os" "reflect" "testing" ) @@ -603,37 +604,44 @@ func TestNestedPreload9(t *testing.T) { } } -func TestManyToManyPreload(t *testing.T) { +func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" { + return + } + type ( Level1 struct { - ID uint `gorm:"primary_key;"` - Value string + ID uint `gorm:"primary_key;"` + LanguageCode string `gorm:"primary_key"` + Value string } Level2 struct { - ID uint `gorm:"primary_key;"` - Value string - Level1s []Level1 `gorm:"many2many:levels;"` + ID uint `gorm:"primary_key;"` + LanguageCode string `gorm:"primary_key"` + Value string + Level1s []Level1 `gorm:"many2many:levels;"` } ) DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level1{}) + DB.Table("levels").DropTableIfExists("levels") if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { panic(err) } - want := Level2{Value: "Bob", Level1s: []Level1{ - {Value: "ru"}, - {Value: "en"}, + want := Level2{Value: "Bob", LanguageCode: "ru", Level1s: []Level1{ + {Value: "ru", LanguageCode: "ru"}, + {Value: "en", LanguageCode: "en"}, }} if err := DB.Save(&want).Error; err != nil { panic(err) } - want2 := Level2{Value: "Tom", Level1s: []Level1{ - {Value: "zh"}, - {Value: "de"}, + want2 := Level2{Value: "Tom", LanguageCode: "zh", Level1s: []Level1{ + {Value: "zh", LanguageCode: "zh"}, + {Value: "de", LanguageCode: "de"}, }} if err := DB.Save(&want2).Error; err != nil { panic(err) @@ -698,6 +706,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) { DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level1{}) + DB.Table("levels").DropTableIfExists("levels") if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { panic(err) diff --git a/utils_private.go b/utils_private.go index b82aa807..50549857 100644 --- a/utils_private.go +++ b/utils_private.go @@ -5,6 +5,7 @@ import ( "reflect" "regexp" "runtime" + "strings" ) func fileWithLineNum() string { @@ -72,8 +73,18 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} { return attrs } -func toString(a interface{}) string { - return fmt.Sprintf("%v", a) +func toString(str interface{}) string { + if values, ok := str.([]interface{}); ok { + var results []string + for _, value := range values { + results = append(results, toString(value)) + } + return strings.Join(results, "_") + } else if bytes, ok := str.([]byte); ok { + return string(bytes) + } else { + return fmt.Sprintf("%v", str) + } } func strInSlice(a string, list []string) bool {