From 9dfed613db7e2cb92a6e463bf063bb8fc1f9fd83 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 29 Apr 2020 23:47:18 +0800 Subject: [PATCH] Test inner joins --- callbacks/query.go | 14 ++++++---- callbacks/scan.go | 26 +++++++++++++++-- tests/joins.go | 70 ++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 100 insertions(+), 10 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index ae22f4d0..a3b59b48 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -28,7 +28,8 @@ func Query(db *gorm.DB) { if len(db.Statement.Selects) == 0 { for _, dbName := range db.Statement.Schema.DBNames { clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Name: dbName, + Table: db.Statement.Table, + Name: dbName, }) } } @@ -37,8 +38,9 @@ func Query(db *gorm.DB) { if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { for _, s := range relation.FieldSchema.DBNames { clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: relation.FieldSchema.Table, + Table: relation.Name, Name: s, + Alias: relation.Name + "__" + s, }) } @@ -46,16 +48,16 @@ func Query(db *gorm.DB) { for _, ref := range relation.References { if ref.OwnPrimaryKey { exprs = append(exprs, clause.Expr{ - SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.PrimaryKey.DBName, relation.FieldSchema.Table, ref.ForeignKey.DBName), + SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.PrimaryKey.DBName, relation.Name, ref.ForeignKey.DBName), }) } else { if ref.PrimaryValue == "" { exprs = append(exprs, clause.Expr{ - SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.ForeignKey.DBName, relation.FieldSchema.Table, ref.PrimaryKey.DBName), + SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.ForeignKey.DBName, relation.Name, ref.PrimaryKey.DBName), }) } else { exprs = append(exprs, clause.Expr{ - SQL: fmt.Sprintf("%s.%s = ?", relation.FieldSchema.Table, ref.PrimaryKey.DBName), + SQL: fmt.Sprintf("%s.%s = ?", relation.Name, ref.PrimaryKey.DBName), Vars: []interface{}{ref.PrimaryValue}, }) } @@ -64,7 +66,7 @@ func Query(db *gorm.DB) { joins = append(joins, clause.Join{ Type: clause.LeftJoin, - Table: clause.Table{Name: relation.FieldSchema.Table}, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: relation.Name}, ON: clause.Where{Exprs: exprs}, }) } else { diff --git a/callbacks/scan.go b/callbacks/scan.go index 2bd0143c..6ea8bf23 100644 --- a/callbacks/scan.go +++ b/callbacks/scan.go @@ -3,6 +3,7 @@ package callbacks import ( "database/sql" "reflect" + "strings" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/schema" @@ -54,12 +55,21 @@ func Scan(rows *sql.Rows, db *gorm.DB) { isPtr := db.Statement.ReflectValue.Type().Elem().Kind() == reflect.Ptr db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) fields := make([]*schema.Field, len(columns)) + joinFields := make([][2]*schema.Field, len(columns)) for idx, column := range columns { if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { fields[idx] = field + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + joinFields[idx] = [2]*schema.Field{rel.Field, field} + continue + } + } + values[idx] = &sql.RawBytes{} } else { - values[idx] = sql.RawBytes{} + values[idx] = &sql.RawBytes{} } } @@ -68,6 +78,9 @@ func Scan(rows *sql.Rows, db *gorm.DB) { for idx, field := range fields { if field != nil { values[idx] = field.ReflectValueOf(elem).Addr().Interface() + } else if joinFields[idx][0] != nil { + relValue := joinFields[idx][0].ReflectValueOf(elem) + values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface() } } @@ -86,8 +99,17 @@ func Scan(rows *sql.Rows, db *gorm.DB) { for idx, column := range columns { if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { + relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + values[idx] = field.ReflectValueOf(relValue).Addr().Interface() + continue + } + } + values[idx] = &sql.RawBytes{} } else { - values[idx] = sql.RawBytes{} + values[idx] = &sql.RawBytes{} } } diff --git a/tests/joins.go b/tests/joins.go index 2a8cdc8b..86f9f104 100644 --- a/tests/joins.go +++ b/tests/joins.go @@ -7,9 +7,75 @@ import ( ) func TestJoins(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&User{}) - db.AutoMigrate(&User{}) + db.Migrator().DropTable(&User{}, &Account{}, &Company{}) + db.AutoMigrate(&User{}, &Account{}, &Company{}) + + check := func(t *testing.T, oldUser, newUser User) { + if newUser.Company.ID != oldUser.Company.ID { + t.Errorf("Company is not equal when load with joins, loaded company id: %v", newUser.Company.ID) + } + + if newUser.Manager == nil || newUser.Manager.ID != oldUser.Manager.ID { + t.Errorf("Manager is not equal when load with joins: loaded manager: %+v", newUser.Manager) + } + + if newUser.Account.ID != oldUser.Account.ID { + t.Errorf("Account is not equal when load with joins, loaded account id: %v, expect: %v", newUser.Account.ID, oldUser.Account.ID) + } + } t.Run("Joins", func(t *testing.T) { + user := User{ + Name: "joins-1", + Company: Company{Name: "company"}, + Manager: &User{Name: "manager"}, + Account: Account{Number: "account-has-one-association"}, + } + + db.Create(&user) + + var user2 User + if err := db.Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil { + t.Fatalf("Failed to load with joins, got error: %v", err) + } + + check(t, user, user2) + }) + + t.Run("JoinsForSlice", func(t *testing.T) { + users := []User{{ + Name: "slice-joins-1", + Company: Company{Name: "company"}, + Manager: &User{Name: "manager"}, + Account: Account{Number: "account-has-one-association"}, + }, { + Name: "slice-joins-2", + Company: Company{Name: "company2"}, + Manager: &User{Name: "manager2"}, + Account: Account{Number: "account-has-one-association2"}, + }, { + Name: "slice-joins-3", + Company: Company{Name: "company3"}, + Manager: &User{Name: "manager3"}, + Account: Account{Number: "account-has-one-association3"}, + }} + + db.Create(&users) + + var users2 []User + if err := db.Joins("Company").Joins("Manager").Joins("Account").Find(&users2, "users.name LIKE ?", "slice-joins%").Error; err != nil { + t.Fatalf("Failed to load with joins, got error: %v", err) + } else if len(users2) != len(users) { + t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) + } + + for _, u2 := range users2 { + for _, u := range users { + if u.Name == u2.Name { + check(t, u, u2) + continue + } + } + } }) }