diff --git a/join_table_handler.go b/join_table_handler.go index 10e1e848..162af4a8 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -13,6 +13,7 @@ type JoinTableHandlerInterface interface { Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB + PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB SourceForeignKeys() []JoinTableForeignKey DestinationForeignKeys() []JoinTableForeignKey } @@ -137,8 +138,8 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so var queryConditions []string var values []interface{} if s.Source.ModelType == modelType { + destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() for _, foreignKey := range s.Destination.ForeignKeys { - destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) } @@ -153,3 +154,34 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so return db } } + +func (s JoinTableHandler) PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { + quotedTable := handler.Table(db) + + scope := db.NewScope(source) + modelType := scope.GetModelStruct().ModelType + var joinConditions []string + var queryConditions []string + var values []interface{} + if s.Source.ModelType == modelType { + destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() + for _, foreignKey := range s.Destination.ForeignKeys { + joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) + } + + for _, foreignKey := range s.Source.ForeignKeys { + condString := fmt.Sprintf("%v.%v in (?)", quotedTable, scope.Quote(foreignKey.DBName)) + + keys := scope.getColumnAsArray([]string{scope.Fields()[foreignKey.AssociationDBName].Name}) + values = append(values, toQueryValues(keys)) + + queryConditions = append(queryConditions, condString) + } + + return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))). + Where(strings.Join(queryConditions, " AND "), values...) + } else { + db.Error = errors.New("wrong source type for join table handler") + return db + } +} diff --git a/model_struct.go b/model_struct.go index 26c58fc5..db6d9a88 100644 --- a/model_struct.go +++ b/model_struct.go @@ -62,14 +62,16 @@ func (structField *StructField) clone() *StructField { } type Relationship struct { - Kind string - PolymorphicType string - PolymorphicDBName string - ForeignFieldNames []string - ForeignDBNames []string - AssociationForeignFieldNames []string - AssociationForeignDBNames []string - JoinTableHandler JoinTableHandlerInterface + Kind string + PolymorphicType string + PolymorphicDBName string + ForeignFieldNames []string + ForeignStructFieldNames []string + ForeignDBNames []string + AssociationForeignFieldNames []string + AssociationForeignStructFieldNames []string + AssociationForeignDBNames []string + JoinTableHandler JoinTableHandlerInterface } func (scope *Scope) GetModelStruct() *ModelStruct { @@ -224,6 +226,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for _, foreignKey := range foreignKeys { if field, ok := scope.FieldByName(foreignKey); ok { relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName) + relationship.ForeignStructFieldNames = append(relationship.ForeignFieldNames, field.Name) joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) } @@ -242,6 +245,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for _, name := range associationForeignKeys { if field, ok := toScope.FieldByName(name); ok { relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) + relationship.AssociationForeignStructFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) } diff --git a/preload.go b/preload.go index 0db6fbde..c0869a2a 100644 --- a/preload.go +++ b/preload.go @@ -10,11 +10,13 @@ import ( func getRealValue(value reflect.Value, columns []string) (results []interface{}) { for _, column := range columns { - result := reflect.Indirect(value).FieldByName(column).Interface() - if r, ok := result.(driver.Valuer); ok { - result, _ = r.Value() + if reflect.Indirect(value).FieldByName(column).IsValid() { + result := reflect.Indirect(value).FieldByName(column).Interface() + if r, ok := result.(driver.Valuer); ok { + result, _ = r.Value() + } + results = append(results, result) } - results = append(results, result) } return } @@ -61,7 +63,7 @@ func Preload(scope *Scope) { case "belongs_to": currentScope.handleBelongsToPreload(field, conditions) case "many_to_many": - fallthrough + currentScope.handleHasManyToManyPreload(field, conditions) default: currentScope.Err(errors.New("not supported relation")) } @@ -189,6 +191,133 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ } } +func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interface{}) { + relation := field.Relationship + + joinTableHandler := relation.JoinTableHandler + destType := field.StructField.Struct.Type.Elem() + var isPtr bool + if destType.Kind() == reflect.Ptr { + isPtr = true + destType = destType.Elem() + } + + var destKeys []string + var sourceKeys []string + var linkHash = make(map[string][]string) + + for _, key := range joinTableHandler.DestinationForeignKeys() { + destKeys = append(destKeys, key.DBName) + } + + for _, key := range joinTableHandler.SourceForeignKeys() { + sourceKeys = append(sourceKeys, key.DBName) + } + + results := reflect.New(field.Struct.Type).Elem() + + db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()) + preloadJoinDB := joinTableHandler.PreloadWithJoin(joinTableHandler, db, scope.Value) + if len(conditions) > 0 { + preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...) + } + rows, err := preloadJoinDB.Rows() + + if scope.Err(err) != nil { + return + } + defer rows.Close() + + columns, _ := rows.Columns() + for rows.Next() { + elem := reflect.New(destType).Elem() + var values = make([]interface{}, len(columns)) + + fields := scope.New(elem.Addr().Interface()).Fields() + + for index, column := range columns { + if field, ok := fields[column]; ok { + if field.Field.Kind() == reflect.Ptr { + values[index] = field.Field.Addr().Interface() + } else { + values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface() + } + } else { + var i interface{} + values[index] = &i + } + } + + scope.Err(rows.Scan(values...)) + + var destKey []interface{} + var sourceKey []interface{} + + for index, column := range columns { + value := values[index] + if field, ok := fields[column]; ok { + if field.Field.Kind() == reflect.Ptr { + field.Field.Set(reflect.ValueOf(value).Elem()) + } else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() { + field.Field.Set(v) + } + } else if strInSlice(column, destKeys) { + destKey = append(destKey, *(value.(*interface{}))) + } else if strInSlice(column, sourceKeys) { + sourceKey = append(sourceKey, *(value.(*interface{}))) + } + } + + if len(destKey) != 0 && len(sourceKey) != 0 { + linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], toString(destKey)) + } + + if isPtr { + results = reflect.Append(results, elem.Addr()) + } else { + results = reflect.Append(results, elem) + } + } + + if scope.IndirectValue().Kind() == reflect.Slice { + objects := scope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + var checked []string + + object := reflect.Indirect(objects.Index(j)) + source := getRealValue(object, relation.AssociationForeignStructFieldNames) + + for i := 0; i < results.Len(); i++ { + result := results.Index(i) + value := getRealValue(result, relation.ForeignStructFieldNames) + + if strInSlice(toString(value), linkHash[toString(source)]) && !strInSlice(toString(value), checked) { + f := object.FieldByName(field.Name) + f.Set(reflect.Append(f, result)) + checked = append(checked, toString(value)) + continue + } + } + } + } else { + object := scope.IndirectValue() + var checked []string + source := getRealValue(object, relation.AssociationForeignStructFieldNames) + + for i := 0; i < results.Len(); i++ { + result := results.Index(i) + value := getRealValue(result, relation.ForeignStructFieldNames) + + if strInSlice(toString(value), linkHash[toString(source)]) && !strInSlice(toString(value), checked) { + f := object.FieldByName(field.Name) + f.Set(reflect.Append(f, result)) + checked = append(checked, toString(value)) + continue + } + } + } +} + func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) { values := scope.IndirectValue() switch values.Kind() { diff --git a/preload_test.go b/preload_test.go index a6647bbd..043e24d6 100644 --- a/preload_test.go +++ b/preload_test.go @@ -603,6 +603,86 @@ func TestNestedPreload9(t *testing.T) { } } +func TestManyToManyPreload(t *testing.T) { + type ( + Level1 struct { + ID uint `gorm:"primary_key;"` + Value string + } + Level2 struct { + ID uint `gorm:"primary_key;"` + Value string + Level1s []*Level1 `gorm:"many2many:levels;"` + } + ) + + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + + if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level2{Value: "Bob", Level1s: []*Level1{ + {Value: "ru"}, + {Value: "en"}, + }} + if err := DB.Save(&want).Error; err != nil { + panic(err) + } + + want2 := Level2{Value: "Tom", Level1s: []*Level1{ + {Value: "zh"}, + {Value: "de"}, + }} + if err := DB.Save(&want2).Error; err != nil { + panic(err) + } + + var got Level2 + if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + + var got2 Level2 + if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got2, want2) { + t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) + } + + var got3 []Level2 + if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got3, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) + } + + var got4 []Level2 + if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + panic(err) + } + + var ruLevel1 Level1 + var zhLevel1 Level1 + DB.First(&ruLevel1, "value = ?", "ru") + DB.First(&zhLevel1, "value = ?", "zh") + + got.Level1s = []*Level1{&ruLevel1} + got2.Level1s = []*Level1{&zhLevel1} + if !reflect.DeepEqual(got4, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) + } +} + func toJSONString(v interface{}) []byte { r, _ := json.MarshalIndent(v, "", " ") return r diff --git a/utils_private.go b/utils_private.go index 6f609ae0..b82aa807 100644 --- a/utils_private.go +++ b/utils_private.go @@ -71,3 +71,16 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} { } return attrs } + +func toString(a interface{}) string { + return fmt.Sprintf("%v", a) +} + +func strInSlice(a string, list []string) bool { + for _, b := range list { + if b == a { + return true + } + } + return false +}