Make callback create pass all tests

This commit is contained in:
Jinzhu 2014-01-27 10:47:37 +08:00
parent ee6a6827a8
commit 3981baf65d
4 changed files with 50 additions and 26 deletions

View File

@ -20,16 +20,16 @@ func SaveBeforeAssociations(scope *Scope) {
newDB.Save(value.Addr().Interface())
} else {
// If can't take address, then clone the value and set it back
destValue := reflect.New(reflect.ValueOf(field.Value).Type()).Elem()
value = reflect.New(reflect.ValueOf(field.Value).Type()).Elem()
for _, f := range newDB.NewScope(field.Value).Fields() {
destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
}
newDB.Save(destValue.Addr().Interface())
scope.SetColumn(field.Name, destValue.Interface())
newDB.Save(value.Addr().Interface())
scope.SetColumn(field.Name, value.Interface())
}
if len(field.foreignKey) > 0 {
scope.SetColumn(field.foreignKey, scope.PrimaryKeyValue())
if len(field.ForeignKey) > 0 {
scope.SetColumn(field.ForeignKey, newDB.NewScope(value.Interface()).PrimaryKeyValue())
}
}
}
@ -46,8 +46,8 @@ func SaveAfterAssociations(scope *Scope) {
newDB := scope.NewDB()
elem := value.Index(i).Addr().Interface()
if len(field.foreignKey) > 0 {
newDB.NewScope(elem).SetColumn(field.foreignKey, scope.PrimaryKeyValue())
if len(field.ForeignKey) > 0 {
newDB.NewScope(elem).SetColumn(field.ForeignKey, scope.PrimaryKeyValue())
}
newDB.Save(elem)
@ -55,17 +55,17 @@ func SaveAfterAssociations(scope *Scope) {
default:
newDB := scope.NewDB()
if value.CanAddr() {
newDB.NewScope(field.Value).SetColumn(field.foreignKey, scope.PrimaryKeyValue())
newDB.NewScope(field.Value).SetColumn(field.ForeignKey, scope.PrimaryKeyValue())
newDB.Save(field.Value)
} else {
destValue := reflect.New(reflect.TypeOf(field.Value)).Elem()
for _, f := range newDB.NewScope(destValue).Fields() {
for _, f := range newDB.NewScope(field.Value).Fields() {
destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
}
elem := destValue.Addr().Interface()
newDB.NewScope(elem).SetColumn(field.foreignKey, scope.PrimaryKeyValue())
newDB.NewScope(elem).SetColumn(field.ForeignKey, scope.PrimaryKeyValue())
newDB.Save(elem)
scope.SetColumn(field.Name, destValue.Interface())
}

View File

@ -58,10 +58,8 @@ func (f *Field) parseAssociation() {
if _, ok := f.Value.([]byte); !ok {
foreignKey := typ.Name() + "Id"
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
f.ForeignKey = foreignKey
f.foreignKey = foreignKey
}
f.AfterAssociation = true
f.afterAssociation = true
}
case reflect.Struct:
@ -69,16 +67,12 @@ func (f *Field) parseAssociation() {
if elem.FieldByName(f.Name + "Id").IsValid() {
f.foreignKey = f.Name + "Id"
f.beforeAssociation = true
f.ForeignKey = f.Name + "Id"
f.BeforeAssociation = true
} else {
foreignKey := typ.Name() + "Id"
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
f.foreignKey = foreignKey
f.ForeignKey = foreignKey
}
f.afterAssociation = true
f.AfterAssociation = true
}
}
}

View File

@ -1318,7 +1318,7 @@ func TestRelated(t *testing.T) {
var credit_card CreditCard
var user3 User
db.First(&credit_card, "number = ?", "1234567890")
db.Model(&credit_card).Related(&user3)
db.Debug().Model(&credit_card).Related(&user3)
if user3.Id != user.Id || user3.Name != user.Name {
t.Errorf("Should get user from credit card correctly")
}

View File

@ -37,7 +37,7 @@ func (scope *Scope) New(value interface{}) *Scope {
}
func (scope *Scope) NewDB() *DB {
return scope.db.parent
return scope.db.new()
}
func (scope *Scope) DB() sqlCommon {
@ -199,9 +199,9 @@ func (scope *Scope) Fields() []*Field {
return fields
}
typ := indirect_value.Type()
for i := 0; i < typ.NumField(); i++ {
field_struct := typ.Field(i)
scope_typ := indirect_value.Type()
for i := 0; i < scope_typ.NumField(); i++ {
field_struct := scope_typ.Field(i)
if field_struct.Anonymous || !ast.IsExported(field_struct.Name) {
continue
}
@ -224,7 +224,35 @@ func (scope *Scope) Fields() []*Field {
field.IsIgnored = true
}
field.parseAssociation()
// parse association
elem := reflect.Indirect(value)
typ := elem.Type()
switch elem.Kind() {
case reflect.Slice:
typ = typ.Elem()
if _, ok := field.Value.([]byte); !ok {
foreignKey := scope_typ.Name() + "Id"
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
field.ForeignKey = foreignKey
}
field.AfterAssociation = true
}
case reflect.Struct:
if !field.IsTime() && !field.IsScanner() {
if scope.HasColumn(field.Name + "Id") {
field.ForeignKey = field.Name + "Id"
field.BeforeAssociation = true
} else {
foreignKey := scope_typ.Name() + "Id"
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
field.ForeignKey = foreignKey
}
field.AfterAssociation = true
}
}
}
fields = append(fields, &field)
}
@ -249,9 +277,11 @@ func (scope *Scope) Trace(t time.Time) {
}
func (scope *Scope) Begin() *Scope {
if tx, err := scope.DB().(sqlDb).Begin(); err == nil {
scope.db.db = interface{}(tx).(sqlCommon)
scope.startedTransaction = true
if db, ok := scope.DB().(sqlDb); ok {
if tx, err := db.Begin(); err == nil {
scope.db.db = interface{}(tx).(sqlCommon)
scope.startedTransaction = true
}
}
return scope
}