Fix error when handle relations

This commit is contained in:
Jinzhu 2013-11-05 22:34:49 +08:00
parent 2f5991d088
commit 619ae6549d
3 changed files with 47 additions and 8 deletions

23
do.go
View File

@ -131,6 +131,7 @@ func (s *Do) saveBeforeAssociations() {
func (s *Do) saveAfterAssociations() { func (s *Do) saveAfterAssociations() {
for _, field := range s.model.afterAssociations() { for _, field := range s.model.afterAssociations() {
reflect_value := reflect.ValueOf(field.Value) reflect_value := reflect.ValueOf(field.Value)
switch reflect.TypeOf(field.Value).Kind() { switch reflect.TypeOf(field.Value).Kind() {
case reflect.Slice: case reflect.Slice:
for i := 0; i < reflect_value.Len(); i++ { for i := 0; i < reflect_value.Len(); i++ {
@ -143,7 +144,19 @@ func (s *Do) saveAfterAssociations() {
} }
default: default:
do := &Do{chain: s.chain, db: s.db, driver: s.driver} do := &Do{chain: s.chain, db: s.db, driver: s.driver}
do.setModel(field.Value).save() if reflect_value.CanAddr() {
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), field.Value)
do.setModel(field.Value).save()
} else {
dest_value := reflect.New(reflect.TypeOf(field.Value)).Elem()
m := &Model{data: field.Value, driver: s.driver}
for _, f := range m.columnsHasValue("update") {
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
}
setFieldValue(dest_value.FieldByName(field.foreignKey), s.model.primaryKeyValue())
do.setModel(dest_value.Interface()).save()
}
} }
} }
} }
@ -311,12 +324,8 @@ func (s *Do) getForeignKey(from *Model, to *Model, foreign_key string) (err erro
} else { } else {
foreign_value = value foreign_value = value
} }
} else if has_column, is_slice, value := to.ColumnAndValue(foreign_key); has_column { } else if has_column, _, _ := to.ColumnAndValue(foreign_key); has_column {
if is_slice { foreign_value = from.primaryKeyValue()
foreign_value = from.primaryKeyValue()
} else {
foreign_value = value
}
} else { } else {
err = errors.New("Can't find valid foreign Key") err = errors.New("Can't find valid foreign Key")
} }

View File

@ -25,11 +25,18 @@ type User struct {
BillingAddressId int64 // Embedded struct's foreign key BillingAddressId int64 // Embedded struct's foreign key
ShippingAddress Address // Embedded struct ShippingAddress Address // Embedded struct
ShippingAddressId int64 // Embedded struct's foreign key ShippingAddressId int64 // Embedded struct's foreign key
CreditCard CreditCard
}
type CreditCard struct {
Id int64
Number string
UserId int64
} }
type Email struct { type Email struct {
Id int64 Id int64
UserId int64 // Foreign key for above embedded structs UserId int64
Email string Email string
Subscribed bool Subscribed bool
} }
@ -87,6 +94,7 @@ func init() {
db.Exec("drop table products;") db.Exec("drop table products;")
db.Exec("drop table emails;") db.Exec("drop table emails;")
db.Exec("drop table addresses") db.Exec("drop table addresses")
db.Exec("drop table credit_cards")
err = db.CreateTable(&User{}).Error err = db.CreateTable(&User{}).Error
if err != nil { if err != nil {
@ -108,6 +116,11 @@ func init() {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
} }
err = db.CreateTable(CreditCard{}).Error
if err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
}
var shortForm = "2006-01-02 15:04:05" var shortForm = "2006-01-02 15:04:05"
t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40") t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40")
t2, _ = time.Parse(shortForm, "2002-01-01 00:00:00") t2, _ = time.Parse(shortForm, "2002-01-01 00:00:00")
@ -1134,6 +1147,7 @@ func TestRelated(t *testing.T) {
BillingAddress: Address{Address1: "Billing Address - Address 1"}, BillingAddress: Address{Address1: "Billing Address - Address 1"},
ShippingAddress: Address{Address1: "Shipping Address - Address 1"}, ShippingAddress: Address{Address1: "Shipping Address - Address 1"},
Email: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}}, Email: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}},
CreditCard: CreditCard{Number: "1234567890"},
} }
db.Save(&user) db.Save(&user)
@ -1155,4 +1169,12 @@ func TestRelated(t *testing.T) {
if user2.Id != user.Id || user2.Name != user.Name { if user2.Id != user.Id || user2.Name != user.Name {
t.Errorf("Should get user from email correctly") t.Errorf("Should get user from email correctly")
} }
var credit_card CreditCard
var user3 User
db.First(&credit_card, "number = ?", "1234567890")
db.Model(&credit_card).Related(&user3)
if user3.Id != user.Id || user3.Name != user.Name {
t.Errorf("Should get user from credit card correctly")
}
} }

View File

@ -120,6 +120,10 @@ func (m *Model) fields(operation string) (fields []Field) {
value.Set(reflect.ValueOf(time.Now())) value.Set(reflect.ValueOf(time.Now()))
} }
case "update": case "update":
if field.AutoCreateTime && time_value.IsZero() {
value.Set(reflect.ValueOf(time.Now()))
}
if field.AutoUpdateTime { if field.AutoUpdateTime {
value.Set(reflect.ValueOf(time.Now())) value.Set(reflect.ValueOf(time.Now()))
} }
@ -153,6 +157,10 @@ func (m *Model) fields(operation string) (fields []Field) {
field.foreignKey = p.Name + "Id" field.foreignKey = p.Name + "Id"
field.beforeAssociation = true field.beforeAssociation = true
} else { } else {
foreign_key := typ.Name() + "Id"
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
field.foreignKey = foreign_key
}
field.afterAssociation = true field.afterAssociation = true
} }
} }