Make Preload works with has_one, has_many, belongs_to relations

This commit is contained in:
Jinzhu 2015-02-11 15:37:04 +08:00
parent dd635a17c2
commit f97e2c088e
4 changed files with 153 additions and 10 deletions

View File

@ -1093,8 +1093,6 @@ db.Where("email = ?", "x@example.org").Attrs(User{RegisteredIp: "111.111.111.111
DefaultTimeZone, R/W Splitting, Validation
* Github Pages
* AlertColumn, DropColumn
* db.Preload("Addresses.Map", "active = ?", true).Preload("Profile").Find(&users)
* db.Find(&users).Related(&users)
# Author

View File

@ -15,6 +15,15 @@ type relationship struct {
Kind string
}
// FIXME
func (r relationship) ForeignDBName() string {
return ToSnake(r.ForeignKey)
}
func (r relationship) AssociationForeignDBName(name string) string {
return ToSnake(r.AssociationForeignKey)
}
type Field struct {
Name string
DBName string

View File

@ -1,15 +1,30 @@
package gorm
import (
"database/sql/driver"
"errors"
"fmt"
"reflect"
)
func getFieldValue(value reflect.Value, field string) interface{} {
result := value.FieldByName(field).Interface()
if r, ok := result.(driver.Valuer); ok {
result, _ = r.Value()
}
return result
}
func equalAsString(a interface{}, b interface{}) bool {
return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b)
}
func Preload(scope *Scope) {
// Get Fields
var fields map[string]*Field
var isSlice bool
if scope.IndirectValue().Kind() == reflect.Slice {
isSlice = true
elem := reflect.New(scope.IndirectValue().Type().Elem()).Elem()
fields = scope.New(elem.Addr().Interface()).Fields()
} else {
@ -23,16 +38,69 @@ func Preload(scope *Scope) {
results := makeSlice(field.Field)
relation := field.Relationship
primaryName := scope.PrimaryKeyField().Name
associationPrimaryKey := scope.New(results).PrimaryKeyField().Name
switch relation.Kind {
case "has_one":
sql := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignKey))
scope.NewDB().Find(results, sql, scope.getColumnAsArray(primaryName))
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName()))
scope.NewDB().Find(results, condition, scope.getColumnAsArray(primaryName))
resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if isSlice {
value := getFieldValue(result, relation.ForeignKey)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
if equalAsString(getFieldValue(objects.Index(j), primaryName), value) {
objects.Index(j).FieldByName(field.Name).Set(result)
break
}
}
} else {
scope.SetColumn(field, result)
}
}
case "has_many":
sql := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignKey))
scope.NewDB().Find(results, sql, scope.getColumnAsArray(primaryName))
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName()))
scope.NewDB().Find(results, condition, scope.getColumnAsArray(primaryName))
resultValues := reflect.Indirect(reflect.ValueOf(results))
if isSlice {
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
value := getFieldValue(result, relation.ForeignKey)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := objects.Index(j)
if equalAsString(getFieldValue(object, primaryName), value) {
f := object.FieldByName(field.Name)
f.Set(reflect.Append(f, result))
break
}
}
}
} else {
scope.SetColumn(field, resultValues)
}
case "belongs_to":
scope.NewDB().Find(results, scope.getColumnAsArray(relation.ForeignKey))
resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if isSlice {
value := getFieldValue(result, associationPrimaryKey)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := objects.Index(j)
if equalAsString(getFieldValue(object, relation.ForeignKey), value) {
object.FieldByName(field.Name).Set(result)
break
}
}
} else {
scope.SetColumn(field, result)
}
}
case "many_to_many":
scope.Err(errors.New("not supported relation"))
default:

View File

@ -1,11 +1,79 @@
package gorm_test
import "testing"
import (
"fmt"
"testing"
)
func getPreloadUser(name string) User {
return User{
Name: name,
Role: Role{"Preload"},
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
Emails: []Email{
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
},
Languages: []Language{
{Name: fmt.Sprintf("lang_1_%v", name)},
{Name: fmt.Sprintf("lang_2_%v", name)},
},
}
}
func checkUserHasPreloadData(user User, t *testing.T) {
u := getPreloadUser(user.Name)
if user.BillingAddress.Address1 != u.BillingAddress.Address1 {
t.Error("Failed to preload user's BillingAddress")
}
if user.ShippingAddress.Address1 != u.ShippingAddress.Address1 {
t.Error("Failed to preload user's ShippingAddress")
}
if user.CreditCard.Number != u.CreditCard.Number {
t.Error("Failed to preload user's CreditCard")
}
if len(user.Emails) != len(u.Emails) {
t.Error("Failed to preload user's Emails")
} else {
var found int
for _, e1 := range u.Emails {
for _, e2 := range user.Emails {
if e1.Email == e2.Email {
found++
break
}
}
}
if found != len(u.Emails) {
t.Error("Failed to preload user's email details")
}
}
}
func TestPreload(t *testing.T) {
user := User{Name: "PreloadUser", BillingAddress: Address{Address1: "Billing Address"}, ShippingAddress: Address{Address1: "Shipping Address"}, Languages: []Language{{Name: "Preload L1"}, {Name: "Preload L2"}}}
DB.Save(&user)
user1 := getPreloadUser("user1")
DB.Save(&user1)
var user User
DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress").
Preload("CreditCard").Preload("Emails").Find(&user)
checkUserHasPreloadData(user, t)
user2 := getPreloadUser("user2")
DB.Save(&user2)
user3 := getPreloadUser("user3")
DB.Save(&user3)
var users []User
DB.Preload("BillingAddress").Preload("ShippingAddress").Preload("Languages").Find(&users)
DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress").
Preload("CreditCard").Preload("Emails").Find(&users)
for _, user := range users {
checkUserHasPreloadData(user, t)
}
}