diff --git a/callback_shared.go b/callback_shared.go index 00400230..4a53db4c 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -20,16 +20,16 @@ func SaveBeforeAssociations(scope *Scope) { newDB.Save(value.Addr().Interface()) } else { // If can't take address, then clone the value and set it back - destValue := reflect.New(reflect.ValueOf(field.Value).Type()).Elem() + value = reflect.New(reflect.ValueOf(field.Value).Type()).Elem() for _, f := range newDB.NewScope(field.Value).Fields() { - destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) + value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) } - newDB.Save(destValue.Addr().Interface()) - scope.SetColumn(field.Name, destValue.Interface()) + newDB.Save(value.Addr().Interface()) + scope.SetColumn(field.Name, value.Interface()) } - if len(field.foreignKey) > 0 { - scope.SetColumn(field.foreignKey, scope.PrimaryKeyValue()) + if len(field.ForeignKey) > 0 { + scope.SetColumn(field.ForeignKey, newDB.NewScope(value.Interface()).PrimaryKeyValue()) } } } @@ -46,8 +46,8 @@ func SaveAfterAssociations(scope *Scope) { newDB := scope.NewDB() elem := value.Index(i).Addr().Interface() - if len(field.foreignKey) > 0 { - newDB.NewScope(elem).SetColumn(field.foreignKey, scope.PrimaryKeyValue()) + if len(field.ForeignKey) > 0 { + newDB.NewScope(elem).SetColumn(field.ForeignKey, scope.PrimaryKeyValue()) } newDB.Save(elem) @@ -55,17 +55,17 @@ func SaveAfterAssociations(scope *Scope) { default: newDB := scope.NewDB() if value.CanAddr() { - newDB.NewScope(field.Value).SetColumn(field.foreignKey, scope.PrimaryKeyValue()) + newDB.NewScope(field.Value).SetColumn(field.ForeignKey, scope.PrimaryKeyValue()) newDB.Save(field.Value) } else { destValue := reflect.New(reflect.TypeOf(field.Value)).Elem() - for _, f := range newDB.NewScope(destValue).Fields() { + for _, f := range newDB.NewScope(field.Value).Fields() { destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) } elem := destValue.Addr().Interface() - newDB.NewScope(elem).SetColumn(field.foreignKey, scope.PrimaryKeyValue()) + newDB.NewScope(elem).SetColumn(field.ForeignKey, scope.PrimaryKeyValue()) newDB.Save(elem) scope.SetColumn(field.Name, destValue.Interface()) } diff --git a/field.go b/field.go index 011289f8..e394f2de 100644 --- a/field.go +++ b/field.go @@ -58,10 +58,8 @@ func (f *Field) parseAssociation() { if _, ok := f.Value.([]byte); !ok { foreignKey := typ.Name() + "Id" if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { - f.ForeignKey = foreignKey f.foreignKey = foreignKey } - f.AfterAssociation = true f.afterAssociation = true } case reflect.Struct: @@ -69,16 +67,12 @@ func (f *Field) parseAssociation() { if elem.FieldByName(f.Name + "Id").IsValid() { f.foreignKey = f.Name + "Id" f.beforeAssociation = true - f.ForeignKey = f.Name + "Id" - f.BeforeAssociation = true } else { foreignKey := typ.Name() + "Id" if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { f.foreignKey = foreignKey - f.ForeignKey = foreignKey } f.afterAssociation = true - f.AfterAssociation = true } } } diff --git a/gorm_test.go b/gorm_test.go index 78bf1d3b..45dc9b8f 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -1318,7 +1318,7 @@ func TestRelated(t *testing.T) { var credit_card CreditCard var user3 User db.First(&credit_card, "number = ?", "1234567890") - db.Model(&credit_card).Related(&user3) + db.Debug().Model(&credit_card).Related(&user3) if user3.Id != user.Id || user3.Name != user.Name { t.Errorf("Should get user from credit card correctly") } diff --git a/scope.go b/scope.go index ea2a015c..4e5eca3e 100644 --- a/scope.go +++ b/scope.go @@ -37,7 +37,7 @@ func (scope *Scope) New(value interface{}) *Scope { } func (scope *Scope) NewDB() *DB { - return scope.db.parent + return scope.db.new() } func (scope *Scope) DB() sqlCommon { @@ -199,9 +199,9 @@ func (scope *Scope) Fields() []*Field { return fields } - typ := indirect_value.Type() - for i := 0; i < typ.NumField(); i++ { - field_struct := typ.Field(i) + scope_typ := indirect_value.Type() + for i := 0; i < scope_typ.NumField(); i++ { + field_struct := scope_typ.Field(i) if field_struct.Anonymous || !ast.IsExported(field_struct.Name) { continue } @@ -224,7 +224,35 @@ func (scope *Scope) Fields() []*Field { field.IsIgnored = true } - field.parseAssociation() + // parse association + elem := reflect.Indirect(value) + typ := elem.Type() + + switch elem.Kind() { + case reflect.Slice: + typ = typ.Elem() + + if _, ok := field.Value.([]byte); !ok { + foreignKey := scope_typ.Name() + "Id" + if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { + field.ForeignKey = foreignKey + } + field.AfterAssociation = true + } + case reflect.Struct: + if !field.IsTime() && !field.IsScanner() { + if scope.HasColumn(field.Name + "Id") { + field.ForeignKey = field.Name + "Id" + field.BeforeAssociation = true + } else { + foreignKey := scope_typ.Name() + "Id" + if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { + field.ForeignKey = foreignKey + } + field.AfterAssociation = true + } + } + } fields = append(fields, &field) } @@ -249,9 +277,11 @@ func (scope *Scope) Trace(t time.Time) { } func (scope *Scope) Begin() *Scope { - if tx, err := scope.DB().(sqlDb).Begin(); err == nil { - scope.db.db = interface{}(tx).(sqlCommon) - scope.startedTransaction = true + if db, ok := scope.DB().(sqlDb); ok { + if tx, err := db.Begin(); err == nil { + scope.db.db = interface{}(tx).(sqlCommon) + scope.startedTransaction = true + } } return scope }