mirror of https://github.com/go-gorm/gorm.git
feat: support nested join (#6067)
* feat: support nested join * fix: empty rel value
This commit is contained in:
parent
654b5f2006
commit
8bf1f269cf
|
@ -8,6 +8,8 @@ import (
|
|||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func Query(db *gorm.DB) {
|
||||
|
@ -109,13 +111,46 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||
}
|
||||
}
|
||||
|
||||
specifiedRelationsName := make(map[string]interface{})
|
||||
for _, join := range db.Statement.Joins {
|
||||
if db.Statement.Schema == nil {
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
})
|
||||
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
|
||||
if db.Statement.Schema != nil {
|
||||
var isRelations bool // is relations or raw sql
|
||||
var relations []*schema.Relationship
|
||||
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
|
||||
if ok {
|
||||
isRelations = true
|
||||
relations = append(relations, relation)
|
||||
} else {
|
||||
// handle nested join like "Manager.Company"
|
||||
nestedJoinNames := strings.Split(join.Name, ".")
|
||||
if len(nestedJoinNames) > 1 {
|
||||
isNestedJoin := true
|
||||
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
|
||||
currentRelations := db.Statement.Schema.Relationships.Relations
|
||||
for _, relname := range nestedJoinNames {
|
||||
// incomplete match, only treated as raw sql
|
||||
if relation, ok = currentRelations[relname]; ok {
|
||||
gussNestedRelations = append(gussNestedRelations, relation)
|
||||
currentRelations = relation.FieldSchema.Relationships.Relations
|
||||
} else {
|
||||
isNestedJoin = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isNestedJoin {
|
||||
isRelations = true
|
||||
relations = gussNestedRelations
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isRelations {
|
||||
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
|
||||
tableAliasName := relation.Name
|
||||
if parentTableName != clause.CurrentTable {
|
||||
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
|
||||
}
|
||||
|
||||
columnStmt := gorm.Statement{
|
||||
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
||||
|
@ -128,7 +163,7 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||
Table: tableAliasName,
|
||||
Name: s,
|
||||
Alias: tableAliasName + "__" + s,
|
||||
Alias: utils.NestedRelationName(tableAliasName, s),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -137,13 +172,13 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||
for idx, ref := range relation.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName},
|
||||
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
}
|
||||
} else {
|
||||
if ref.PrimaryValue == "" {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName},
|
||||
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
||||
}
|
||||
} else {
|
||||
|
@ -184,11 +219,28 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||
}
|
||||
}
|
||||
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Type: join.JoinType,
|
||||
return clause.Join{
|
||||
Type: joinType,
|
||||
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
||||
ON: clause.Where{Exprs: exprs},
|
||||
}
|
||||
}
|
||||
|
||||
parentTableName := clause.CurrentTable
|
||||
for _, rel := range relations {
|
||||
// joins table alias like "Manager, Company, Manager__Company"
|
||||
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
|
||||
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
|
||||
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
|
||||
specifiedRelationsName[nestedAlias] = nil
|
||||
}
|
||||
parentTableName = rel.Name
|
||||
}
|
||||
} else {
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
|
|
60
scan.go
60
scan.go
|
@ -4,10 +4,10 @@ import (
|
|||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// prepareValues prepare values slice
|
||||
|
@ -50,7 +50,7 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
|
|||
}
|
||||
}
|
||||
|
||||
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) {
|
||||
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) {
|
||||
for idx, field := range fields {
|
||||
if field != nil {
|
||||
values[idx] = field.NewValuePool.Get()
|
||||
|
@ -65,28 +65,45 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int
|
|||
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
joinedSchemaMap := make(map[*schema.Field]interface{})
|
||||
joinedNestedSchemaMap := make(map[string]interface{})
|
||||
for idx, field := range fields {
|
||||
if field == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(joinFields) == 0 || joinFields[idx][0] == nil {
|
||||
if len(joinFields) == 0 || len(joinFields[idx]) == 0 {
|
||||
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
|
||||
} else {
|
||||
joinSchema := joinFields[idx][0]
|
||||
relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue)
|
||||
} else { // joinFields count is larger than 2 when using join
|
||||
var isNilPtrValue bool
|
||||
var relValue reflect.Value
|
||||
// does not contain raw dbname
|
||||
nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1]
|
||||
// current reflect value
|
||||
currentReflectValue := reflectValue
|
||||
fullRels := make([]string, 0, len(nestedJoinSchemas))
|
||||
for _, joinSchema := range nestedJoinSchemas {
|
||||
fullRels = append(fullRels, joinSchema.Name)
|
||||
relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue)
|
||||
if relValue.Kind() == reflect.Ptr {
|
||||
if _, ok := joinedSchemaMap[joinSchema]; !ok {
|
||||
fullRelsName := utils.JoinNestedRelationNames(fullRels)
|
||||
// same nested structure
|
||||
if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok {
|
||||
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
|
||||
continue
|
||||
isNilPtrValue = true
|
||||
break
|
||||
}
|
||||
|
||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||
joinedSchemaMap[joinSchema] = nil
|
||||
joinedNestedSchemaMap[fullRelsName] = nil
|
||||
}
|
||||
}
|
||||
db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx]))
|
||||
currentReflectValue = relValue
|
||||
}
|
||||
|
||||
if !isNilPtrValue { // ignore if value is nil
|
||||
f := joinFields[idx][len(joinFields[idx])-1]
|
||||
db.AddError(f.Set(db.Statement.Context, relValue, values[idx]))
|
||||
}
|
||||
}
|
||||
|
||||
// release data to pool
|
||||
|
@ -163,7 +180,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
|||
default:
|
||||
var (
|
||||
fields = make([]*schema.Field, len(columns))
|
||||
joinFields [][2]*schema.Field
|
||||
joinFields [][]*schema.Field
|
||||
sch = db.Statement.Schema
|
||||
reflectValue = db.Statement.ReflectValue
|
||||
)
|
||||
|
@ -217,15 +234,26 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
|||
} else {
|
||||
matchedFieldCount[column] = 1
|
||||
}
|
||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
|
||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||
subNameCount := len(names)
|
||||
// nested relation fields
|
||||
relFields := make([]*schema.Field, 0, subNameCount-1)
|
||||
relFields = append(relFields, rel.Field)
|
||||
for _, name := range names[1 : subNameCount-1] {
|
||||
rel = rel.FieldSchema.Relationships.Relations[name]
|
||||
relFields = append(relFields, rel.Field)
|
||||
}
|
||||
// lastest name is raw dbname
|
||||
dbName := names[subNameCount-1]
|
||||
if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
|
||||
if len(joinFields) == 0 {
|
||||
joinFields = make([][2]*schema.Field, len(columns))
|
||||
joinFields = make([][]*schema.Field, len(columns))
|
||||
}
|
||||
joinFields[idx] = [2]*schema.Field{rel.Field, field}
|
||||
relFields = append(relFields, field)
|
||||
joinFields[idx] = relFields
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
|
|
@ -325,3 +325,66 @@ func TestJoinArgsWithDB(t *testing.T) {
|
|||
}
|
||||
AssertEqual(t, user4.NamedPet.Name, "")
|
||||
}
|
||||
|
||||
func TestNestedJoins(t *testing.T) {
|
||||
users := []User{
|
||||
{
|
||||
Name: "nested-joins-1",
|
||||
Manager: GetUser("nested-joins-manager-1", Config{Company: true, NamedPet: true}),
|
||||
NamedPet: &Pet{Name: "nested-joins-namepet-1", Toy: Toy{Name: "nested-joins-namepet-toy-1"}},
|
||||
},
|
||||
{
|
||||
Name: "nested-joins-2",
|
||||
Manager: GetUser("nested-joins-manager-2", Config{Company: true, NamedPet: true}),
|
||||
NamedPet: &Pet{Name: "nested-joins-namepet-2", Toy: Toy{Name: "nested-joins-namepet-toy-2"}},
|
||||
},
|
||||
}
|
||||
|
||||
DB.Create(&users)
|
||||
|
||||
var userIDs []uint
|
||||
for _, user := range users {
|
||||
userIDs = append(userIDs, user.ID)
|
||||
}
|
||||
|
||||
var users2 []User
|
||||
if err := DB.
|
||||
Joins("Manager").
|
||||
Joins("Manager.Company").
|
||||
Joins("Manager.NamedPet").
|
||||
Joins("NamedPet").
|
||||
Joins("NamedPet.Toy").
|
||||
Find(&users2, "users.id IN ?", userIDs).Error; err != nil {
|
||||
t.Fatalf("Failed to load with joins, got error: %v", err)
|
||||
} else if len(users2) != len(users) {
|
||||
t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users))
|
||||
}
|
||||
|
||||
sort.Slice(users2, func(i, j int) bool {
|
||||
return users2[i].ID > users2[j].ID
|
||||
})
|
||||
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].ID > users[j].ID
|
||||
})
|
||||
|
||||
for idx, user := range users {
|
||||
// user
|
||||
CheckUser(t, user, users2[idx])
|
||||
if users2[idx].Manager == nil {
|
||||
t.Fatalf("Failed to load Manager")
|
||||
}
|
||||
// manager
|
||||
CheckUser(t, *user.Manager, *users2[idx].Manager)
|
||||
// user pet
|
||||
if users2[idx].NamedPet == nil {
|
||||
t.Fatalf("Failed to load NamedPet")
|
||||
}
|
||||
CheckPet(t, *user.NamedPet, *users2[idx].NamedPet)
|
||||
// manager pet
|
||||
if users2[idx].Manager.NamedPet == nil {
|
||||
t.Fatalf("Failed to load NamedPet")
|
||||
}
|
||||
CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,8 +13,14 @@ import (
|
|||
|
||||
func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) {
|
||||
for _, name := range names {
|
||||
got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface()
|
||||
expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface()
|
||||
rv := reflect.Indirect(reflect.ValueOf(r))
|
||||
ev := reflect.Indirect(reflect.ValueOf(e))
|
||||
if rv.IsValid() != ev.IsValid() {
|
||||
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), r, e)
|
||||
return
|
||||
}
|
||||
got := rv.FieldByName(name).Interface()
|
||||
expect := ev.FieldByName(name).Interface()
|
||||
t.Run(name, func(t *testing.T) {
|
||||
AssertEqual(t, got, expect)
|
||||
})
|
||||
|
|
|
@ -131,3 +131,20 @@ func ToString(value interface{}) string {
|
|||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
const nestedRelationSplit = "__"
|
||||
|
||||
// NestedRelationName nested relationships like `Manager__Company`
|
||||
func NestedRelationName(prefix, name string) string {
|
||||
return prefix + nestedRelationSplit + name
|
||||
}
|
||||
|
||||
// SplitNestedRelationName Split nested relationships to `[]string{"Manager","Company"}`
|
||||
func SplitNestedRelationName(name string) []string {
|
||||
return strings.Split(name, nestedRelationSplit)
|
||||
}
|
||||
|
||||
// JoinNestedRelationNames nested relationships like `Manager__Company`
|
||||
func JoinNestedRelationNames(relationNames []string) string {
|
||||
return strings.Join(relationNames, nestedRelationSplit)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue