From 42c3f39163a1676bafcce69bd1d34252fd6bf653 Mon Sep 17 00:00:00 2001 From: kimiby Date: Sun, 16 Aug 2015 10:10:11 +0300 Subject: [PATCH] m2m preload --- join_table_handler.go | 42 +++++++++++++++ preload.go | 121 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 158 insertions(+), 5 deletions(-) diff --git a/join_table_handler.go b/join_table_handler.go index 10e1e848..1fb25e5d 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -13,8 +13,10 @@ type JoinTableHandlerInterface interface { Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB + PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}, conditions ...interface{}) *DB SourceForeignKeys() []JoinTableForeignKey DestinationForeignKeys() []JoinTableForeignKey + DestinationType() reflect.Type } type JoinTableForeignKey struct { @@ -153,3 +155,43 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so return db } } + +func (s JoinTableHandler) PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}, conditions ...interface{}) *DB { + quotedTable := handler.Table(db) + + scope := db.NewScope(source) + modelType := scope.GetModelStruct().ModelType + var joinConditions []string + var queryConditions []string + var values []interface{} + if s.Source.ModelType == modelType { + for _, foreignKey := range s.Destination.ForeignKeys { + destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).inlineCondition(conditions...).QuotedTableName() + joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) + } + + for _, foreignKey := range s.Source.ForeignKeys { + condString := fmt.Sprintf("%v.%v in (?)", quotedTable, scope.Quote(foreignKey.DBName)) + + keys := scope.getColumnAsArray([]string{scope.Fields()[foreignKey.AssociationDBName].Name}) + values = append(values, toQueryValues(keys)) + + queryConditions = append(queryConditions, condString) + } + + if len(conditions) > 0 { + queryConditions = append(queryConditions, toString(conditions[0])) + values = append(values, conditions[1:]...) + } + + return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))). + Where(strings.Join(queryConditions, " AND "), values...) + } else { + db.Error = errors.New("wrong source type for join table handler") + return db + } +} + +func (s JoinTableHandler) DestinationType() reflect.Type { + return s.Destination.ModelType +} diff --git a/preload.go b/preload.go index 0db6fbde..75be26dc 100644 --- a/preload.go +++ b/preload.go @@ -10,11 +10,22 @@ import ( func getRealValue(value reflect.Value, columns []string) (results []interface{}) { for _, column := range columns { - result := reflect.Indirect(value).FieldByName(column).Interface() - if r, ok := result.(driver.Valuer); ok { - result, _ = r.Value() + if reflect.Indirect(value).FieldByName(column).IsValid() { + result := reflect.Indirect(value).FieldByName(column).Interface() + if r, ok := result.(driver.Valuer); ok { + result, _ = r.Value() + } + results = append(results, result) + } else { + column = upFL(column) + if reflect.Indirect(value).FieldByName(column).IsValid() { + result := reflect.Indirect(value).FieldByName(column).Interface() + if r, ok := result.(driver.Valuer); ok { + result, _ = r.Value() + } + results = append(results, result) + } } - results = append(results, result) } return } @@ -61,7 +72,7 @@ func Preload(scope *Scope) { case "belongs_to": currentScope.handleBelongsToPreload(field, conditions) case "many_to_many": - fallthrough + currentScope.handleHasManyToManyPreload(field, conditions) default: currentScope.Err(errors.New("not supported relation")) } @@ -189,6 +200,106 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ } } +func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interface{}) { + relation := field.Relationship + + joinTableHandler := relation.JoinTableHandler + destType := joinTableHandler.DestinationType() + + db := scope.NewDB().Table(scope.db.NewScope(reflect.New(destType).Elem().Interface()).TableName()) + + var destKeys []string + var sourceKeys []string + + linkHash := make(map[string][]string) + + for _, key := range joinTableHandler.DestinationForeignKeys() { + destKeys = append(destKeys, key.DBName) + } + + for _, key := range joinTableHandler.SourceForeignKeys() { + sourceKeys = append(sourceKeys, key.DBName) + } + + results := reflect.New(field.Struct.Type).Elem() + rows, err := joinTableHandler.PreloadWithJoin(joinTableHandler, db, scope.Value, conditions...).Rows() + + if scope.Err(err) != nil { + return + } + defer rows.Close() + + columns, _ := rows.Columns() + for rows.Next() { + elem := reflect.New(destType).Elem() + var values = make([]interface{}, len(columns)) + + fields := scope.New(elem.Addr().Interface()).Fields() + + for index, column := range columns { + if field, ok := fields[column]; ok { + if field.Field.Kind() == reflect.Ptr { + values[index] = field.Field.Addr().Interface() + } else { + values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface() + } + } else { + var i interface{} + values[index] = &i + } + } + + scope.Err(rows.Scan(values...)) + + var destKey []interface{} + var sourceKey []interface{} + + for index, column := range columns { + value := values[index] + if field, ok := fields[column]; ok { + if field.Field.Kind() == reflect.Ptr { + field.Field.Set(reflect.ValueOf(value).Elem()) + } else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() { + field.Field.Set(v) + } + } else if strInSlice(column, destKeys) { + destKey = append(destKey, *(value.(*interface{}))) + } else if strInSlice(column, sourceKeys) { + sourceKey = append(sourceKey, *(value.(*interface{}))) + } + } + + if len(destKey) != 0 && len(sourceKey) != 0 { + linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], toString(destKey)) + } + + results = reflect.Append(results, elem) + + } + + if scope.IndirectValue().Kind() == reflect.Slice { + objects := scope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + var checked []string + + object := reflect.Indirect(objects.Index(j)) + source := getRealValue(object, relation.AssociationForeignFieldNames) + + for i := 0; i < results.Len(); i++ { + result := results.Index(i) + value := getRealValue(result, relation.ForeignFieldNames) + + if strInSlice(toString(value), linkHash[toString(source)]) && !strInSlice(toString(value), checked) { + f := object.FieldByName(field.Name) + f.Set(reflect.Append(f, result)) + checked = append(checked, toString(value)) + continue + } + } + } + } +} + func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) { values := scope.IndirectValue() switch values.Kind() {