From 2400a46ff7e6a402f9443d1d5c538f13f691e74b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 6 Nov 2013 08:03:39 +0800 Subject: [PATCH] Set associated struct's value after save even they are not pointer --- do.go | 25 +++++++++++++++++++++---- gorm_test.go | 39 ++++++++++++++++++++++++++++++--------- 2 files changed, 51 insertions(+), 13 deletions(-) diff --git a/do.go b/do.go index c1bab3a5..47dbaa18 100644 --- a/do.go +++ b/do.go @@ -120,8 +120,23 @@ func (s *Do) prepareCreateSql() { func (s *Do) saveBeforeAssociations() { for _, field := range s.model.beforeAssociations() { + var id interface{} + do := &Do{chain: s.chain, db: s.db, driver: s.driver} - id := do.setModel(field.Value).save() + + reflect_value := reflect.ValueOf(field.Value) + if reflect_value.CanAddr() { + id = do.setModel(reflect_value.Addr().Interface()).save() + } else { + dest_value := reflect.New(reflect_value.Type()).Elem() + m := &Model{data: field.Value, driver: s.driver} + for _, f := range m.columnsHasValue("other") { + dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) + } + id = do.setModel(dest_value.Addr().Interface()).save() + m.setValueByColumn(field.Name, dest_value.Interface(), s.value) + } + if len(field.foreignKey) > 0 { s.model.setValueByColumn(field.foreignKey, id, s.model.data) } @@ -150,12 +165,14 @@ func (s *Do) saveAfterAssociations() { } else { dest_value := reflect.New(reflect.TypeOf(field.Value)).Elem() m := &Model{data: field.Value, driver: s.driver} - for _, f := range m.columnsHasValue("update") { + for _, f := range m.columnsHasValue("other") { dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) } setFieldValue(dest_value.FieldByName(field.foreignKey), s.model.primaryKeyValue()) - do.setModel(dest_value.Interface()).save() + do.setModel(dest_value.Addr().Interface()).save() + + m.setValueByColumn(field.Name, dest_value.Interface(), s.value) } } } @@ -686,7 +703,7 @@ func (s *Do) combinedSql() string { func (s *Do) createTable() *Do { var sqls []string - for _, field := range s.model.fields("create") { + for _, field := range s.model.fields("other") { if len(field.SqlType) > 0 { sqls = append(sqls, field.DbName+" "+field.SqlType) } diff --git a/gorm_test.go b/gorm_test.go index 4ef8a9ae..9b263018 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -20,7 +20,7 @@ type User struct { CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically DeletedAt time.Time // DeletedAt: Time of record is deleted, refer Soft Delete for more - Email []Email // Embedded structs + Emails []Email // Embedded structs BillingAddress Address // Embedded struct BillingAddressId int64 // Embedded struct's foreign key ShippingAddress Address // Embedded struct @@ -29,9 +29,12 @@ type User struct { } type CreditCard struct { - Id int64 - Number string - UserId int64 + Id int64 + Number string + UserId int64 + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt time.Time } type Email struct { @@ -39,13 +42,19 @@ type Email struct { UserId int64 Email string Subscribed bool + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt time.Time } type Address struct { - Id int64 - Address1 string - Address2 string - Post string + Id int64 + Address1 string + Address2 string + Post string + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt time.Time } type Product struct { @@ -1146,12 +1155,24 @@ func TestRelated(t *testing.T) { Name: "jinzhu", BillingAddress: Address{Address1: "Billing Address - Address 1"}, ShippingAddress: Address{Address1: "Shipping Address - Address 1"}, - Email: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}}, + Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}}, CreditCard: CreditCard{Number: "1234567890"}, } db.Save(&user) + if user.CreditCard.Id == 0 { + t.Errorf("After user save, credit card should have id") + } + + if user.BillingAddress.Id == 0 { + t.Errorf("After user save, billing address should have id") + } + + if user.Emails[0].Id == 0 { + t.Errorf("After user save, billing address should have id") + } + var emails []Email db.Model(&user).Related(&emails) if len(emails) != 2 {