From 619ae6549dfb06d887bc913671acd315f0834235 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 5 Nov 2013 22:34:49 +0800 Subject: [PATCH] Fix error when handle relations --- do.go | 23 ++++++++++++++++------- gorm_test.go | 24 +++++++++++++++++++++++- model.go | 8 ++++++++ 3 files changed, 47 insertions(+), 8 deletions(-) diff --git a/do.go b/do.go index 694b5868..c1bab3a5 100644 --- a/do.go +++ b/do.go @@ -131,6 +131,7 @@ func (s *Do) saveBeforeAssociations() { func (s *Do) saveAfterAssociations() { for _, field := range s.model.afterAssociations() { reflect_value := reflect.ValueOf(field.Value) + switch reflect.TypeOf(field.Value).Kind() { case reflect.Slice: for i := 0; i < reflect_value.Len(); i++ { @@ -143,7 +144,19 @@ func (s *Do) saveAfterAssociations() { } default: do := &Do{chain: s.chain, db: s.db, driver: s.driver} - do.setModel(field.Value).save() + if reflect_value.CanAddr() { + s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), field.Value) + do.setModel(field.Value).save() + } else { + dest_value := reflect.New(reflect.TypeOf(field.Value)).Elem() + m := &Model{data: field.Value, driver: s.driver} + for _, f := range m.columnsHasValue("update") { + dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) + } + + setFieldValue(dest_value.FieldByName(field.foreignKey), s.model.primaryKeyValue()) + do.setModel(dest_value.Interface()).save() + } } } } @@ -311,12 +324,8 @@ func (s *Do) getForeignKey(from *Model, to *Model, foreign_key string) (err erro } else { foreign_value = value } - } else if has_column, is_slice, value := to.ColumnAndValue(foreign_key); has_column { - if is_slice { - foreign_value = from.primaryKeyValue() - } else { - foreign_value = value - } + } else if has_column, _, _ := to.ColumnAndValue(foreign_key); has_column { + foreign_value = from.primaryKeyValue() } else { err = errors.New("Can't find valid foreign Key") } diff --git a/gorm_test.go b/gorm_test.go index 69f5d2c9..4ef8a9ae 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -25,11 +25,18 @@ type User struct { BillingAddressId int64 // Embedded struct's foreign key ShippingAddress Address // Embedded struct ShippingAddressId int64 // Embedded struct's foreign key + CreditCard CreditCard +} + +type CreditCard struct { + Id int64 + Number string + UserId int64 } type Email struct { Id int64 - UserId int64 // Foreign key for above embedded structs + UserId int64 Email string Subscribed bool } @@ -87,6 +94,7 @@ func init() { db.Exec("drop table products;") db.Exec("drop table emails;") db.Exec("drop table addresses") + db.Exec("drop table credit_cards") err = db.CreateTable(&User{}).Error if err != nil { @@ -108,6 +116,11 @@ func init() { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) } + err = db.CreateTable(CreditCard{}).Error + if err != nil { + panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) + } + var shortForm = "2006-01-02 15:04:05" t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40") t2, _ = time.Parse(shortForm, "2002-01-01 00:00:00") @@ -1134,6 +1147,7 @@ func TestRelated(t *testing.T) { BillingAddress: Address{Address1: "Billing Address - Address 1"}, ShippingAddress: Address{Address1: "Shipping Address - Address 1"}, Email: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}}, + CreditCard: CreditCard{Number: "1234567890"}, } db.Save(&user) @@ -1155,4 +1169,12 @@ func TestRelated(t *testing.T) { if user2.Id != user.Id || user2.Name != user.Name { t.Errorf("Should get user from email correctly") } + + var credit_card CreditCard + var user3 User + db.First(&credit_card, "number = ?", "1234567890") + db.Model(&credit_card).Related(&user3) + if user3.Id != user.Id || user3.Name != user.Name { + t.Errorf("Should get user from credit card correctly") + } } diff --git a/model.go b/model.go index 4d02f980..a30cc2d9 100644 --- a/model.go +++ b/model.go @@ -120,6 +120,10 @@ func (m *Model) fields(operation string) (fields []Field) { value.Set(reflect.ValueOf(time.Now())) } case "update": + if field.AutoCreateTime && time_value.IsZero() { + value.Set(reflect.ValueOf(time.Now())) + } + if field.AutoUpdateTime { value.Set(reflect.ValueOf(time.Now())) } @@ -153,6 +157,10 @@ func (m *Model) fields(operation string) (fields []Field) { field.foreignKey = p.Name + "Id" field.beforeAssociation = true } else { + foreign_key := typ.Name() + "Id" + if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() { + field.foreignKey = foreign_key + } field.afterAssociation = true } }