mirror of https://github.com/go-gorm/gorm.git
m2m preload
This commit is contained in:
parent
dd0d4d931f
commit
42c3f39163
|
@ -13,8 +13,10 @@ type JoinTableHandlerInterface interface {
|
|||
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
|
||||
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
|
||||
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
|
||||
PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}, conditions ...interface{}) *DB
|
||||
SourceForeignKeys() []JoinTableForeignKey
|
||||
DestinationForeignKeys() []JoinTableForeignKey
|
||||
DestinationType() reflect.Type
|
||||
}
|
||||
|
||||
type JoinTableForeignKey struct {
|
||||
|
@ -153,3 +155,43 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
|
|||
return db
|
||||
}
|
||||
}
|
||||
|
||||
func (s JoinTableHandler) PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}, conditions ...interface{}) *DB {
|
||||
quotedTable := handler.Table(db)
|
||||
|
||||
scope := db.NewScope(source)
|
||||
modelType := scope.GetModelStruct().ModelType
|
||||
var joinConditions []string
|
||||
var queryConditions []string
|
||||
var values []interface{}
|
||||
if s.Source.ModelType == modelType {
|
||||
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).inlineCondition(conditions...).QuotedTableName()
|
||||
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
|
||||
}
|
||||
|
||||
for _, foreignKey := range s.Source.ForeignKeys {
|
||||
condString := fmt.Sprintf("%v.%v in (?)", quotedTable, scope.Quote(foreignKey.DBName))
|
||||
|
||||
keys := scope.getColumnAsArray([]string{scope.Fields()[foreignKey.AssociationDBName].Name})
|
||||
values = append(values, toQueryValues(keys))
|
||||
|
||||
queryConditions = append(queryConditions, condString)
|
||||
}
|
||||
|
||||
if len(conditions) > 0 {
|
||||
queryConditions = append(queryConditions, toString(conditions[0]))
|
||||
values = append(values, conditions[1:]...)
|
||||
}
|
||||
|
||||
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))).
|
||||
Where(strings.Join(queryConditions, " AND "), values...)
|
||||
} else {
|
||||
db.Error = errors.New("wrong source type for join table handler")
|
||||
return db
|
||||
}
|
||||
}
|
||||
|
||||
func (s JoinTableHandler) DestinationType() reflect.Type {
|
||||
return s.Destination.ModelType
|
||||
}
|
||||
|
|
121
preload.go
121
preload.go
|
@ -10,11 +10,22 @@ import (
|
|||
|
||||
func getRealValue(value reflect.Value, columns []string) (results []interface{}) {
|
||||
for _, column := range columns {
|
||||
result := reflect.Indirect(value).FieldByName(column).Interface()
|
||||
if r, ok := result.(driver.Valuer); ok {
|
||||
result, _ = r.Value()
|
||||
if reflect.Indirect(value).FieldByName(column).IsValid() {
|
||||
result := reflect.Indirect(value).FieldByName(column).Interface()
|
||||
if r, ok := result.(driver.Valuer); ok {
|
||||
result, _ = r.Value()
|
||||
}
|
||||
results = append(results, result)
|
||||
} else {
|
||||
column = upFL(column)
|
||||
if reflect.Indirect(value).FieldByName(column).IsValid() {
|
||||
result := reflect.Indirect(value).FieldByName(column).Interface()
|
||||
if r, ok := result.(driver.Valuer); ok {
|
||||
result, _ = r.Value()
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -61,7 +72,7 @@ func Preload(scope *Scope) {
|
|||
case "belongs_to":
|
||||
currentScope.handleBelongsToPreload(field, conditions)
|
||||
case "many_to_many":
|
||||
fallthrough
|
||||
currentScope.handleHasManyToManyPreload(field, conditions)
|
||||
default:
|
||||
currentScope.Err(errors.New("not supported relation"))
|
||||
}
|
||||
|
@ -189,6 +200,106 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
|
|||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interface{}) {
|
||||
relation := field.Relationship
|
||||
|
||||
joinTableHandler := relation.JoinTableHandler
|
||||
destType := joinTableHandler.DestinationType()
|
||||
|
||||
db := scope.NewDB().Table(scope.db.NewScope(reflect.New(destType).Elem().Interface()).TableName())
|
||||
|
||||
var destKeys []string
|
||||
var sourceKeys []string
|
||||
|
||||
linkHash := make(map[string][]string)
|
||||
|
||||
for _, key := range joinTableHandler.DestinationForeignKeys() {
|
||||
destKeys = append(destKeys, key.DBName)
|
||||
}
|
||||
|
||||
for _, key := range joinTableHandler.SourceForeignKeys() {
|
||||
sourceKeys = append(sourceKeys, key.DBName)
|
||||
}
|
||||
|
||||
results := reflect.New(field.Struct.Type).Elem()
|
||||
rows, err := joinTableHandler.PreloadWithJoin(joinTableHandler, db, scope.Value, conditions...).Rows()
|
||||
|
||||
if scope.Err(err) != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, _ := rows.Columns()
|
||||
for rows.Next() {
|
||||
elem := reflect.New(destType).Elem()
|
||||
var values = make([]interface{}, len(columns))
|
||||
|
||||
fields := scope.New(elem.Addr().Interface()).Fields()
|
||||
|
||||
for index, column := range columns {
|
||||
if field, ok := fields[column]; ok {
|
||||
if field.Field.Kind() == reflect.Ptr {
|
||||
values[index] = field.Field.Addr().Interface()
|
||||
} else {
|
||||
values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface()
|
||||
}
|
||||
} else {
|
||||
var i interface{}
|
||||
values[index] = &i
|
||||
}
|
||||
}
|
||||
|
||||
scope.Err(rows.Scan(values...))
|
||||
|
||||
var destKey []interface{}
|
||||
var sourceKey []interface{}
|
||||
|
||||
for index, column := range columns {
|
||||
value := values[index]
|
||||
if field, ok := fields[column]; ok {
|
||||
if field.Field.Kind() == reflect.Ptr {
|
||||
field.Field.Set(reflect.ValueOf(value).Elem())
|
||||
} else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() {
|
||||
field.Field.Set(v)
|
||||
}
|
||||
} else if strInSlice(column, destKeys) {
|
||||
destKey = append(destKey, *(value.(*interface{})))
|
||||
} else if strInSlice(column, sourceKeys) {
|
||||
sourceKey = append(sourceKey, *(value.(*interface{})))
|
||||
}
|
||||
}
|
||||
|
||||
if len(destKey) != 0 && len(sourceKey) != 0 {
|
||||
linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], toString(destKey))
|
||||
}
|
||||
|
||||
results = reflect.Append(results, elem)
|
||||
|
||||
}
|
||||
|
||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||
objects := scope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
var checked []string
|
||||
|
||||
object := reflect.Indirect(objects.Index(j))
|
||||
source := getRealValue(object, relation.AssociationForeignFieldNames)
|
||||
|
||||
for i := 0; i < results.Len(); i++ {
|
||||
result := results.Index(i)
|
||||
value := getRealValue(result, relation.ForeignFieldNames)
|
||||
|
||||
if strInSlice(toString(value), linkHash[toString(source)]) && !strInSlice(toString(value), checked) {
|
||||
f := object.FieldByName(field.Name)
|
||||
f.Set(reflect.Append(f, result))
|
||||
checked = append(checked, toString(value))
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) {
|
||||
values := scope.IndirectValue()
|
||||
switch values.Kind() {
|
||||
|
|
Loading…
Reference in New Issue