diff --git a/README.md b/README.md index e336dd7e..ce1dae9b 100644 --- a/README.md +++ b/README.md @@ -1093,8 +1093,6 @@ db.Where("email = ?", "x@example.org").Attrs(User{RegisteredIp: "111.111.111.111 DefaultTimeZone, R/W Splitting, Validation * Github Pages * AlertColumn, DropColumn -* db.Preload("Addresses.Map", "active = ?", true).Preload("Profile").Find(&users) -* db.Find(&users).Related(&users) # Author diff --git a/field.go b/field.go index 0f9bb611..b1deab9b 100644 --- a/field.go +++ b/field.go @@ -15,6 +15,15 @@ type relationship struct { Kind string } +// FIXME +func (r relationship) ForeignDBName() string { + return ToSnake(r.ForeignKey) +} + +func (r relationship) AssociationForeignDBName(name string) string { + return ToSnake(r.AssociationForeignKey) +} + type Field struct { Name string DBName string diff --git a/preload.go b/preload.go index eb6cb080..1e6c26b3 100644 --- a/preload.go +++ b/preload.go @@ -1,15 +1,30 @@ package gorm import ( + "database/sql/driver" "errors" "fmt" "reflect" ) +func getFieldValue(value reflect.Value, field string) interface{} { + result := value.FieldByName(field).Interface() + if r, ok := result.(driver.Valuer); ok { + result, _ = r.Value() + } + return result +} + +func equalAsString(a interface{}, b interface{}) bool { + return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b) +} + func Preload(scope *Scope) { // Get Fields var fields map[string]*Field + var isSlice bool if scope.IndirectValue().Kind() == reflect.Slice { + isSlice = true elem := reflect.New(scope.IndirectValue().Type().Elem()).Elem() fields = scope.New(elem.Addr().Interface()).Fields() } else { @@ -23,16 +38,69 @@ func Preload(scope *Scope) { results := makeSlice(field.Field) relation := field.Relationship primaryName := scope.PrimaryKeyField().Name + associationPrimaryKey := scope.New(results).PrimaryKeyField().Name switch relation.Kind { case "has_one": - sql := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignKey)) - scope.NewDB().Find(results, sql, scope.getColumnAsArray(primaryName)) + condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName())) + scope.NewDB().Find(results, condition, scope.getColumnAsArray(primaryName)) + + resultValues := reflect.Indirect(reflect.ValueOf(results)) + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + if isSlice { + value := getFieldValue(result, relation.ForeignKey) + objects := scope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + if equalAsString(getFieldValue(objects.Index(j), primaryName), value) { + objects.Index(j).FieldByName(field.Name).Set(result) + break + } + } + } else { + scope.SetColumn(field, result) + } + } case "has_many": - sql := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignKey)) - scope.NewDB().Find(results, sql, scope.getColumnAsArray(primaryName)) + condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName())) + scope.NewDB().Find(results, condition, scope.getColumnAsArray(primaryName)) + resultValues := reflect.Indirect(reflect.ValueOf(results)) + if isSlice { + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + value := getFieldValue(result, relation.ForeignKey) + objects := scope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + object := objects.Index(j) + if equalAsString(getFieldValue(object, primaryName), value) { + f := object.FieldByName(field.Name) + f.Set(reflect.Append(f, result)) + break + } + } + } + } else { + scope.SetColumn(field, resultValues) + } case "belongs_to": scope.NewDB().Find(results, scope.getColumnAsArray(relation.ForeignKey)) + resultValues := reflect.Indirect(reflect.ValueOf(results)) + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + if isSlice { + value := getFieldValue(result, associationPrimaryKey) + objects := scope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + object := objects.Index(j) + if equalAsString(getFieldValue(object, relation.ForeignKey), value) { + object.FieldByName(field.Name).Set(result) + break + } + } + } else { + scope.SetColumn(field, result) + } + } case "many_to_many": scope.Err(errors.New("not supported relation")) default: diff --git a/preload_test.go b/preload_test.go index 939dbd3e..1107f234 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1,11 +1,79 @@ package gorm_test -import "testing" +import ( + "fmt" + "testing" +) + +func getPreloadUser(name string) User { + return User{ + Name: name, + Role: Role{"Preload"}, + BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)}, + ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)}, + CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)}, + Emails: []Email{ + {Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)}, + }, + Languages: []Language{ + {Name: fmt.Sprintf("lang_1_%v", name)}, + {Name: fmt.Sprintf("lang_2_%v", name)}, + }, + } +} + +func checkUserHasPreloadData(user User, t *testing.T) { + u := getPreloadUser(user.Name) + if user.BillingAddress.Address1 != u.BillingAddress.Address1 { + t.Error("Failed to preload user's BillingAddress") + } + + if user.ShippingAddress.Address1 != u.ShippingAddress.Address1 { + t.Error("Failed to preload user's ShippingAddress") + } + + if user.CreditCard.Number != u.CreditCard.Number { + t.Error("Failed to preload user's CreditCard") + } + + if len(user.Emails) != len(u.Emails) { + t.Error("Failed to preload user's Emails") + } else { + var found int + for _, e1 := range u.Emails { + for _, e2 := range user.Emails { + if e1.Email == e2.Email { + found++ + break + } + } + } + if found != len(u.Emails) { + t.Error("Failed to preload user's email details") + } + } +} func TestPreload(t *testing.T) { - user := User{Name: "PreloadUser", BillingAddress: Address{Address1: "Billing Address"}, ShippingAddress: Address{Address1: "Shipping Address"}, Languages: []Language{{Name: "Preload L1"}, {Name: "Preload L2"}}} - DB.Save(&user) + user1 := getPreloadUser("user1") + DB.Save(&user1) + + var user User + DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress"). + Preload("CreditCard").Preload("Emails").Find(&user) + checkUserHasPreloadData(user, t) + + user2 := getPreloadUser("user2") + DB.Save(&user2) + + user3 := getPreloadUser("user3") + DB.Save(&user3) var users []User - DB.Preload("BillingAddress").Preload("ShippingAddress").Preload("Languages").Find(&users) + DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress"). + Preload("CreditCard").Preload("Emails").Find(&users) + + for _, user := range users { + checkUserHasPreloadData(user, t) + } }