mirror of https://github.com/go-gorm/gorm.git
Fix preload many2many with multiple primary keys
This commit is contained in:
parent
6a6c1bf762
commit
9982134955
|
@ -142,17 +142,25 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
|
|||
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
|
||||
}
|
||||
|
||||
for _, foreignKey := range s.Source.ForeignKeys {
|
||||
condString := fmt.Sprintf("%v.%v in (?)", quotedTable, scope.Quote(foreignKey.DBName))
|
||||
var foreignDBNames []string
|
||||
var foreignFieldNames []string
|
||||
|
||||
keys := scope.getColumnAsArray([]string{scope.Fields()[foreignKey.AssociationDBName].Name})
|
||||
for _, foreignKey := range s.Source.ForeignKeys {
|
||||
foreignDBNames = append(foreignDBNames, foreignKey.DBName)
|
||||
foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name)
|
||||
}
|
||||
|
||||
foreignFieldValues := scope.getColumnAsArray(foreignFieldNames)
|
||||
|
||||
condString := fmt.Sprintf("%v in (%v)", toQueryCondition(scope, foreignDBNames), toQueryMarks(foreignFieldValues))
|
||||
|
||||
keys := scope.getColumnAsArray(foreignFieldNames)
|
||||
values = append(values, toQueryValues(keys))
|
||||
|
||||
queryConditions = append(queryConditions, condString)
|
||||
}
|
||||
|
||||
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))).
|
||||
Where(strings.Join(queryConditions, " AND "), values...)
|
||||
Where(condString, toQueryValues(foreignFieldValues)...)
|
||||
} else {
|
||||
db.Error = errors.New("wrong source type for join table handler")
|
||||
return db
|
||||
|
|
|
@ -66,7 +66,6 @@ type Relationship struct {
|
|||
PolymorphicType string
|
||||
PolymorphicDBName string
|
||||
ForeignFieldNames []string
|
||||
ForeignStructFieldNames []string
|
||||
ForeignDBNames []string
|
||||
AssociationForeignFieldNames []string
|
||||
AssociationForeignStructFieldNames []string
|
||||
|
@ -226,7 +225,6 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
for _, foreignKey := range foreignKeys {
|
||||
if field, ok := scope.FieldByName(foreignKey); ok {
|
||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName)
|
||||
relationship.ForeignStructFieldNames = append(relationship.ForeignFieldNames, field.Name)
|
||||
joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName
|
||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
|
||||
}
|
||||
|
|
11
preload.go
11
preload.go
|
@ -267,11 +267,18 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf
|
|||
}
|
||||
}
|
||||
|
||||
var associationForeignStructFieldNames []string
|
||||
for _, dbName := range relation.AssociationForeignFieldNames {
|
||||
if field, ok := scope.FieldByName(dbName); ok {
|
||||
associationForeignStructFieldNames = append(associationForeignStructFieldNames, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||
objects := scope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
object := reflect.Indirect(objects.Index(j))
|
||||
source := getRealValue(object, relation.AssociationForeignStructFieldNames)
|
||||
source := getRealValue(object, associationForeignStructFieldNames)
|
||||
field := object.FieldByName(field.Name)
|
||||
for _, link := range linkHash[toString(source)] {
|
||||
field.Set(reflect.Append(field, link))
|
||||
|
@ -279,7 +286,7 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf
|
|||
}
|
||||
} else {
|
||||
object := scope.IndirectValue()
|
||||
source := getRealValue(object, relation.AssociationForeignStructFieldNames)
|
||||
source := getRealValue(object, associationForeignStructFieldNames)
|
||||
field := object.FieldByName(field.Name)
|
||||
for _, link := range linkHash[toString(source)] {
|
||||
field.Set(reflect.Append(field, link))
|
||||
|
|
|
@ -2,6 +2,7 @@ package gorm_test
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
@ -603,14 +604,20 @@ func TestNestedPreload9(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestManyToManyPreload(t *testing.T) {
|
||||
func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
|
||||
if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" {
|
||||
return
|
||||
}
|
||||
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint `gorm:"primary_key;"`
|
||||
LanguageCode string `gorm:"primary_key"`
|
||||
Value string
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint `gorm:"primary_key;"`
|
||||
LanguageCode string `gorm:"primary_key"`
|
||||
Value string
|
||||
Level1s []Level1 `gorm:"many2many:levels;"`
|
||||
}
|
||||
|
@ -618,22 +625,23 @@ func TestManyToManyPreload(t *testing.T) {
|
|||
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
DB.Table("levels").DropTableIfExists("levels")
|
||||
|
||||
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want := Level2{Value: "Bob", Level1s: []Level1{
|
||||
{Value: "ru"},
|
||||
{Value: "en"},
|
||||
want := Level2{Value: "Bob", LanguageCode: "ru", Level1s: []Level1{
|
||||
{Value: "ru", LanguageCode: "ru"},
|
||||
{Value: "en", LanguageCode: "en"},
|
||||
}}
|
||||
if err := DB.Save(&want).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want2 := Level2{Value: "Tom", Level1s: []Level1{
|
||||
{Value: "zh"},
|
||||
{Value: "de"},
|
||||
want2 := Level2{Value: "Tom", LanguageCode: "zh", Level1s: []Level1{
|
||||
{Value: "zh", LanguageCode: "zh"},
|
||||
{Value: "de", LanguageCode: "de"},
|
||||
}}
|
||||
if err := DB.Save(&want2).Error; err != nil {
|
||||
panic(err)
|
||||
|
@ -698,6 +706,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) {
|
|||
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
DB.Table("levels").DropTableIfExists("levels")
|
||||
|
||||
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
|
||||
panic(err)
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func fileWithLineNum() string {
|
||||
|
@ -72,8 +73,18 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} {
|
|||
return attrs
|
||||
}
|
||||
|
||||
func toString(a interface{}) string {
|
||||
return fmt.Sprintf("%v", a)
|
||||
func toString(str interface{}) string {
|
||||
if values, ok := str.([]interface{}); ok {
|
||||
var results []string
|
||||
for _, value := range values {
|
||||
results = append(results, toString(value))
|
||||
}
|
||||
return strings.Join(results, "_")
|
||||
} else if bytes, ok := str.([]byte); ok {
|
||||
return string(bytes)
|
||||
} else {
|
||||
return fmt.Sprintf("%v", str)
|
||||
}
|
||||
}
|
||||
|
||||
func strInSlice(a string, list []string) bool {
|
||||
|
|
Loading…
Reference in New Issue