forked from mirror/gorm
Some Tweaks for Preload Many2Many, Add tests with inline conditions
This commit is contained in:
parent
27511118fe
commit
f8e2f04562
|
@ -13,10 +13,9 @@ type JoinTableHandlerInterface interface {
|
||||||
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
|
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
|
||||||
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
|
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
|
||||||
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
|
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
|
||||||
PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}, conditions ...interface{}) *DB
|
PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
|
||||||
SourceForeignKeys() []JoinTableForeignKey
|
SourceForeignKeys() []JoinTableForeignKey
|
||||||
DestinationForeignKeys() []JoinTableForeignKey
|
DestinationForeignKeys() []JoinTableForeignKey
|
||||||
DestinationType() reflect.Type
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type JoinTableForeignKey struct {
|
type JoinTableForeignKey struct {
|
||||||
|
@ -139,8 +138,8 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
|
||||||
var queryConditions []string
|
var queryConditions []string
|
||||||
var values []interface{}
|
var values []interface{}
|
||||||
if s.Source.ModelType == modelType {
|
if s.Source.ModelType == modelType {
|
||||||
for _, foreignKey := range s.Destination.ForeignKeys {
|
|
||||||
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
|
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
|
||||||
|
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||||
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
|
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -156,7 +155,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s JoinTableHandler) PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}, conditions ...interface{}) *DB {
|
func (s JoinTableHandler) PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
|
||||||
quotedTable := handler.Table(db)
|
quotedTable := handler.Table(db)
|
||||||
|
|
||||||
scope := db.NewScope(source)
|
scope := db.NewScope(source)
|
||||||
|
@ -165,8 +164,8 @@ func (s JoinTableHandler) PreloadWithJoin(handler JoinTableHandlerInterface, db
|
||||||
var queryConditions []string
|
var queryConditions []string
|
||||||
var values []interface{}
|
var values []interface{}
|
||||||
if s.Source.ModelType == modelType {
|
if s.Source.ModelType == modelType {
|
||||||
|
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
|
||||||
for _, foreignKey := range s.Destination.ForeignKeys {
|
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||||
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).inlineCondition(conditions...).QuotedTableName()
|
|
||||||
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
|
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -179,11 +178,6 @@ func (s JoinTableHandler) PreloadWithJoin(handler JoinTableHandlerInterface, db
|
||||||
queryConditions = append(queryConditions, condString)
|
queryConditions = append(queryConditions, condString)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(conditions) > 0 {
|
|
||||||
queryConditions = append(queryConditions, toString(conditions[0]))
|
|
||||||
values = append(values, conditions[1:]...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))).
|
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))).
|
||||||
Where(strings.Join(queryConditions, " AND "), values...)
|
Where(strings.Join(queryConditions, " AND "), values...)
|
||||||
} else {
|
} else {
|
||||||
|
@ -191,7 +185,3 @@ func (s JoinTableHandler) PreloadWithJoin(handler JoinTableHandlerInterface, db
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s JoinTableHandler) DestinationType() reflect.Type {
|
|
||||||
return s.Destination.ModelType
|
|
||||||
}
|
|
||||||
|
|
25
preload.go
25
preload.go
|
@ -195,14 +195,16 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf
|
||||||
relation := field.Relationship
|
relation := field.Relationship
|
||||||
|
|
||||||
joinTableHandler := relation.JoinTableHandler
|
joinTableHandler := relation.JoinTableHandler
|
||||||
destType := joinTableHandler.DestinationType()
|
destType := field.StructField.Struct.Type.Elem()
|
||||||
|
var isPtr bool
|
||||||
db := scope.NewDB().Table(scope.db.NewScope(reflect.New(destType).Elem().Interface()).TableName())
|
if destType.Kind() == reflect.Ptr {
|
||||||
|
isPtr = true
|
||||||
|
destType = destType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
var destKeys []string
|
var destKeys []string
|
||||||
var sourceKeys []string
|
var sourceKeys []string
|
||||||
|
var linkHash = make(map[string][]string)
|
||||||
linkHash := make(map[string][]string)
|
|
||||||
|
|
||||||
for _, key := range joinTableHandler.DestinationForeignKeys() {
|
for _, key := range joinTableHandler.DestinationForeignKeys() {
|
||||||
destKeys = append(destKeys, key.DBName)
|
destKeys = append(destKeys, key.DBName)
|
||||||
|
@ -213,7 +215,13 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf
|
||||||
}
|
}
|
||||||
|
|
||||||
results := reflect.New(field.Struct.Type).Elem()
|
results := reflect.New(field.Struct.Type).Elem()
|
||||||
rows, err := joinTableHandler.PreloadWithJoin(joinTableHandler, db, scope.Value, conditions...).Rows()
|
|
||||||
|
db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName())
|
||||||
|
preloadJoinDB := joinTableHandler.PreloadWithJoin(joinTableHandler, db, scope.Value)
|
||||||
|
if len(conditions) > 0 {
|
||||||
|
preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...)
|
||||||
|
}
|
||||||
|
rows, err := preloadJoinDB.Rows()
|
||||||
|
|
||||||
if scope.Err(err) != nil {
|
if scope.Err(err) != nil {
|
||||||
return
|
return
|
||||||
|
@ -264,8 +272,11 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf
|
||||||
linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], toString(destKey))
|
linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], toString(destKey))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isPtr {
|
||||||
|
results = reflect.Append(results, elem.Addr())
|
||||||
|
} else {
|
||||||
results = reflect.Append(results, elem)
|
results = reflect.Append(results, elem)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||||
|
|
|
@ -612,7 +612,7 @@ func TestManyToManyPreload(t *testing.T) {
|
||||||
Level2 struct {
|
Level2 struct {
|
||||||
ID uint `gorm:"primary_key;"`
|
ID uint `gorm:"primary_key;"`
|
||||||
Value string
|
Value string
|
||||||
Level1s []Level1 `gorm:"many2many:levels;"`
|
Level1s []*Level1 `gorm:"many2many:levels;"`
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -623,7 +623,7 @@ func TestManyToManyPreload(t *testing.T) {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
want := Level2{Value: "Bob", Level1s: []Level1{
|
want := Level2{Value: "Bob", Level1s: []*Level1{
|
||||||
{Value: "ru"},
|
{Value: "ru"},
|
||||||
{Value: "en"},
|
{Value: "en"},
|
||||||
}}
|
}}
|
||||||
|
@ -631,7 +631,7 @@ func TestManyToManyPreload(t *testing.T) {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
want2 := Level2{Value: "Tom", Level1s: []Level1{
|
want2 := Level2{Value: "Tom", Level1s: []*Level1{
|
||||||
{Value: "zh"},
|
{Value: "zh"},
|
||||||
{Value: "de"},
|
{Value: "de"},
|
||||||
}}
|
}}
|
||||||
|
@ -665,6 +665,22 @@ func TestManyToManyPreload(t *testing.T) {
|
||||||
if !reflect.DeepEqual(got3, []Level2{got, got2}) {
|
if !reflect.DeepEqual(got3, []Level2{got, got2}) {
|
||||||
t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2}))
|
t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var got4 []Level2
|
||||||
|
if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ruLevel1 Level1
|
||||||
|
var zhLevel1 Level1
|
||||||
|
DB.First(&ruLevel1, "value = ?", "ru")
|
||||||
|
DB.First(&zhLevel1, "value = ?", "zh")
|
||||||
|
|
||||||
|
got.Level1s = []*Level1{&ruLevel1}
|
||||||
|
got2.Level1s = []*Level1{&zhLevel1}
|
||||||
|
if !reflect.DeepEqual(got4, []Level2{got, got2}) {
|
||||||
|
t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2}))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toJSONString(v interface{}) []byte {
|
func toJSONString(v interface{}) []byte {
|
||||||
|
|
Loading…
Reference in New Issue