Make callback create pass all tests

This commit is contained in:
Jinzhu 2014-01-27 10:47:37 +08:00
parent ee6a6827a8
commit 3981baf65d
4 changed files with 50 additions and 26 deletions

View File

@ -20,16 +20,16 @@ func SaveBeforeAssociations(scope *Scope) {
newDB.Save(value.Addr().Interface()) newDB.Save(value.Addr().Interface())
} else { } else {
// If can't take address, then clone the value and set it back // If can't take address, then clone the value and set it back
destValue := reflect.New(reflect.ValueOf(field.Value).Type()).Elem() value = reflect.New(reflect.ValueOf(field.Value).Type()).Elem()
for _, f := range newDB.NewScope(field.Value).Fields() { for _, f := range newDB.NewScope(field.Value).Fields() {
destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
} }
newDB.Save(destValue.Addr().Interface()) newDB.Save(value.Addr().Interface())
scope.SetColumn(field.Name, destValue.Interface()) scope.SetColumn(field.Name, value.Interface())
} }
if len(field.foreignKey) > 0 { if len(field.ForeignKey) > 0 {
scope.SetColumn(field.foreignKey, scope.PrimaryKeyValue()) scope.SetColumn(field.ForeignKey, newDB.NewScope(value.Interface()).PrimaryKeyValue())
} }
} }
} }
@ -46,8 +46,8 @@ func SaveAfterAssociations(scope *Scope) {
newDB := scope.NewDB() newDB := scope.NewDB()
elem := value.Index(i).Addr().Interface() elem := value.Index(i).Addr().Interface()
if len(field.foreignKey) > 0 { if len(field.ForeignKey) > 0 {
newDB.NewScope(elem).SetColumn(field.foreignKey, scope.PrimaryKeyValue()) newDB.NewScope(elem).SetColumn(field.ForeignKey, scope.PrimaryKeyValue())
} }
newDB.Save(elem) newDB.Save(elem)
@ -55,17 +55,17 @@ func SaveAfterAssociations(scope *Scope) {
default: default:
newDB := scope.NewDB() newDB := scope.NewDB()
if value.CanAddr() { if value.CanAddr() {
newDB.NewScope(field.Value).SetColumn(field.foreignKey, scope.PrimaryKeyValue()) newDB.NewScope(field.Value).SetColumn(field.ForeignKey, scope.PrimaryKeyValue())
newDB.Save(field.Value) newDB.Save(field.Value)
} else { } else {
destValue := reflect.New(reflect.TypeOf(field.Value)).Elem() destValue := reflect.New(reflect.TypeOf(field.Value)).Elem()
for _, f := range newDB.NewScope(destValue).Fields() { for _, f := range newDB.NewScope(field.Value).Fields() {
destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
} }
elem := destValue.Addr().Interface() elem := destValue.Addr().Interface()
newDB.NewScope(elem).SetColumn(field.foreignKey, scope.PrimaryKeyValue()) newDB.NewScope(elem).SetColumn(field.ForeignKey, scope.PrimaryKeyValue())
newDB.Save(elem) newDB.Save(elem)
scope.SetColumn(field.Name, destValue.Interface()) scope.SetColumn(field.Name, destValue.Interface())
} }

View File

@ -58,10 +58,8 @@ func (f *Field) parseAssociation() {
if _, ok := f.Value.([]byte); !ok { if _, ok := f.Value.([]byte); !ok {
foreignKey := typ.Name() + "Id" foreignKey := typ.Name() + "Id"
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
f.ForeignKey = foreignKey
f.foreignKey = foreignKey f.foreignKey = foreignKey
} }
f.AfterAssociation = true
f.afterAssociation = true f.afterAssociation = true
} }
case reflect.Struct: case reflect.Struct:
@ -69,16 +67,12 @@ func (f *Field) parseAssociation() {
if elem.FieldByName(f.Name + "Id").IsValid() { if elem.FieldByName(f.Name + "Id").IsValid() {
f.foreignKey = f.Name + "Id" f.foreignKey = f.Name + "Id"
f.beforeAssociation = true f.beforeAssociation = true
f.ForeignKey = f.Name + "Id"
f.BeforeAssociation = true
} else { } else {
foreignKey := typ.Name() + "Id" foreignKey := typ.Name() + "Id"
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
f.foreignKey = foreignKey f.foreignKey = foreignKey
f.ForeignKey = foreignKey
} }
f.afterAssociation = true f.afterAssociation = true
f.AfterAssociation = true
} }
} }
} }

View File

@ -1318,7 +1318,7 @@ func TestRelated(t *testing.T) {
var credit_card CreditCard var credit_card CreditCard
var user3 User var user3 User
db.First(&credit_card, "number = ?", "1234567890") db.First(&credit_card, "number = ?", "1234567890")
db.Model(&credit_card).Related(&user3) db.Debug().Model(&credit_card).Related(&user3)
if user3.Id != user.Id || user3.Name != user.Name { if user3.Id != user.Id || user3.Name != user.Name {
t.Errorf("Should get user from credit card correctly") t.Errorf("Should get user from credit card correctly")
} }

View File

@ -37,7 +37,7 @@ func (scope *Scope) New(value interface{}) *Scope {
} }
func (scope *Scope) NewDB() *DB { func (scope *Scope) NewDB() *DB {
return scope.db.parent return scope.db.new()
} }
func (scope *Scope) DB() sqlCommon { func (scope *Scope) DB() sqlCommon {
@ -199,9 +199,9 @@ func (scope *Scope) Fields() []*Field {
return fields return fields
} }
typ := indirect_value.Type() scope_typ := indirect_value.Type()
for i := 0; i < typ.NumField(); i++ { for i := 0; i < scope_typ.NumField(); i++ {
field_struct := typ.Field(i) field_struct := scope_typ.Field(i)
if field_struct.Anonymous || !ast.IsExported(field_struct.Name) { if field_struct.Anonymous || !ast.IsExported(field_struct.Name) {
continue continue
} }
@ -224,7 +224,35 @@ func (scope *Scope) Fields() []*Field {
field.IsIgnored = true field.IsIgnored = true
} }
field.parseAssociation() // parse association
elem := reflect.Indirect(value)
typ := elem.Type()
switch elem.Kind() {
case reflect.Slice:
typ = typ.Elem()
if _, ok := field.Value.([]byte); !ok {
foreignKey := scope_typ.Name() + "Id"
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
field.ForeignKey = foreignKey
}
field.AfterAssociation = true
}
case reflect.Struct:
if !field.IsTime() && !field.IsScanner() {
if scope.HasColumn(field.Name + "Id") {
field.ForeignKey = field.Name + "Id"
field.BeforeAssociation = true
} else {
foreignKey := scope_typ.Name() + "Id"
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
field.ForeignKey = foreignKey
}
field.AfterAssociation = true
}
}
}
fields = append(fields, &field) fields = append(fields, &field)
} }
@ -249,9 +277,11 @@ func (scope *Scope) Trace(t time.Time) {
} }
func (scope *Scope) Begin() *Scope { func (scope *Scope) Begin() *Scope {
if tx, err := scope.DB().(sqlDb).Begin(); err == nil { if db, ok := scope.DB().(sqlDb); ok {
scope.db.db = interface{}(tx).(sqlCommon) if tx, err := db.Begin(); err == nil {
scope.startedTransaction = true scope.db.db = interface{}(tx).(sqlCommon)
scope.startedTransaction = true
}
} }
return scope return scope
} }