mirror of https://github.com/go-gorm/gorm.git
Fix preload many2many with multiple primary keys
This commit is contained in:
parent
6a6c1bf762
commit
9982134955
|
@ -142,17 +142,25 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
|
||||||
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
|
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, foreignKey := range s.Source.ForeignKeys {
|
var foreignDBNames []string
|
||||||
condString := fmt.Sprintf("%v.%v in (?)", quotedTable, scope.Quote(foreignKey.DBName))
|
var foreignFieldNames []string
|
||||||
|
|
||||||
keys := scope.getColumnAsArray([]string{scope.Fields()[foreignKey.AssociationDBName].Name})
|
for _, foreignKey := range s.Source.ForeignKeys {
|
||||||
|
foreignDBNames = append(foreignDBNames, foreignKey.DBName)
|
||||||
|
foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
foreignFieldValues := scope.getColumnAsArray(foreignFieldNames)
|
||||||
|
|
||||||
|
condString := fmt.Sprintf("%v in (%v)", toQueryCondition(scope, foreignDBNames), toQueryMarks(foreignFieldValues))
|
||||||
|
|
||||||
|
keys := scope.getColumnAsArray(foreignFieldNames)
|
||||||
values = append(values, toQueryValues(keys))
|
values = append(values, toQueryValues(keys))
|
||||||
|
|
||||||
queryConditions = append(queryConditions, condString)
|
queryConditions = append(queryConditions, condString)
|
||||||
}
|
|
||||||
|
|
||||||
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))).
|
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))).
|
||||||
Where(strings.Join(queryConditions, " AND "), values...)
|
Where(condString, toQueryValues(foreignFieldValues)...)
|
||||||
} else {
|
} else {
|
||||||
db.Error = errors.New("wrong source type for join table handler")
|
db.Error = errors.New("wrong source type for join table handler")
|
||||||
return db
|
return db
|
||||||
|
|
|
@ -66,7 +66,6 @@ type Relationship struct {
|
||||||
PolymorphicType string
|
PolymorphicType string
|
||||||
PolymorphicDBName string
|
PolymorphicDBName string
|
||||||
ForeignFieldNames []string
|
ForeignFieldNames []string
|
||||||
ForeignStructFieldNames []string
|
|
||||||
ForeignDBNames []string
|
ForeignDBNames []string
|
||||||
AssociationForeignFieldNames []string
|
AssociationForeignFieldNames []string
|
||||||
AssociationForeignStructFieldNames []string
|
AssociationForeignStructFieldNames []string
|
||||||
|
@ -226,7 +225,6 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
for _, foreignKey := range foreignKeys {
|
for _, foreignKey := range foreignKeys {
|
||||||
if field, ok := scope.FieldByName(foreignKey); ok {
|
if field, ok := scope.FieldByName(foreignKey); ok {
|
||||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName)
|
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName)
|
||||||
relationship.ForeignStructFieldNames = append(relationship.ForeignFieldNames, field.Name)
|
|
||||||
joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName
|
joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName
|
||||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
|
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
|
||||||
}
|
}
|
||||||
|
|
11
preload.go
11
preload.go
|
@ -267,11 +267,18 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var associationForeignStructFieldNames []string
|
||||||
|
for _, dbName := range relation.AssociationForeignFieldNames {
|
||||||
|
if field, ok := scope.FieldByName(dbName); ok {
|
||||||
|
associationForeignStructFieldNames = append(associationForeignStructFieldNames, field.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||||
objects := scope.IndirectValue()
|
objects := scope.IndirectValue()
|
||||||
for j := 0; j < objects.Len(); j++ {
|
for j := 0; j < objects.Len(); j++ {
|
||||||
object := reflect.Indirect(objects.Index(j))
|
object := reflect.Indirect(objects.Index(j))
|
||||||
source := getRealValue(object, relation.AssociationForeignStructFieldNames)
|
source := getRealValue(object, associationForeignStructFieldNames)
|
||||||
field := object.FieldByName(field.Name)
|
field := object.FieldByName(field.Name)
|
||||||
for _, link := range linkHash[toString(source)] {
|
for _, link := range linkHash[toString(source)] {
|
||||||
field.Set(reflect.Append(field, link))
|
field.Set(reflect.Append(field, link))
|
||||||
|
@ -279,7 +286,7 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
object := scope.IndirectValue()
|
object := scope.IndirectValue()
|
||||||
source := getRealValue(object, relation.AssociationForeignStructFieldNames)
|
source := getRealValue(object, associationForeignStructFieldNames)
|
||||||
field := object.FieldByName(field.Name)
|
field := object.FieldByName(field.Name)
|
||||||
for _, link := range linkHash[toString(source)] {
|
for _, link := range linkHash[toString(source)] {
|
||||||
field.Set(reflect.Append(field, link))
|
field.Set(reflect.Append(field, link))
|
||||||
|
|
|
@ -2,6 +2,7 @@ package gorm_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
@ -603,14 +604,20 @@ func TestNestedPreload9(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManyToManyPreload(t *testing.T) {
|
func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
|
||||||
|
if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
type (
|
type (
|
||||||
Level1 struct {
|
Level1 struct {
|
||||||
ID uint `gorm:"primary_key;"`
|
ID uint `gorm:"primary_key;"`
|
||||||
|
LanguageCode string `gorm:"primary_key"`
|
||||||
Value string
|
Value string
|
||||||
}
|
}
|
||||||
Level2 struct {
|
Level2 struct {
|
||||||
ID uint `gorm:"primary_key;"`
|
ID uint `gorm:"primary_key;"`
|
||||||
|
LanguageCode string `gorm:"primary_key"`
|
||||||
Value string
|
Value string
|
||||||
Level1s []Level1 `gorm:"many2many:levels;"`
|
Level1s []Level1 `gorm:"many2many:levels;"`
|
||||||
}
|
}
|
||||||
|
@ -618,22 +625,23 @@ func TestManyToManyPreload(t *testing.T) {
|
||||||
|
|
||||||
DB.DropTableIfExists(&Level2{})
|
DB.DropTableIfExists(&Level2{})
|
||||||
DB.DropTableIfExists(&Level1{})
|
DB.DropTableIfExists(&Level1{})
|
||||||
|
DB.Table("levels").DropTableIfExists("levels")
|
||||||
|
|
||||||
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
|
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
want := Level2{Value: "Bob", Level1s: []Level1{
|
want := Level2{Value: "Bob", LanguageCode: "ru", Level1s: []Level1{
|
||||||
{Value: "ru"},
|
{Value: "ru", LanguageCode: "ru"},
|
||||||
{Value: "en"},
|
{Value: "en", LanguageCode: "en"},
|
||||||
}}
|
}}
|
||||||
if err := DB.Save(&want).Error; err != nil {
|
if err := DB.Save(&want).Error; err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
want2 := Level2{Value: "Tom", Level1s: []Level1{
|
want2 := Level2{Value: "Tom", LanguageCode: "zh", Level1s: []Level1{
|
||||||
{Value: "zh"},
|
{Value: "zh", LanguageCode: "zh"},
|
||||||
{Value: "de"},
|
{Value: "de", LanguageCode: "de"},
|
||||||
}}
|
}}
|
||||||
if err := DB.Save(&want2).Error; err != nil {
|
if err := DB.Save(&want2).Error; err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
@ -698,6 +706,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) {
|
||||||
|
|
||||||
DB.DropTableIfExists(&Level2{})
|
DB.DropTableIfExists(&Level2{})
|
||||||
DB.DropTableIfExists(&Level1{})
|
DB.DropTableIfExists(&Level1{})
|
||||||
|
DB.Table("levels").DropTableIfExists("levels")
|
||||||
|
|
||||||
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
|
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func fileWithLineNum() string {
|
func fileWithLineNum() string {
|
||||||
|
@ -72,8 +73,18 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} {
|
||||||
return attrs
|
return attrs
|
||||||
}
|
}
|
||||||
|
|
||||||
func toString(a interface{}) string {
|
func toString(str interface{}) string {
|
||||||
return fmt.Sprintf("%v", a)
|
if values, ok := str.([]interface{}); ok {
|
||||||
|
var results []string
|
||||||
|
for _, value := range values {
|
||||||
|
results = append(results, toString(value))
|
||||||
|
}
|
||||||
|
return strings.Join(results, "_")
|
||||||
|
} else if bytes, ok := str.([]byte); ok {
|
||||||
|
return string(bytes)
|
||||||
|
} else {
|
||||||
|
return fmt.Sprintf("%v", str)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func strInSlice(a string, list []string) bool {
|
func strInSlice(a string, list []string) bool {
|
||||||
|
|
Loading…
Reference in New Issue