forked from mirror/gorm
Make Preload works with has_one, has_many, belongs_to relations
This commit is contained in:
parent
dd635a17c2
commit
f97e2c088e
|
@ -1093,8 +1093,6 @@ db.Where("email = ?", "x@example.org").Attrs(User{RegisteredIp: "111.111.111.111
|
|||
DefaultTimeZone, R/W Splitting, Validation
|
||||
* Github Pages
|
||||
* AlertColumn, DropColumn
|
||||
* db.Preload("Addresses.Map", "active = ?", true).Preload("Profile").Find(&users)
|
||||
* db.Find(&users).Related(&users)
|
||||
|
||||
# Author
|
||||
|
||||
|
|
9
field.go
9
field.go
|
@ -15,6 +15,15 @@ type relationship struct {
|
|||
Kind string
|
||||
}
|
||||
|
||||
// FIXME
|
||||
func (r relationship) ForeignDBName() string {
|
||||
return ToSnake(r.ForeignKey)
|
||||
}
|
||||
|
||||
func (r relationship) AssociationForeignDBName(name string) string {
|
||||
return ToSnake(r.AssociationForeignKey)
|
||||
}
|
||||
|
||||
type Field struct {
|
||||
Name string
|
||||
DBName string
|
||||
|
|
76
preload.go
76
preload.go
|
@ -1,15 +1,30 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func getFieldValue(value reflect.Value, field string) interface{} {
|
||||
result := value.FieldByName(field).Interface()
|
||||
if r, ok := result.(driver.Valuer); ok {
|
||||
result, _ = r.Value()
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func equalAsString(a interface{}, b interface{}) bool {
|
||||
return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b)
|
||||
}
|
||||
|
||||
func Preload(scope *Scope) {
|
||||
// Get Fields
|
||||
var fields map[string]*Field
|
||||
var isSlice bool
|
||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||
isSlice = true
|
||||
elem := reflect.New(scope.IndirectValue().Type().Elem()).Elem()
|
||||
fields = scope.New(elem.Addr().Interface()).Fields()
|
||||
} else {
|
||||
|
@ -23,16 +38,69 @@ func Preload(scope *Scope) {
|
|||
results := makeSlice(field.Field)
|
||||
relation := field.Relationship
|
||||
primaryName := scope.PrimaryKeyField().Name
|
||||
associationPrimaryKey := scope.New(results).PrimaryKeyField().Name
|
||||
|
||||
switch relation.Kind {
|
||||
case "has_one":
|
||||
sql := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignKey))
|
||||
scope.NewDB().Find(results, sql, scope.getColumnAsArray(primaryName))
|
||||
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName()))
|
||||
scope.NewDB().Find(results, condition, scope.getColumnAsArray(primaryName))
|
||||
|
||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
if isSlice {
|
||||
value := getFieldValue(result, relation.ForeignKey)
|
||||
objects := scope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
if equalAsString(getFieldValue(objects.Index(j), primaryName), value) {
|
||||
objects.Index(j).FieldByName(field.Name).Set(result)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
scope.SetColumn(field, result)
|
||||
}
|
||||
}
|
||||
case "has_many":
|
||||
sql := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignKey))
|
||||
scope.NewDB().Find(results, sql, scope.getColumnAsArray(primaryName))
|
||||
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName()))
|
||||
scope.NewDB().Find(results, condition, scope.getColumnAsArray(primaryName))
|
||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||
if isSlice {
|
||||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
value := getFieldValue(result, relation.ForeignKey)
|
||||
objects := scope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
object := objects.Index(j)
|
||||
if equalAsString(getFieldValue(object, primaryName), value) {
|
||||
f := object.FieldByName(field.Name)
|
||||
f.Set(reflect.Append(f, result))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
scope.SetColumn(field, resultValues)
|
||||
}
|
||||
case "belongs_to":
|
||||
scope.NewDB().Find(results, scope.getColumnAsArray(relation.ForeignKey))
|
||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
if isSlice {
|
||||
value := getFieldValue(result, associationPrimaryKey)
|
||||
objects := scope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
object := objects.Index(j)
|
||||
if equalAsString(getFieldValue(object, relation.ForeignKey), value) {
|
||||
object.FieldByName(field.Name).Set(result)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
scope.SetColumn(field, result)
|
||||
}
|
||||
}
|
||||
case "many_to_many":
|
||||
scope.Err(errors.New("not supported relation"))
|
||||
default:
|
||||
|
|
|
@ -1,11 +1,79 @@
|
|||
package gorm_test
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func getPreloadUser(name string) User {
|
||||
return User{
|
||||
Name: name,
|
||||
Role: Role{"Preload"},
|
||||
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
|
||||
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
|
||||
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
|
||||
Emails: []Email{
|
||||
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
|
||||
},
|
||||
Languages: []Language{
|
||||
{Name: fmt.Sprintf("lang_1_%v", name)},
|
||||
{Name: fmt.Sprintf("lang_2_%v", name)},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func checkUserHasPreloadData(user User, t *testing.T) {
|
||||
u := getPreloadUser(user.Name)
|
||||
if user.BillingAddress.Address1 != u.BillingAddress.Address1 {
|
||||
t.Error("Failed to preload user's BillingAddress")
|
||||
}
|
||||
|
||||
if user.ShippingAddress.Address1 != u.ShippingAddress.Address1 {
|
||||
t.Error("Failed to preload user's ShippingAddress")
|
||||
}
|
||||
|
||||
if user.CreditCard.Number != u.CreditCard.Number {
|
||||
t.Error("Failed to preload user's CreditCard")
|
||||
}
|
||||
|
||||
if len(user.Emails) != len(u.Emails) {
|
||||
t.Error("Failed to preload user's Emails")
|
||||
} else {
|
||||
var found int
|
||||
for _, e1 := range u.Emails {
|
||||
for _, e2 := range user.Emails {
|
||||
if e1.Email == e2.Email {
|
||||
found++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if found != len(u.Emails) {
|
||||
t.Error("Failed to preload user's email details")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreload(t *testing.T) {
|
||||
user := User{Name: "PreloadUser", BillingAddress: Address{Address1: "Billing Address"}, ShippingAddress: Address{Address1: "Shipping Address"}, Languages: []Language{{Name: "Preload L1"}, {Name: "Preload L2"}}}
|
||||
DB.Save(&user)
|
||||
user1 := getPreloadUser("user1")
|
||||
DB.Save(&user1)
|
||||
|
||||
var user User
|
||||
DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress").
|
||||
Preload("CreditCard").Preload("Emails").Find(&user)
|
||||
checkUserHasPreloadData(user, t)
|
||||
|
||||
user2 := getPreloadUser("user2")
|
||||
DB.Save(&user2)
|
||||
|
||||
user3 := getPreloadUser("user3")
|
||||
DB.Save(&user3)
|
||||
|
||||
var users []User
|
||||
DB.Preload("BillingAddress").Preload("ShippingAddress").Preload("Languages").Find(&users)
|
||||
DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress").
|
||||
Preload("CreditCard").Preload("Emails").Find(&users)
|
||||
|
||||
for _, user := range users {
|
||||
checkUserHasPreloadData(user, t)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue