forked from mirror/gorm
Fix error when handle relations
This commit is contained in:
parent
2f5991d088
commit
619ae6549d
19
do.go
19
do.go
|
@ -131,6 +131,7 @@ func (s *Do) saveBeforeAssociations() {
|
|||
func (s *Do) saveAfterAssociations() {
|
||||
for _, field := range s.model.afterAssociations() {
|
||||
reflect_value := reflect.ValueOf(field.Value)
|
||||
|
||||
switch reflect.TypeOf(field.Value).Kind() {
|
||||
case reflect.Slice:
|
||||
for i := 0; i < reflect_value.Len(); i++ {
|
||||
|
@ -143,7 +144,19 @@ func (s *Do) saveAfterAssociations() {
|
|||
}
|
||||
default:
|
||||
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
||||
if reflect_value.CanAddr() {
|
||||
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), field.Value)
|
||||
do.setModel(field.Value).save()
|
||||
} else {
|
||||
dest_value := reflect.New(reflect.TypeOf(field.Value)).Elem()
|
||||
m := &Model{data: field.Value, driver: s.driver}
|
||||
for _, f := range m.columnsHasValue("update") {
|
||||
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
|
||||
}
|
||||
|
||||
setFieldValue(dest_value.FieldByName(field.foreignKey), s.model.primaryKeyValue())
|
||||
do.setModel(dest_value.Interface()).save()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -311,12 +324,8 @@ func (s *Do) getForeignKey(from *Model, to *Model, foreign_key string) (err erro
|
|||
} else {
|
||||
foreign_value = value
|
||||
}
|
||||
} else if has_column, is_slice, value := to.ColumnAndValue(foreign_key); has_column {
|
||||
if is_slice {
|
||||
} else if has_column, _, _ := to.ColumnAndValue(foreign_key); has_column {
|
||||
foreign_value = from.primaryKeyValue()
|
||||
} else {
|
||||
foreign_value = value
|
||||
}
|
||||
} else {
|
||||
err = errors.New("Can't find valid foreign Key")
|
||||
}
|
||||
|
|
24
gorm_test.go
24
gorm_test.go
|
@ -25,11 +25,18 @@ type User struct {
|
|||
BillingAddressId int64 // Embedded struct's foreign key
|
||||
ShippingAddress Address // Embedded struct
|
||||
ShippingAddressId int64 // Embedded struct's foreign key
|
||||
CreditCard CreditCard
|
||||
}
|
||||
|
||||
type CreditCard struct {
|
||||
Id int64
|
||||
Number string
|
||||
UserId int64
|
||||
}
|
||||
|
||||
type Email struct {
|
||||
Id int64
|
||||
UserId int64 // Foreign key for above embedded structs
|
||||
UserId int64
|
||||
Email string
|
||||
Subscribed bool
|
||||
}
|
||||
|
@ -87,6 +94,7 @@ func init() {
|
|||
db.Exec("drop table products;")
|
||||
db.Exec("drop table emails;")
|
||||
db.Exec("drop table addresses")
|
||||
db.Exec("drop table credit_cards")
|
||||
|
||||
err = db.CreateTable(&User{}).Error
|
||||
if err != nil {
|
||||
|
@ -108,6 +116,11 @@ func init() {
|
|||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||
}
|
||||
|
||||
err = db.CreateTable(CreditCard{}).Error
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||
}
|
||||
|
||||
var shortForm = "2006-01-02 15:04:05"
|
||||
t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40")
|
||||
t2, _ = time.Parse(shortForm, "2002-01-01 00:00:00")
|
||||
|
@ -1134,6 +1147,7 @@ func TestRelated(t *testing.T) {
|
|||
BillingAddress: Address{Address1: "Billing Address - Address 1"},
|
||||
ShippingAddress: Address{Address1: "Shipping Address - Address 1"},
|
||||
Email: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}},
|
||||
CreditCard: CreditCard{Number: "1234567890"},
|
||||
}
|
||||
|
||||
db.Save(&user)
|
||||
|
@ -1155,4 +1169,12 @@ func TestRelated(t *testing.T) {
|
|||
if user2.Id != user.Id || user2.Name != user.Name {
|
||||
t.Errorf("Should get user from email correctly")
|
||||
}
|
||||
|
||||
var credit_card CreditCard
|
||||
var user3 User
|
||||
db.First(&credit_card, "number = ?", "1234567890")
|
||||
db.Model(&credit_card).Related(&user3)
|
||||
if user3.Id != user.Id || user3.Name != user.Name {
|
||||
t.Errorf("Should get user from credit card correctly")
|
||||
}
|
||||
}
|
||||
|
|
8
model.go
8
model.go
|
@ -120,6 +120,10 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
value.Set(reflect.ValueOf(time.Now()))
|
||||
}
|
||||
case "update":
|
||||
if field.AutoCreateTime && time_value.IsZero() {
|
||||
value.Set(reflect.ValueOf(time.Now()))
|
||||
}
|
||||
|
||||
if field.AutoUpdateTime {
|
||||
value.Set(reflect.ValueOf(time.Now()))
|
||||
}
|
||||
|
@ -153,6 +157,10 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
field.foreignKey = p.Name + "Id"
|
||||
field.beforeAssociation = true
|
||||
} else {
|
||||
foreign_key := typ.Name() + "Id"
|
||||
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
|
||||
field.foreignKey = foreign_key
|
||||
}
|
||||
field.afterAssociation = true
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue