feat: support nested join (#6067)

* feat: support nested join

* fix: empty rel value
This commit is contained in:
Cr 2023-03-10 17:21:56 +08:00 committed by GitHub
parent 654b5f2006
commit 8bf1f269cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 255 additions and 89 deletions

View File

@ -8,6 +8,8 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
) )
func Query(db *gorm.DB) { func Query(db *gorm.DB) {
@ -109,86 +111,136 @@ func BuildQuerySQL(db *gorm.DB) {
} }
} }
specifiedRelationsName := make(map[string]interface{})
for _, join := range db.Statement.Joins { for _, join := range db.Statement.Joins {
if db.Statement.Schema == nil { if db.Statement.Schema != nil {
fromClause.Joins = append(fromClause.Joins, clause.Join{ var isRelations bool // is relations or raw sql
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, var relations []*schema.Relationship
}) relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { if ok {
tableAliasName := relation.Name isRelations = true
relations = append(relations, relation)
columnStmt := gorm.Statement{ } else {
Table: tableAliasName, DB: db, Schema: relation.FieldSchema, // handle nested join like "Manager.Company"
Selects: join.Selects, Omits: join.Omits, nestedJoinNames := strings.Split(join.Name, ".")
} if len(nestedJoinNames) > 1 {
isNestedJoin := true
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
for _, s := range relation.FieldSchema.DBNames { currentRelations := db.Statement.Schema.Relationships.Relations
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { for _, relname := range nestedJoinNames {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ // incomplete match, only treated as raw sql
Table: tableAliasName, if relation, ok = currentRelations[relname]; ok {
Name: s, gussNestedRelations = append(gussNestedRelations, relation)
Alias: tableAliasName + "__" + s, currentRelations = relation.FieldSchema.Relationships.Relations
}) } else {
} isNestedJoin = false
} break
}
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
} }
} else {
if ref.PrimaryValue == "" { if isNestedJoin {
exprs[idx] = clause.Eq{ isRelations = true
Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, relations = gussNestedRelations
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
}
} else {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
}
} }
} }
} }
{ if isRelations {
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
for _, c := range relation.FieldSchema.QueryClauses { tableAliasName := relation.Name
onStmt.AddClause(c) if parentTableName != clause.CurrentTable {
} tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
}
if join.On != nil { columnStmt := gorm.Statement{
onStmt.AddClause(join.On) Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
} Selects: join.Selects, Omits: join.Omits,
}
if cs, ok := onStmt.Clauses["WHERE"]; ok { selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
if where, ok := cs.Expression.(clause.Where); ok { for _, s := range relation.FieldSchema.DBNames {
where.Build(&onStmt) if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Table: tableAliasName,
Name: s,
Alias: utils.NestedRelationName(tableAliasName, s),
})
}
}
if onSQL := onStmt.SQL.String(); onSQL != "" { exprs := make([]clause.Expression, len(relation.References))
vars := onStmt.Vars for idx, ref := range relation.References {
for idx, v := range vars { if ref.OwnPrimaryKey {
bindvar := strings.Builder{} exprs[idx] = clause.Eq{
onStmt.Vars = vars[0 : idx+1] Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
db.Dialector.BindVarTo(&bindvar, &onStmt, v) Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) }
} else {
if ref.PrimaryValue == "" {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
}
} else {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
}
} }
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
} }
} }
}
}
fromClause.Joins = append(fromClause.Joins, clause.Join{ {
Type: join.JoinType, onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, for _, c := range relation.FieldSchema.QueryClauses {
ON: clause.Where{Exprs: exprs}, onStmt.AddClause(c)
}) }
if join.On != nil {
onStmt.AddClause(join.On)
}
if cs, ok := onStmt.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
where.Build(&onStmt)
if onSQL := onStmt.SQL.String(); onSQL != "" {
vars := onStmt.Vars
for idx, v := range vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
}
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
}
}
}
}
return clause.Join{
Type: joinType,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs},
}
}
parentTableName := clause.CurrentTable
for _, rel := range relations {
// joins table alias like "Manager, Company, Manager__Company"
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
specifiedRelationsName[nestedAlias] = nil
}
parentTableName = rel.Name
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
}
} else { } else {
fromClause.Joins = append(fromClause.Joins, clause.Join{ fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},

68
scan.go
View File

@ -4,10 +4,10 @@ import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"reflect" "reflect"
"strings"
"time" "time"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils"
) )
// prepareValues prepare values slice // prepareValues prepare values slice
@ -50,7 +50,7 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
} }
} }
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) { func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) {
for idx, field := range fields { for idx, field := range fields {
if field != nil { if field != nil {
values[idx] = field.NewValuePool.Get() values[idx] = field.NewValuePool.Get()
@ -65,28 +65,45 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int
db.RowsAffected++ db.RowsAffected++
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
joinedSchemaMap := make(map[*schema.Field]interface{}) joinedNestedSchemaMap := make(map[string]interface{})
for idx, field := range fields { for idx, field := range fields {
if field == nil { if field == nil {
continue continue
} }
if len(joinFields) == 0 || joinFields[idx][0] == nil { if len(joinFields) == 0 || len(joinFields[idx]) == 0 {
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
} else { } else { // joinFields count is larger than 2 when using join
joinSchema := joinFields[idx][0] var isNilPtrValue bool
relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue) var relValue reflect.Value
if relValue.Kind() == reflect.Ptr { // does not contain raw dbname
if _, ok := joinedSchemaMap[joinSchema]; !ok { nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1]
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { // current reflect value
continue currentReflectValue := reflectValue
} fullRels := make([]string, 0, len(nestedJoinSchemas))
for _, joinSchema := range nestedJoinSchemas {
fullRels = append(fullRels, joinSchema.Name)
relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue)
if relValue.Kind() == reflect.Ptr {
fullRelsName := utils.JoinNestedRelationNames(fullRels)
// same nested structure
if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok {
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
isNilPtrValue = true
break
}
relValue.Set(reflect.New(relValue.Type().Elem())) relValue.Set(reflect.New(relValue.Type().Elem()))
joinedSchemaMap[joinSchema] = nil joinedNestedSchemaMap[fullRelsName] = nil
}
} }
currentReflectValue = relValue
}
if !isNilPtrValue { // ignore if value is nil
f := joinFields[idx][len(joinFields[idx])-1]
db.AddError(f.Set(db.Statement.Context, relValue, values[idx]))
} }
db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx]))
} }
// release data to pool // release data to pool
@ -163,7 +180,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
default: default:
var ( var (
fields = make([]*schema.Field, len(columns)) fields = make([]*schema.Field, len(columns))
joinFields [][2]*schema.Field joinFields [][]*schema.Field
sch = db.Statement.Schema sch = db.Statement.Schema
reflectValue = db.Statement.ReflectValue reflectValue = db.Statement.ReflectValue
) )
@ -217,15 +234,26 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
} else { } else {
matchedFieldCount[column] = 1 matchedFieldCount[column] = 1
} }
} else if names := strings.Split(column, "__"); len(names) > 1 { } else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
if rel, ok := sch.Relationships.Relations[names[0]]; ok { if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { subNameCount := len(names)
// nested relation fields
relFields := make([]*schema.Field, 0, subNameCount-1)
relFields = append(relFields, rel.Field)
for _, name := range names[1 : subNameCount-1] {
rel = rel.FieldSchema.Relationships.Relations[name]
relFields = append(relFields, rel.Field)
}
// lastest name is raw dbname
dbName := names[subNameCount-1]
if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
fields[idx] = field fields[idx] = field
if len(joinFields) == 0 { if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns)) joinFields = make([][]*schema.Field, len(columns))
} }
joinFields[idx] = [2]*schema.Field{rel.Field, field} relFields = append(relFields, field)
joinFields[idx] = relFields
continue continue
} }
} }

View File

@ -325,3 +325,66 @@ func TestJoinArgsWithDB(t *testing.T) {
} }
AssertEqual(t, user4.NamedPet.Name, "") AssertEqual(t, user4.NamedPet.Name, "")
} }
func TestNestedJoins(t *testing.T) {
users := []User{
{
Name: "nested-joins-1",
Manager: GetUser("nested-joins-manager-1", Config{Company: true, NamedPet: true}),
NamedPet: &Pet{Name: "nested-joins-namepet-1", Toy: Toy{Name: "nested-joins-namepet-toy-1"}},
},
{
Name: "nested-joins-2",
Manager: GetUser("nested-joins-manager-2", Config{Company: true, NamedPet: true}),
NamedPet: &Pet{Name: "nested-joins-namepet-2", Toy: Toy{Name: "nested-joins-namepet-toy-2"}},
},
}
DB.Create(&users)
var userIDs []uint
for _, user := range users {
userIDs = append(userIDs, user.ID)
}
var users2 []User
if err := DB.
Joins("Manager").
Joins("Manager.Company").
Joins("Manager.NamedPet").
Joins("NamedPet").
Joins("NamedPet.Toy").
Find(&users2, "users.id IN ?", userIDs).Error; err != nil {
t.Fatalf("Failed to load with joins, got error: %v", err)
} else if len(users2) != len(users) {
t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users))
}
sort.Slice(users2, func(i, j int) bool {
return users2[i].ID > users2[j].ID
})
sort.Slice(users, func(i, j int) bool {
return users[i].ID > users[j].ID
})
for idx, user := range users {
// user
CheckUser(t, user, users2[idx])
if users2[idx].Manager == nil {
t.Fatalf("Failed to load Manager")
}
// manager
CheckUser(t, *user.Manager, *users2[idx].Manager)
// user pet
if users2[idx].NamedPet == nil {
t.Fatalf("Failed to load NamedPet")
}
CheckPet(t, *user.NamedPet, *users2[idx].NamedPet)
// manager pet
if users2[idx].Manager.NamedPet == nil {
t.Fatalf("Failed to load NamedPet")
}
CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet)
}
}

View File

@ -13,8 +13,14 @@ import (
func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) {
for _, name := range names { for _, name := range names {
got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() rv := reflect.Indirect(reflect.ValueOf(r))
expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() ev := reflect.Indirect(reflect.ValueOf(e))
if rv.IsValid() != ev.IsValid() {
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), r, e)
return
}
got := rv.FieldByName(name).Interface()
expect := ev.FieldByName(name).Interface()
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
AssertEqual(t, got, expect) AssertEqual(t, got, expect)
}) })

View File

@ -131,3 +131,20 @@ func ToString(value interface{}) string {
} }
return "" return ""
} }
const nestedRelationSplit = "__"
// NestedRelationName nested relationships like `Manager__Company`
func NestedRelationName(prefix, name string) string {
return prefix + nestedRelationSplit + name
}
// SplitNestedRelationName Split nested relationships to `[]string{"Manager","Company"}`
func SplitNestedRelationName(name string) []string {
return strings.Split(name, nestedRelationSplit)
}
// JoinNestedRelationNames nested relationships like `Manager__Company`
func JoinNestedRelationNames(relationNames []string) string {
return strings.Join(relationNames, nestedRelationSplit)
}