Fix preload many2many with multiple primary keys

This commit is contained in:
Jinzhu 2015-08-18 09:08:33 +08:00
parent 6a6c1bf762
commit 9982134955
5 changed files with 58 additions and 25 deletions

View File

@ -142,17 +142,25 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
} }
for _, foreignKey := range s.Source.ForeignKeys { var foreignDBNames []string
condString := fmt.Sprintf("%v.%v in (?)", quotedTable, scope.Quote(foreignKey.DBName)) var foreignFieldNames []string
keys := scope.getColumnAsArray([]string{scope.Fields()[foreignKey.AssociationDBName].Name}) for _, foreignKey := range s.Source.ForeignKeys {
foreignDBNames = append(foreignDBNames, foreignKey.DBName)
foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name)
}
foreignFieldValues := scope.getColumnAsArray(foreignFieldNames)
condString := fmt.Sprintf("%v in (%v)", toQueryCondition(scope, foreignDBNames), toQueryMarks(foreignFieldValues))
keys := scope.getColumnAsArray(foreignFieldNames)
values = append(values, toQueryValues(keys)) values = append(values, toQueryValues(keys))
queryConditions = append(queryConditions, condString) queryConditions = append(queryConditions, condString)
}
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))). return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))).
Where(strings.Join(queryConditions, " AND "), values...) Where(condString, toQueryValues(foreignFieldValues)...)
} else { } else {
db.Error = errors.New("wrong source type for join table handler") db.Error = errors.New("wrong source type for join table handler")
return db return db

View File

@ -66,7 +66,6 @@ type Relationship struct {
PolymorphicType string PolymorphicType string
PolymorphicDBName string PolymorphicDBName string
ForeignFieldNames []string ForeignFieldNames []string
ForeignStructFieldNames []string
ForeignDBNames []string ForeignDBNames []string
AssociationForeignFieldNames []string AssociationForeignFieldNames []string
AssociationForeignStructFieldNames []string AssociationForeignStructFieldNames []string
@ -226,7 +225,6 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
for _, foreignKey := range foreignKeys { for _, foreignKey := range foreignKeys {
if field, ok := scope.FieldByName(foreignKey); ok { if field, ok := scope.FieldByName(foreignKey); ok {
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName) relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName)
relationship.ForeignStructFieldNames = append(relationship.ForeignFieldNames, field.Name)
joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
} }

View File

@ -267,11 +267,18 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf
} }
} }
var associationForeignStructFieldNames []string
for _, dbName := range relation.AssociationForeignFieldNames {
if field, ok := scope.FieldByName(dbName); ok {
associationForeignStructFieldNames = append(associationForeignStructFieldNames, field.Name)
}
}
if scope.IndirectValue().Kind() == reflect.Slice { if scope.IndirectValue().Kind() == reflect.Slice {
objects := scope.IndirectValue() objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ { for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j)) object := reflect.Indirect(objects.Index(j))
source := getRealValue(object, relation.AssociationForeignStructFieldNames) source := getRealValue(object, associationForeignStructFieldNames)
field := object.FieldByName(field.Name) field := object.FieldByName(field.Name)
for _, link := range linkHash[toString(source)] { for _, link := range linkHash[toString(source)] {
field.Set(reflect.Append(field, link)) field.Set(reflect.Append(field, link))
@ -279,7 +286,7 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf
} }
} else { } else {
object := scope.IndirectValue() object := scope.IndirectValue()
source := getRealValue(object, relation.AssociationForeignStructFieldNames) source := getRealValue(object, associationForeignStructFieldNames)
field := object.FieldByName(field.Name) field := object.FieldByName(field.Name)
for _, link := range linkHash[toString(source)] { for _, link := range linkHash[toString(source)] {
field.Set(reflect.Append(field, link)) field.Set(reflect.Append(field, link))

View File

@ -2,6 +2,7 @@ package gorm_test
import ( import (
"encoding/json" "encoding/json"
"os"
"reflect" "reflect"
"testing" "testing"
) )
@ -603,14 +604,20 @@ func TestNestedPreload9(t *testing.T) {
} }
} }
func TestManyToManyPreload(t *testing.T) { func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" {
return
}
type ( type (
Level1 struct { Level1 struct {
ID uint `gorm:"primary_key;"` ID uint `gorm:"primary_key;"`
LanguageCode string `gorm:"primary_key"`
Value string Value string
} }
Level2 struct { Level2 struct {
ID uint `gorm:"primary_key;"` ID uint `gorm:"primary_key;"`
LanguageCode string `gorm:"primary_key"`
Value string Value string
Level1s []Level1 `gorm:"many2many:levels;"` Level1s []Level1 `gorm:"many2many:levels;"`
} }
@ -618,22 +625,23 @@ func TestManyToManyPreload(t *testing.T) {
DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists(&Level1{})
DB.Table("levels").DropTableIfExists("levels")
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
panic(err) panic(err)
} }
want := Level2{Value: "Bob", Level1s: []Level1{ want := Level2{Value: "Bob", LanguageCode: "ru", Level1s: []Level1{
{Value: "ru"}, {Value: "ru", LanguageCode: "ru"},
{Value: "en"}, {Value: "en", LanguageCode: "en"},
}} }}
if err := DB.Save(&want).Error; err != nil { if err := DB.Save(&want).Error; err != nil {
panic(err) panic(err)
} }
want2 := Level2{Value: "Tom", Level1s: []Level1{ want2 := Level2{Value: "Tom", LanguageCode: "zh", Level1s: []Level1{
{Value: "zh"}, {Value: "zh", LanguageCode: "zh"},
{Value: "de"}, {Value: "de", LanguageCode: "de"},
}} }}
if err := DB.Save(&want2).Error; err != nil { if err := DB.Save(&want2).Error; err != nil {
panic(err) panic(err)
@ -698,6 +706,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) {
DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists(&Level1{})
DB.Table("levels").DropTableIfExists("levels")
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
panic(err) panic(err)

View File

@ -5,6 +5,7 @@ import (
"reflect" "reflect"
"regexp" "regexp"
"runtime" "runtime"
"strings"
) )
func fileWithLineNum() string { func fileWithLineNum() string {
@ -72,8 +73,18 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} {
return attrs return attrs
} }
func toString(a interface{}) string { func toString(str interface{}) string {
return fmt.Sprintf("%v", a) if values, ok := str.([]interface{}); ok {
var results []string
for _, value := range values {
results = append(results, toString(value))
}
return strings.Join(results, "_")
} else if bytes, ok := str.([]byte); ok {
return string(bytes)
} else {
return fmt.Sprintf("%v", str)
}
} }
func strInSlice(a string, list []string) bool { func strInSlice(a string, list []string) bool {