forked from mirror/gorm
Fix error when handle relations
This commit is contained in:
parent
2f5991d088
commit
619ae6549d
23
do.go
23
do.go
|
@ -131,6 +131,7 @@ func (s *Do) saveBeforeAssociations() {
|
||||||
func (s *Do) saveAfterAssociations() {
|
func (s *Do) saveAfterAssociations() {
|
||||||
for _, field := range s.model.afterAssociations() {
|
for _, field := range s.model.afterAssociations() {
|
||||||
reflect_value := reflect.ValueOf(field.Value)
|
reflect_value := reflect.ValueOf(field.Value)
|
||||||
|
|
||||||
switch reflect.TypeOf(field.Value).Kind() {
|
switch reflect.TypeOf(field.Value).Kind() {
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
for i := 0; i < reflect_value.Len(); i++ {
|
for i := 0; i < reflect_value.Len(); i++ {
|
||||||
|
@ -143,7 +144,19 @@ func (s *Do) saveAfterAssociations() {
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
do := &Do{chain: s.chain, db: s.db, driver: s.driver}
|
||||||
do.setModel(field.Value).save()
|
if reflect_value.CanAddr() {
|
||||||
|
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), field.Value)
|
||||||
|
do.setModel(field.Value).save()
|
||||||
|
} else {
|
||||||
|
dest_value := reflect.New(reflect.TypeOf(field.Value)).Elem()
|
||||||
|
m := &Model{data: field.Value, driver: s.driver}
|
||||||
|
for _, f := range m.columnsHasValue("update") {
|
||||||
|
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
|
||||||
|
}
|
||||||
|
|
||||||
|
setFieldValue(dest_value.FieldByName(field.foreignKey), s.model.primaryKeyValue())
|
||||||
|
do.setModel(dest_value.Interface()).save()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -311,12 +324,8 @@ func (s *Do) getForeignKey(from *Model, to *Model, foreign_key string) (err erro
|
||||||
} else {
|
} else {
|
||||||
foreign_value = value
|
foreign_value = value
|
||||||
}
|
}
|
||||||
} else if has_column, is_slice, value := to.ColumnAndValue(foreign_key); has_column {
|
} else if has_column, _, _ := to.ColumnAndValue(foreign_key); has_column {
|
||||||
if is_slice {
|
foreign_value = from.primaryKeyValue()
|
||||||
foreign_value = from.primaryKeyValue()
|
|
||||||
} else {
|
|
||||||
foreign_value = value
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
err = errors.New("Can't find valid foreign Key")
|
err = errors.New("Can't find valid foreign Key")
|
||||||
}
|
}
|
||||||
|
|
24
gorm_test.go
24
gorm_test.go
|
@ -25,11 +25,18 @@ type User struct {
|
||||||
BillingAddressId int64 // Embedded struct's foreign key
|
BillingAddressId int64 // Embedded struct's foreign key
|
||||||
ShippingAddress Address // Embedded struct
|
ShippingAddress Address // Embedded struct
|
||||||
ShippingAddressId int64 // Embedded struct's foreign key
|
ShippingAddressId int64 // Embedded struct's foreign key
|
||||||
|
CreditCard CreditCard
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreditCard struct {
|
||||||
|
Id int64
|
||||||
|
Number string
|
||||||
|
UserId int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type Email struct {
|
type Email struct {
|
||||||
Id int64
|
Id int64
|
||||||
UserId int64 // Foreign key for above embedded structs
|
UserId int64
|
||||||
Email string
|
Email string
|
||||||
Subscribed bool
|
Subscribed bool
|
||||||
}
|
}
|
||||||
|
@ -87,6 +94,7 @@ func init() {
|
||||||
db.Exec("drop table products;")
|
db.Exec("drop table products;")
|
||||||
db.Exec("drop table emails;")
|
db.Exec("drop table emails;")
|
||||||
db.Exec("drop table addresses")
|
db.Exec("drop table addresses")
|
||||||
|
db.Exec("drop table credit_cards")
|
||||||
|
|
||||||
err = db.CreateTable(&User{}).Error
|
err = db.CreateTable(&User{}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -108,6 +116,11 @@ func init() {
|
||||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = db.CreateTable(CreditCard{}).Error
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||||
|
}
|
||||||
|
|
||||||
var shortForm = "2006-01-02 15:04:05"
|
var shortForm = "2006-01-02 15:04:05"
|
||||||
t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40")
|
t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40")
|
||||||
t2, _ = time.Parse(shortForm, "2002-01-01 00:00:00")
|
t2, _ = time.Parse(shortForm, "2002-01-01 00:00:00")
|
||||||
|
@ -1134,6 +1147,7 @@ func TestRelated(t *testing.T) {
|
||||||
BillingAddress: Address{Address1: "Billing Address - Address 1"},
|
BillingAddress: Address{Address1: "Billing Address - Address 1"},
|
||||||
ShippingAddress: Address{Address1: "Shipping Address - Address 1"},
|
ShippingAddress: Address{Address1: "Shipping Address - Address 1"},
|
||||||
Email: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}},
|
Email: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}},
|
||||||
|
CreditCard: CreditCard{Number: "1234567890"},
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Save(&user)
|
db.Save(&user)
|
||||||
|
@ -1155,4 +1169,12 @@ func TestRelated(t *testing.T) {
|
||||||
if user2.Id != user.Id || user2.Name != user.Name {
|
if user2.Id != user.Id || user2.Name != user.Name {
|
||||||
t.Errorf("Should get user from email correctly")
|
t.Errorf("Should get user from email correctly")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var credit_card CreditCard
|
||||||
|
var user3 User
|
||||||
|
db.First(&credit_card, "number = ?", "1234567890")
|
||||||
|
db.Model(&credit_card).Related(&user3)
|
||||||
|
if user3.Id != user.Id || user3.Name != user.Name {
|
||||||
|
t.Errorf("Should get user from credit card correctly")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
8
model.go
8
model.go
|
@ -120,6 +120,10 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
value.Set(reflect.ValueOf(time.Now()))
|
value.Set(reflect.ValueOf(time.Now()))
|
||||||
}
|
}
|
||||||
case "update":
|
case "update":
|
||||||
|
if field.AutoCreateTime && time_value.IsZero() {
|
||||||
|
value.Set(reflect.ValueOf(time.Now()))
|
||||||
|
}
|
||||||
|
|
||||||
if field.AutoUpdateTime {
|
if field.AutoUpdateTime {
|
||||||
value.Set(reflect.ValueOf(time.Now()))
|
value.Set(reflect.ValueOf(time.Now()))
|
||||||
}
|
}
|
||||||
|
@ -153,6 +157,10 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
field.foreignKey = p.Name + "Id"
|
field.foreignKey = p.Name + "Id"
|
||||||
field.beforeAssociation = true
|
field.beforeAssociation = true
|
||||||
} else {
|
} else {
|
||||||
|
foreign_key := typ.Name() + "Id"
|
||||||
|
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
|
||||||
|
field.foreignKey = foreign_key
|
||||||
|
}
|
||||||
field.afterAssociation = true
|
field.afterAssociation = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue