forked from mirror/gorm
Test inner joins
This commit is contained in:
parent
85246682c8
commit
9dfed613db
|
@ -28,6 +28,7 @@ func Query(db *gorm.DB) {
|
||||||
if len(db.Statement.Selects) == 0 {
|
if len(db.Statement.Selects) == 0 {
|
||||||
for _, dbName := range db.Statement.Schema.DBNames {
|
for _, dbName := range db.Statement.Schema.DBNames {
|
||||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||||
|
Table: db.Statement.Table,
|
||||||
Name: dbName,
|
Name: dbName,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -37,8 +38,9 @@ func Query(db *gorm.DB) {
|
||||||
if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok {
|
if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok {
|
||||||
for _, s := range relation.FieldSchema.DBNames {
|
for _, s := range relation.FieldSchema.DBNames {
|
||||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||||
Table: relation.FieldSchema.Table,
|
Table: relation.Name,
|
||||||
Name: s,
|
Name: s,
|
||||||
|
Alias: relation.Name + "__" + s,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,16 +48,16 @@ func Query(db *gorm.DB) {
|
||||||
for _, ref := range relation.References {
|
for _, ref := range relation.References {
|
||||||
if ref.OwnPrimaryKey {
|
if ref.OwnPrimaryKey {
|
||||||
exprs = append(exprs, clause.Expr{
|
exprs = append(exprs, clause.Expr{
|
||||||
SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.PrimaryKey.DBName, relation.FieldSchema.Table, ref.ForeignKey.DBName),
|
SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.PrimaryKey.DBName, relation.Name, ref.ForeignKey.DBName),
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
if ref.PrimaryValue == "" {
|
if ref.PrimaryValue == "" {
|
||||||
exprs = append(exprs, clause.Expr{
|
exprs = append(exprs, clause.Expr{
|
||||||
SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.ForeignKey.DBName, relation.FieldSchema.Table, ref.PrimaryKey.DBName),
|
SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.ForeignKey.DBName, relation.Name, ref.PrimaryKey.DBName),
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
exprs = append(exprs, clause.Expr{
|
exprs = append(exprs, clause.Expr{
|
||||||
SQL: fmt.Sprintf("%s.%s = ?", relation.FieldSchema.Table, ref.PrimaryKey.DBName),
|
SQL: fmt.Sprintf("%s.%s = ?", relation.Name, ref.PrimaryKey.DBName),
|
||||||
Vars: []interface{}{ref.PrimaryValue},
|
Vars: []interface{}{ref.PrimaryValue},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -64,7 +66,7 @@ func Query(db *gorm.DB) {
|
||||||
|
|
||||||
joins = append(joins, clause.Join{
|
joins = append(joins, clause.Join{
|
||||||
Type: clause.LeftJoin,
|
Type: clause.LeftJoin,
|
||||||
Table: clause.Table{Name: relation.FieldSchema.Table},
|
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: relation.Name},
|
||||||
ON: clause.Where{Exprs: exprs},
|
ON: clause.Where{Exprs: exprs},
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -3,6 +3,7 @@ package callbacks
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/jinzhu/gorm/schema"
|
"github.com/jinzhu/gorm/schema"
|
||||||
|
@ -54,12 +55,21 @@ func Scan(rows *sql.Rows, db *gorm.DB) {
|
||||||
isPtr := db.Statement.ReflectValue.Type().Elem().Kind() == reflect.Ptr
|
isPtr := db.Statement.ReflectValue.Type().Elem().Kind() == reflect.Ptr
|
||||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0))
|
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0))
|
||||||
fields := make([]*schema.Field, len(columns))
|
fields := make([]*schema.Field, len(columns))
|
||||||
|
joinFields := make([][2]*schema.Field, len(columns))
|
||||||
|
|
||||||
for idx, column := range columns {
|
for idx, column := range columns {
|
||||||
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
||||||
fields[idx] = field
|
fields[idx] = field
|
||||||
|
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||||
|
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
|
||||||
|
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||||
|
joinFields[idx] = [2]*schema.Field{rel.Field, field}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
values[idx] = &sql.RawBytes{}
|
||||||
} else {
|
} else {
|
||||||
values[idx] = sql.RawBytes{}
|
values[idx] = &sql.RawBytes{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,6 +78,9 @@ func Scan(rows *sql.Rows, db *gorm.DB) {
|
||||||
for idx, field := range fields {
|
for idx, field := range fields {
|
||||||
if field != nil {
|
if field != nil {
|
||||||
values[idx] = field.ReflectValueOf(elem).Addr().Interface()
|
values[idx] = field.ReflectValueOf(elem).Addr().Interface()
|
||||||
|
} else if joinFields[idx][0] != nil {
|
||||||
|
relValue := joinFields[idx][0].ReflectValueOf(elem)
|
||||||
|
values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,8 +99,17 @@ func Scan(rows *sql.Rows, db *gorm.DB) {
|
||||||
for idx, column := range columns {
|
for idx, column := range columns {
|
||||||
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
||||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||||
|
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||||
|
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
|
||||||
|
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
|
||||||
|
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||||
|
values[idx] = field.ReflectValueOf(relValue).Addr().Interface()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
values[idx] = &sql.RawBytes{}
|
||||||
} else {
|
} else {
|
||||||
values[idx] = sql.RawBytes{}
|
values[idx] = &sql.RawBytes{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,9 +7,75 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestJoins(t *testing.T, db *gorm.DB) {
|
func TestJoins(t *testing.T, db *gorm.DB) {
|
||||||
db.Migrator().DropTable(&User{})
|
db.Migrator().DropTable(&User{}, &Account{}, &Company{})
|
||||||
db.AutoMigrate(&User{})
|
db.AutoMigrate(&User{}, &Account{}, &Company{})
|
||||||
|
|
||||||
|
check := func(t *testing.T, oldUser, newUser User) {
|
||||||
|
if newUser.Company.ID != oldUser.Company.ID {
|
||||||
|
t.Errorf("Company is not equal when load with joins, loaded company id: %v", newUser.Company.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if newUser.Manager == nil || newUser.Manager.ID != oldUser.Manager.ID {
|
||||||
|
t.Errorf("Manager is not equal when load with joins: loaded manager: %+v", newUser.Manager)
|
||||||
|
}
|
||||||
|
|
||||||
|
if newUser.Account.ID != oldUser.Account.ID {
|
||||||
|
t.Errorf("Account is not equal when load with joins, loaded account id: %v, expect: %v", newUser.Account.ID, oldUser.Account.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
t.Run("Joins", func(t *testing.T) {
|
t.Run("Joins", func(t *testing.T) {
|
||||||
|
user := User{
|
||||||
|
Name: "joins-1",
|
||||||
|
Company: Company{Name: "company"},
|
||||||
|
Manager: &User{Name: "manager"},
|
||||||
|
Account: Account{Number: "account-has-one-association"},
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Create(&user)
|
||||||
|
|
||||||
|
var user2 User
|
||||||
|
if err := db.Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Failed to load with joins, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
check(t, user, user2)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("JoinsForSlice", func(t *testing.T) {
|
||||||
|
users := []User{{
|
||||||
|
Name: "slice-joins-1",
|
||||||
|
Company: Company{Name: "company"},
|
||||||
|
Manager: &User{Name: "manager"},
|
||||||
|
Account: Account{Number: "account-has-one-association"},
|
||||||
|
}, {
|
||||||
|
Name: "slice-joins-2",
|
||||||
|
Company: Company{Name: "company2"},
|
||||||
|
Manager: &User{Name: "manager2"},
|
||||||
|
Account: Account{Number: "account-has-one-association2"},
|
||||||
|
}, {
|
||||||
|
Name: "slice-joins-3",
|
||||||
|
Company: Company{Name: "company3"},
|
||||||
|
Manager: &User{Name: "manager3"},
|
||||||
|
Account: Account{Number: "account-has-one-association3"},
|
||||||
|
}}
|
||||||
|
|
||||||
|
db.Create(&users)
|
||||||
|
|
||||||
|
var users2 []User
|
||||||
|
if err := db.Joins("Company").Joins("Manager").Joins("Account").Find(&users2, "users.name LIKE ?", "slice-joins%").Error; err != nil {
|
||||||
|
t.Fatalf("Failed to load with joins, got error: %v", err)
|
||||||
|
} else if len(users2) != len(users) {
|
||||||
|
t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, u2 := range users2 {
|
||||||
|
for _, u := range users {
|
||||||
|
if u.Name == u2.Name {
|
||||||
|
check(t, u, u2)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue