forked from mirror/gorm
Make Preload works with has_one, has_many, belongs_to relations
This commit is contained in:
parent
dd635a17c2
commit
f97e2c088e
|
@ -1093,8 +1093,6 @@ db.Where("email = ?", "x@example.org").Attrs(User{RegisteredIp: "111.111.111.111
|
||||||
DefaultTimeZone, R/W Splitting, Validation
|
DefaultTimeZone, R/W Splitting, Validation
|
||||||
* Github Pages
|
* Github Pages
|
||||||
* AlertColumn, DropColumn
|
* AlertColumn, DropColumn
|
||||||
* db.Preload("Addresses.Map", "active = ?", true).Preload("Profile").Find(&users)
|
|
||||||
* db.Find(&users).Related(&users)
|
|
||||||
|
|
||||||
# Author
|
# Author
|
||||||
|
|
||||||
|
|
9
field.go
9
field.go
|
@ -15,6 +15,15 @@ type relationship struct {
|
||||||
Kind string
|
Kind string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FIXME
|
||||||
|
func (r relationship) ForeignDBName() string {
|
||||||
|
return ToSnake(r.ForeignKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r relationship) AssociationForeignDBName(name string) string {
|
||||||
|
return ToSnake(r.AssociationForeignKey)
|
||||||
|
}
|
||||||
|
|
||||||
type Field struct {
|
type Field struct {
|
||||||
Name string
|
Name string
|
||||||
DBName string
|
DBName string
|
||||||
|
|
76
preload.go
76
preload.go
|
@ -1,15 +1,30 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func getFieldValue(value reflect.Value, field string) interface{} {
|
||||||
|
result := value.FieldByName(field).Interface()
|
||||||
|
if r, ok := result.(driver.Valuer); ok {
|
||||||
|
result, _ = r.Value()
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func equalAsString(a interface{}, b interface{}) bool {
|
||||||
|
return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b)
|
||||||
|
}
|
||||||
|
|
||||||
func Preload(scope *Scope) {
|
func Preload(scope *Scope) {
|
||||||
// Get Fields
|
// Get Fields
|
||||||
var fields map[string]*Field
|
var fields map[string]*Field
|
||||||
|
var isSlice bool
|
||||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||||
|
isSlice = true
|
||||||
elem := reflect.New(scope.IndirectValue().Type().Elem()).Elem()
|
elem := reflect.New(scope.IndirectValue().Type().Elem()).Elem()
|
||||||
fields = scope.New(elem.Addr().Interface()).Fields()
|
fields = scope.New(elem.Addr().Interface()).Fields()
|
||||||
} else {
|
} else {
|
||||||
|
@ -23,16 +38,69 @@ func Preload(scope *Scope) {
|
||||||
results := makeSlice(field.Field)
|
results := makeSlice(field.Field)
|
||||||
relation := field.Relationship
|
relation := field.Relationship
|
||||||
primaryName := scope.PrimaryKeyField().Name
|
primaryName := scope.PrimaryKeyField().Name
|
||||||
|
associationPrimaryKey := scope.New(results).PrimaryKeyField().Name
|
||||||
|
|
||||||
switch relation.Kind {
|
switch relation.Kind {
|
||||||
case "has_one":
|
case "has_one":
|
||||||
sql := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignKey))
|
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName()))
|
||||||
scope.NewDB().Find(results, sql, scope.getColumnAsArray(primaryName))
|
scope.NewDB().Find(results, condition, scope.getColumnAsArray(primaryName))
|
||||||
|
|
||||||
|
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||||
|
for i := 0; i < resultValues.Len(); i++ {
|
||||||
|
result := resultValues.Index(i)
|
||||||
|
if isSlice {
|
||||||
|
value := getFieldValue(result, relation.ForeignKey)
|
||||||
|
objects := scope.IndirectValue()
|
||||||
|
for j := 0; j < objects.Len(); j++ {
|
||||||
|
if equalAsString(getFieldValue(objects.Index(j), primaryName), value) {
|
||||||
|
objects.Index(j).FieldByName(field.Name).Set(result)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
scope.SetColumn(field, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
case "has_many":
|
case "has_many":
|
||||||
sql := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignKey))
|
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName()))
|
||||||
scope.NewDB().Find(results, sql, scope.getColumnAsArray(primaryName))
|
scope.NewDB().Find(results, condition, scope.getColumnAsArray(primaryName))
|
||||||
|
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||||
|
if isSlice {
|
||||||
|
for i := 0; i < resultValues.Len(); i++ {
|
||||||
|
result := resultValues.Index(i)
|
||||||
|
value := getFieldValue(result, relation.ForeignKey)
|
||||||
|
objects := scope.IndirectValue()
|
||||||
|
for j := 0; j < objects.Len(); j++ {
|
||||||
|
object := objects.Index(j)
|
||||||
|
if equalAsString(getFieldValue(object, primaryName), value) {
|
||||||
|
f := object.FieldByName(field.Name)
|
||||||
|
f.Set(reflect.Append(f, result))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
scope.SetColumn(field, resultValues)
|
||||||
|
}
|
||||||
case "belongs_to":
|
case "belongs_to":
|
||||||
scope.NewDB().Find(results, scope.getColumnAsArray(relation.ForeignKey))
|
scope.NewDB().Find(results, scope.getColumnAsArray(relation.ForeignKey))
|
||||||
|
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||||
|
for i := 0; i < resultValues.Len(); i++ {
|
||||||
|
result := resultValues.Index(i)
|
||||||
|
if isSlice {
|
||||||
|
value := getFieldValue(result, associationPrimaryKey)
|
||||||
|
objects := scope.IndirectValue()
|
||||||
|
for j := 0; j < objects.Len(); j++ {
|
||||||
|
object := objects.Index(j)
|
||||||
|
if equalAsString(getFieldValue(object, relation.ForeignKey), value) {
|
||||||
|
object.FieldByName(field.Name).Set(result)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
scope.SetColumn(field, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
case "many_to_many":
|
case "many_to_many":
|
||||||
scope.Err(errors.New("not supported relation"))
|
scope.Err(errors.New("not supported relation"))
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -1,11 +1,79 @@
|
||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getPreloadUser(name string) User {
|
||||||
|
return User{
|
||||||
|
Name: name,
|
||||||
|
Role: Role{"Preload"},
|
||||||
|
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
|
||||||
|
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
|
||||||
|
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
|
||||||
|
Emails: []Email{
|
||||||
|
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
|
||||||
|
},
|
||||||
|
Languages: []Language{
|
||||||
|
{Name: fmt.Sprintf("lang_1_%v", name)},
|
||||||
|
{Name: fmt.Sprintf("lang_2_%v", name)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkUserHasPreloadData(user User, t *testing.T) {
|
||||||
|
u := getPreloadUser(user.Name)
|
||||||
|
if user.BillingAddress.Address1 != u.BillingAddress.Address1 {
|
||||||
|
t.Error("Failed to preload user's BillingAddress")
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.ShippingAddress.Address1 != u.ShippingAddress.Address1 {
|
||||||
|
t.Error("Failed to preload user's ShippingAddress")
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.CreditCard.Number != u.CreditCard.Number {
|
||||||
|
t.Error("Failed to preload user's CreditCard")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(user.Emails) != len(u.Emails) {
|
||||||
|
t.Error("Failed to preload user's Emails")
|
||||||
|
} else {
|
||||||
|
var found int
|
||||||
|
for _, e1 := range u.Emails {
|
||||||
|
for _, e2 := range user.Emails {
|
||||||
|
if e1.Email == e2.Email {
|
||||||
|
found++
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if found != len(u.Emails) {
|
||||||
|
t.Error("Failed to preload user's email details")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestPreload(t *testing.T) {
|
func TestPreload(t *testing.T) {
|
||||||
user := User{Name: "PreloadUser", BillingAddress: Address{Address1: "Billing Address"}, ShippingAddress: Address{Address1: "Shipping Address"}, Languages: []Language{{Name: "Preload L1"}, {Name: "Preload L2"}}}
|
user1 := getPreloadUser("user1")
|
||||||
DB.Save(&user)
|
DB.Save(&user1)
|
||||||
|
|
||||||
|
var user User
|
||||||
|
DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress").
|
||||||
|
Preload("CreditCard").Preload("Emails").Find(&user)
|
||||||
|
checkUserHasPreloadData(user, t)
|
||||||
|
|
||||||
|
user2 := getPreloadUser("user2")
|
||||||
|
DB.Save(&user2)
|
||||||
|
|
||||||
|
user3 := getPreloadUser("user3")
|
||||||
|
DB.Save(&user3)
|
||||||
|
|
||||||
var users []User
|
var users []User
|
||||||
DB.Preload("BillingAddress").Preload("ShippingAddress").Preload("Languages").Find(&users)
|
DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress").
|
||||||
|
Preload("CreditCard").Preload("Emails").Find(&users)
|
||||||
|
|
||||||
|
for _, user := range users {
|
||||||
|
checkUserHasPreloadData(user, t)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue