Fix preload many2many with multiple primary keys

This commit is contained in:
Jinzhu 2015-08-18 09:08:33 +08:00
parent 6a6c1bf762
commit 9982134955
5 changed files with 58 additions and 25 deletions

View File

@ -142,17 +142,25 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
}
for _, foreignKey := range s.Source.ForeignKeys {
condString := fmt.Sprintf("%v.%v in (?)", quotedTable, scope.Quote(foreignKey.DBName))
var foreignDBNames []string
var foreignFieldNames []string
keys := scope.getColumnAsArray([]string{scope.Fields()[foreignKey.AssociationDBName].Name})
for _, foreignKey := range s.Source.ForeignKeys {
foreignDBNames = append(foreignDBNames, foreignKey.DBName)
foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name)
}
foreignFieldValues := scope.getColumnAsArray(foreignFieldNames)
condString := fmt.Sprintf("%v in (%v)", toQueryCondition(scope, foreignDBNames), toQueryMarks(foreignFieldValues))
keys := scope.getColumnAsArray(foreignFieldNames)
values = append(values, toQueryValues(keys))
queryConditions = append(queryConditions, condString)
}
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))).
Where(strings.Join(queryConditions, " AND "), values...)
Where(condString, toQueryValues(foreignFieldValues)...)
} else {
db.Error = errors.New("wrong source type for join table handler")
return db

View File

@ -66,7 +66,6 @@ type Relationship struct {
PolymorphicType string
PolymorphicDBName string
ForeignFieldNames []string
ForeignStructFieldNames []string
ForeignDBNames []string
AssociationForeignFieldNames []string
AssociationForeignStructFieldNames []string
@ -226,7 +225,6 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
for _, foreignKey := range foreignKeys {
if field, ok := scope.FieldByName(foreignKey); ok {
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName)
relationship.ForeignStructFieldNames = append(relationship.ForeignFieldNames, field.Name)
joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
}

View File

@ -267,11 +267,18 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf
}
}
var associationForeignStructFieldNames []string
for _, dbName := range relation.AssociationForeignFieldNames {
if field, ok := scope.FieldByName(dbName); ok {
associationForeignStructFieldNames = append(associationForeignStructFieldNames, field.Name)
}
}
if scope.IndirectValue().Kind() == reflect.Slice {
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
source := getRealValue(object, relation.AssociationForeignStructFieldNames)
source := getRealValue(object, associationForeignStructFieldNames)
field := object.FieldByName(field.Name)
for _, link := range linkHash[toString(source)] {
field.Set(reflect.Append(field, link))
@ -279,7 +286,7 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf
}
} else {
object := scope.IndirectValue()
source := getRealValue(object, relation.AssociationForeignStructFieldNames)
source := getRealValue(object, associationForeignStructFieldNames)
field := object.FieldByName(field.Name)
for _, link := range linkHash[toString(source)] {
field.Set(reflect.Append(field, link))

View File

@ -2,6 +2,7 @@ package gorm_test
import (
"encoding/json"
"os"
"reflect"
"testing"
)
@ -603,14 +604,20 @@ func TestNestedPreload9(t *testing.T) {
}
}
func TestManyToManyPreload(t *testing.T) {
func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" {
return
}
type (
Level1 struct {
ID uint `gorm:"primary_key;"`
LanguageCode string `gorm:"primary_key"`
Value string
}
Level2 struct {
ID uint `gorm:"primary_key;"`
LanguageCode string `gorm:"primary_key"`
Value string
Level1s []Level1 `gorm:"many2many:levels;"`
}
@ -618,22 +625,23 @@ func TestManyToManyPreload(t *testing.T) {
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
DB.Table("levels").DropTableIfExists("levels")
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
panic(err)
}
want := Level2{Value: "Bob", Level1s: []Level1{
{Value: "ru"},
{Value: "en"},
want := Level2{Value: "Bob", LanguageCode: "ru", Level1s: []Level1{
{Value: "ru", LanguageCode: "ru"},
{Value: "en", LanguageCode: "en"},
}}
if err := DB.Save(&want).Error; err != nil {
panic(err)
}
want2 := Level2{Value: "Tom", Level1s: []Level1{
{Value: "zh"},
{Value: "de"},
want2 := Level2{Value: "Tom", LanguageCode: "zh", Level1s: []Level1{
{Value: "zh", LanguageCode: "zh"},
{Value: "de", LanguageCode: "de"},
}}
if err := DB.Save(&want2).Error; err != nil {
panic(err)
@ -698,6 +706,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) {
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
DB.Table("levels").DropTableIfExists("levels")
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
panic(err)

View File

@ -5,6 +5,7 @@ import (
"reflect"
"regexp"
"runtime"
"strings"
)
func fileWithLineNum() string {
@ -72,8 +73,18 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} {
return attrs
}
func toString(a interface{}) string {
return fmt.Sprintf("%v", a)
func toString(str interface{}) string {
if values, ok := str.([]interface{}); ok {
var results []string
for _, value := range values {
results = append(results, toString(value))
}
return strings.Join(results, "_")
} else if bytes, ok := str.([]byte); ok {
return string(bytes)
} else {
return fmt.Sprintf("%v", str)
}
}
func strInSlice(a string, list []string) bool {