forked from mirror/gorm
Test HasOne Association
This commit is contained in:
parent
677c745b62
commit
68a7a8207a
|
@ -97,28 +97,34 @@ func (association *Association) Replace(values ...interface{}) error {
|
||||||
}
|
}
|
||||||
case schema.HasOne, schema.HasMany:
|
case schema.HasOne, schema.HasMany:
|
||||||
var (
|
var (
|
||||||
|
tx = association.DB
|
||||||
primaryFields []*schema.Field
|
primaryFields []*schema.Field
|
||||||
foreignKeys []string
|
foreignKeys []string
|
||||||
updateMap = map[string]interface{}{}
|
updateMap = map[string]interface{}{}
|
||||||
|
relPrimaryKeys = []string{}
|
||||||
|
relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel})
|
||||||
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
|
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||||
)
|
)
|
||||||
if rel.Type == schema.BelongsTo {
|
|
||||||
modelValue = reflect.New(rel.Schema.ModelType).Interface()
|
for _, field := range rel.FieldSchema.PrimaryFields {
|
||||||
|
relPrimaryKeys = append(relPrimaryKeys, field.DBName)
|
||||||
|
}
|
||||||
|
if _, qvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(qvs) > 0 {
|
||||||
|
if column, values := schema.ToQueryValues(relPrimaryKeys, qvs); len(values) > 0 {
|
||||||
|
tx = tx.Not(clause.IN{Column: column, Values: values})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ref := range rel.References {
|
for _, ref := range rel.References {
|
||||||
if ref.OwnPrimaryKey {
|
if ref.OwnPrimaryKey {
|
||||||
primaryFields = append(primaryFields, ref.PrimaryKey)
|
primaryFields = append(primaryFields, ref.PrimaryKey)
|
||||||
} else {
|
|
||||||
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
|
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
|
||||||
updateMap[ref.ForeignKey.DBName] = nil
|
updateMap[ref.ForeignKey.DBName] = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if _, qvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(qvs) > 0 {
|
||||||
_, values := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
|
column, values := schema.ToQueryValues(foreignKeys, qvs)
|
||||||
if len(values) == 0 {
|
tx.Model(modelValue).Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap)
|
||||||
column, queryValues := schema.ToQueryValues(foreignKeys, values)
|
|
||||||
association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap)
|
|
||||||
}
|
}
|
||||||
case schema.Many2Many:
|
case schema.Many2Many:
|
||||||
var primaryFields, relPrimaryFields []*schema.Field
|
var primaryFields, relPrimaryFields []*schema.Field
|
||||||
|
@ -413,7 +419,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(values) > 0 {
|
if len(values) > 0 {
|
||||||
association.DB.Select(selectedColumns).Model(nil).Save(reflectValue.Addr().Interface())
|
association.DB.Session(&Session{}).Select(selectedColumns).Model(nil).Save(reflectValue.Addr().Interface())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -124,6 +124,8 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||||
|
|
||||||
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero {
|
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero {
|
||||||
elems = reflect.Append(elems, rv)
|
elems = reflect.Append(elems, rv)
|
||||||
|
} else {
|
||||||
|
db.Session(&gorm.Session{}).Save(rv.Interface())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -149,6 +151,8 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||||
|
|
||||||
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f); isZero {
|
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f); isZero {
|
||||||
db.Session(&gorm.Session{}).Create(f.Interface())
|
db.Session(&gorm.Session{}).Create(f.Interface())
|
||||||
|
} else {
|
||||||
|
db.Session(&gorm.Session{}).Save(f.Interface())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -187,6 +191,8 @@ func SaveAfterAssociations(db *gorm.DB) {
|
||||||
} else {
|
} else {
|
||||||
elems = reflect.Append(elems, elem.Addr())
|
elems = reflect.Append(elems, elem.Addr())
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
db.Session(&gorm.Session{}).Save(elem.Interface())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,11 @@ func BeforeUpdate(db *gorm.DB) {
|
||||||
|
|
||||||
func Update(db *gorm.DB) {
|
func Update(db *gorm.DB) {
|
||||||
db.Statement.AddClauseIfNotExists(clause.Update{})
|
db.Statement.AddClauseIfNotExists(clause.Update{})
|
||||||
db.Statement.AddClause(ConvertToAssignments(db.Statement))
|
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
||||||
|
db.Statement.AddClause(set)
|
||||||
|
} else {
|
||||||
|
return
|
||||||
|
}
|
||||||
db.Statement.Build("UPDATE", "SET", "WHERE")
|
db.Statement.Build("UPDATE", "SET", "WHERE")
|
||||||
|
|
||||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
@ -198,5 +202,6 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,9 +55,11 @@ func (in IN) NegationBuild(builder Builder) {
|
||||||
switch len(in.Values) {
|
switch len(in.Values) {
|
||||||
case 0:
|
case 0:
|
||||||
case 1:
|
case 1:
|
||||||
|
builder.WriteQuoted(in.Column)
|
||||||
builder.WriteString(" <> ")
|
builder.WriteString(" <> ")
|
||||||
builder.AddVar(builder, in.Values...)
|
builder.AddVar(builder, in.Values...)
|
||||||
default:
|
default:
|
||||||
|
builder.WriteQuoted(in.Column)
|
||||||
builder.WriteString(" NOT IN (")
|
builder.WriteString(" NOT IN (")
|
||||||
builder.AddVar(builder, in.Values...)
|
builder.AddVar(builder, in.Values...)
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
|
|
|
@ -603,6 +603,9 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
if _, ok := fieldValue.Interface().(sql.Scanner); ok {
|
if _, ok := fieldValue.Interface().(sql.Scanner); ok {
|
||||||
// struct scanner
|
// struct scanner
|
||||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
||||||
|
if v == nil {
|
||||||
|
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||||
|
} else {
|
||||||
reflectV := reflect.ValueOf(v)
|
reflectV := reflect.ValueOf(v)
|
||||||
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
||||||
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
|
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
|
||||||
|
@ -613,11 +616,15 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
} else {
|
} else {
|
||||||
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
|
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
|
} else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
|
||||||
// pointer scanner
|
// pointer scanner
|
||||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
||||||
|
if v == nil {
|
||||||
|
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||||
|
} else {
|
||||||
reflectV := reflect.ValueOf(v)
|
reflectV := reflect.ValueOf(v)
|
||||||
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
||||||
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
|
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
|
||||||
|
@ -630,6 +637,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
} else {
|
} else {
|
||||||
err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v)
|
err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -233,3 +233,79 @@ func TestBelongsToAssociationForSlice(t *testing.T) {
|
||||||
AssertAssociationCount(t, users[0], "Company", 0, "After Delete")
|
AssertAssociationCount(t, users[0], "Company", 0, "After Delete")
|
||||||
AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete")
|
AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHasOneAssociation(t *testing.T) {
|
||||||
|
var user = *GetUser("hasone", Config{Account: true})
|
||||||
|
|
||||||
|
if err := DB.Create(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when create: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
CheckUser(t, user, user)
|
||||||
|
|
||||||
|
// Find
|
||||||
|
var user2 User
|
||||||
|
DB.Find(&user2, "id = ?", user.ID)
|
||||||
|
DB.Model(&user2).Association("Account").Find(&user2.Account)
|
||||||
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
// Count
|
||||||
|
AssertAssociationCount(t, user, "Account", 1, "")
|
||||||
|
|
||||||
|
// Append
|
||||||
|
var account = Account{Number: "account-has-one-append"}
|
||||||
|
|
||||||
|
if err := DB.Model(&user2).Association("Account").Append(&account); err != nil {
|
||||||
|
t.Fatalf("Error happened when append account, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if account.ID == 0 {
|
||||||
|
t.Fatalf("Account's ID should be created")
|
||||||
|
}
|
||||||
|
|
||||||
|
user.Account = account
|
||||||
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
AssertAssociationCount(t, user, "Account", 1, "AfterAppend")
|
||||||
|
|
||||||
|
// Replace
|
||||||
|
var account2 = Account{Number: "account-has-one-replace"}
|
||||||
|
|
||||||
|
if err := DB.Model(&user2).Association("Account").Replace(&account2); err != nil {
|
||||||
|
t.Fatalf("Error happened when append Account, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if account2.ID == 0 {
|
||||||
|
t.Fatalf("account2's ID should be created")
|
||||||
|
}
|
||||||
|
|
||||||
|
user.Account = account2
|
||||||
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
AssertAssociationCount(t, user2, "Account", 1, "AfterReplace")
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
if err := DB.Model(&user2).Association("Account").Delete(&Company{}); err != nil {
|
||||||
|
t.Fatalf("Error happened when delete account, got %v", err)
|
||||||
|
}
|
||||||
|
AssertAssociationCount(t, user2, "Account", 1, "after delete non-existing data")
|
||||||
|
|
||||||
|
if err := DB.Model(&user2).Association("Account").Delete(&account2); err != nil {
|
||||||
|
t.Fatalf("Error happened when delete Account, got %v", err)
|
||||||
|
}
|
||||||
|
AssertAssociationCount(t, user2, "Account", 0, "after delete")
|
||||||
|
|
||||||
|
// Prepare Data for Clear
|
||||||
|
if err := DB.Model(&user2).Association("Account").Append(&account); err != nil {
|
||||||
|
t.Fatalf("Error happened when append Account, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertAssociationCount(t, user2, "Account", 1, "after prepare data")
|
||||||
|
|
||||||
|
// Clear
|
||||||
|
if err := DB.Model(&user2).Association("Account").Clear(); err != nil {
|
||||||
|
t.Errorf("Error happened when clear Account, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertAssociationCount(t, user2, "Account", 0, "after clear")
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue