mirror of https://github.com/go-gorm/gorm.git
Add Slice Association for BelongsTo
This commit is contained in:
parent
91a695893c
commit
2db33730b6
|
@ -366,6 +366,11 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
|||
if clear && len(values) == 0 {
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
|
||||
for _, ref := range association.Relationship.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
@ -382,6 +387,11 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
|||
case reflect.Struct:
|
||||
if clear && len(values) == 0 {
|
||||
association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
|
||||
for _, ref := range association.Relationship.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for idx, value := range values {
|
||||
|
@ -392,10 +402,12 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
|||
_, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue)
|
||||
}
|
||||
|
||||
if len(values) > 0 {
|
||||
if hasZero {
|
||||
association.DB.Save(reflectValue.Addr().Interface())
|
||||
association.DB.Create(reflectValue.Addr().Interface())
|
||||
} else {
|
||||
association.DB.Select(selectedColumns).Save(reflectValue.Addr().Interface())
|
||||
association.DB.Select(selectedColumns).Model(nil).Save(reflectValue.Addr().Interface())
|
||||
}
|
||||
}
|
||||
|
||||
for _, assignBack := range assignBacks {
|
||||
|
|
|
@ -173,12 +173,30 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||
}
|
||||
|
||||
if stmt.Dest != stmt.Model {
|
||||
reflectValue := reflect.ValueOf(stmt.Model)
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Model))
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var priamryKeyExprs []clause.Expression
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
|
||||
var notZero bool
|
||||
for idx, field := range stmt.Schema.PrimaryFields {
|
||||
value, isZero := field.ValueOf(reflectValue.Index(i))
|
||||
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
|
||||
notZero = notZero || !isZero
|
||||
}
|
||||
if notZero {
|
||||
priamryKeyExprs = append(priamryKeyExprs, clause.And(exprs...))
|
||||
}
|
||||
}
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}})
|
||||
case reflect.Struct:
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
if value, isZero := field.ValueOf(reflectValue); !isZero {
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -19,4 +19,6 @@ var (
|
|||
ErrMissingWhereClause = errors.New("missing WHERE clause while deleting")
|
||||
// ErrUnsupportedRelation unsupported relations
|
||||
ErrUnsupportedRelation = errors.New("unsupported relations")
|
||||
// ErrPtrStructSupported only ptr of struct supported
|
||||
ErrPtrStructSupported = errors.New("only ptr of struct supported")
|
||||
)
|
||||
|
|
|
@ -23,7 +23,11 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
|||
|
||||
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
|
||||
where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))}
|
||||
reflectValue := reflect.ValueOf(value)
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
tx.AddError(ErrPtrStructSupported)
|
||||
case reflect.Struct:
|
||||
for idx, pf := range tx.Statement.Schema.PrimaryFields {
|
||||
if pv, isZero := pf.ValueOf(reflectValue); isZero {
|
||||
tx.callbacks.Create().Execute(tx)
|
||||
|
@ -31,6 +35,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
|||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tx.Statement.AddClause(where)
|
||||
}
|
||||
|
|
|
@ -6,7 +6,26 @@ import (
|
|||
. "github.com/jinzhu/gorm/tests"
|
||||
)
|
||||
|
||||
func TestAssociationForBelongsTo(t *testing.T) {
|
||||
func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) {
|
||||
if count := DB.Model(data).Association(name).Count(); count != result {
|
||||
t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count)
|
||||
}
|
||||
|
||||
var newUser User
|
||||
if user, ok := data.(User); ok {
|
||||
DB.Find(&newUser, "id = ?", user.ID)
|
||||
} else if user, ok := data.(*User); ok {
|
||||
DB.Find(&newUser, "id = ?", user.ID)
|
||||
}
|
||||
|
||||
if newUser.ID != 0 {
|
||||
if count := DB.Model(&newUser).Association(name).Count(); count != result {
|
||||
t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBelongsToAssociation(t *testing.T) {
|
||||
var user = *GetUser("belongs-to", Config{Company: true, Manager: true})
|
||||
|
||||
if err := DB.Create(&user).Error; err != nil {
|
||||
|
@ -24,13 +43,8 @@ func TestAssociationForBelongsTo(t *testing.T) {
|
|||
CheckUser(t, user2, user)
|
||||
|
||||
// Count
|
||||
if count := DB.Model(&user).Association("Company").Count(); count != 1 {
|
||||
t.Errorf("invalid company count, got %v", count)
|
||||
}
|
||||
|
||||
if count := DB.Model(&user).Association("Manager").Count(); count != 1 {
|
||||
t.Errorf("invalid manager count, got %v", count)
|
||||
}
|
||||
AssertAssociationCount(t, user, "Company", 1, "")
|
||||
AssertAssociationCount(t, user, "Manager", 1, "")
|
||||
|
||||
// Append
|
||||
var company = Company{Name: "company-belongs-to-append"}
|
||||
|
@ -58,6 +72,9 @@ func TestAssociationForBelongsTo(t *testing.T) {
|
|||
user.ManagerID = &manager.ID
|
||||
CheckUser(t, user2, user)
|
||||
|
||||
AssertAssociationCount(t, user2, "Company", 1, "AfterAppend")
|
||||
AssertAssociationCount(t, user2, "Manager", 1, "AfterAppend")
|
||||
|
||||
// Replace
|
||||
var company2 = Company{Name: "company-belongs-to-replace"}
|
||||
var manager2 = GetUser("manager-belongs-to-replace", Config{})
|
||||
|
@ -84,40 +101,31 @@ func TestAssociationForBelongsTo(t *testing.T) {
|
|||
user.ManagerID = &manager2.ID
|
||||
CheckUser(t, user2, user)
|
||||
|
||||
AssertAssociationCount(t, user2, "Company", 1, "AfterReplace")
|
||||
AssertAssociationCount(t, user2, "Manager", 1, "AfterReplace")
|
||||
|
||||
// Delete
|
||||
if err := DB.Model(&user2).Association("Company").Delete(&Company{}); err != nil {
|
||||
t.Fatalf("Error happened when delete Company, got %v", err)
|
||||
}
|
||||
|
||||
if count := DB.Model(&user2).Association("Company").Count(); count != 1 {
|
||||
t.Errorf("Invalid company count after delete non-existing association, got %v", count)
|
||||
}
|
||||
AssertAssociationCount(t, user2, "Company", 1, "after delete non-existing data")
|
||||
|
||||
if err := DB.Model(&user2).Association("Company").Delete(&company2); err != nil {
|
||||
t.Fatalf("Error happened when delete Company, got %v", err)
|
||||
}
|
||||
|
||||
if count := DB.Model(&user2).Association("Company").Count(); count != 0 {
|
||||
t.Errorf("Invalid company count after delete, got %v", count)
|
||||
}
|
||||
AssertAssociationCount(t, user2, "Company", 0, "after delete")
|
||||
|
||||
if err := DB.Model(&user2).Association("Manager").Delete(&User{}); err != nil {
|
||||
t.Fatalf("Error happened when delete Manager, got %v", err)
|
||||
}
|
||||
|
||||
if count := DB.Model(&user2).Association("Manager").Count(); count != 1 {
|
||||
t.Errorf("Invalid manager count after delete non-existing association, got %v", count)
|
||||
}
|
||||
AssertAssociationCount(t, user2, "Manager", 1, "after delete non-existing data")
|
||||
|
||||
if err := DB.Model(&user2).Association("Manager").Delete(manager2); err != nil {
|
||||
t.Fatalf("Error happened when delete Manager, got %v", err)
|
||||
}
|
||||
AssertAssociationCount(t, user2, "Manager", 0, "after delete")
|
||||
|
||||
if count := DB.Model(&user2).Association("Manager").Count(); count != 0 {
|
||||
t.Errorf("Invalid manager count after delete, got %v", count)
|
||||
}
|
||||
|
||||
// Prepare Data
|
||||
// Prepare Data for Clear
|
||||
if err := DB.Model(&user2).Association("Company").Append(&company); err != nil {
|
||||
t.Fatalf("Error happened when append Company, got %v", err)
|
||||
}
|
||||
|
@ -126,13 +134,8 @@ func TestAssociationForBelongsTo(t *testing.T) {
|
|||
t.Fatalf("Error happened when append Manager, got %v", err)
|
||||
}
|
||||
|
||||
if count := DB.Model(&user2).Association("Company").Count(); count != 1 {
|
||||
t.Errorf("Invalid company count after append, got %v", count)
|
||||
}
|
||||
|
||||
if count := DB.Model(&user2).Association("Manager").Count(); count != 1 {
|
||||
t.Errorf("Invalid manager count after append, got %v", count)
|
||||
}
|
||||
AssertAssociationCount(t, user2, "Company", 1, "after prepare data")
|
||||
AssertAssociationCount(t, user2, "Manager", 1, "after prepare data")
|
||||
|
||||
// Clear
|
||||
if err := DB.Model(&user2).Association("Company").Clear(); err != nil {
|
||||
|
@ -143,11 +146,43 @@ func TestAssociationForBelongsTo(t *testing.T) {
|
|||
t.Errorf("Error happened when clear Manager, got %v", err)
|
||||
}
|
||||
|
||||
if count := DB.Model(&user2).Association("Company").Count(); count != 0 {
|
||||
t.Errorf("Invalid company count after clear, got %v", count)
|
||||
AssertAssociationCount(t, user2, "Company", 0, "after clear")
|
||||
AssertAssociationCount(t, user2, "Manager", 0, "after clear")
|
||||
}
|
||||
|
||||
if count := DB.Model(&user2).Association("Manager").Count(); count != 0 {
|
||||
t.Errorf("Invalid manager count after clear, got %v", count)
|
||||
func TestBelongsToAssociationForSlice(t *testing.T) {
|
||||
var users = []User{
|
||||
*GetUser("slice-belongs-to-1", Config{Company: true, Manager: true}),
|
||||
*GetUser("slice-belongs-to-2", Config{Company: true, Manager: false}),
|
||||
*GetUser("slice-belongs-to-3", Config{Company: true, Manager: true}),
|
||||
}
|
||||
|
||||
DB.Create(&users)
|
||||
|
||||
AssertAssociationCount(t, users, "Company", 3, "")
|
||||
AssertAssociationCount(t, users, "Manager", 2, "")
|
||||
|
||||
// Find
|
||||
var companies []Company
|
||||
if DB.Model(users).Association("Company").Find(&companies); len(companies) != 3 {
|
||||
t.Errorf("companies count should be %v, but got %v", 3, len(companies))
|
||||
}
|
||||
|
||||
var managers []User
|
||||
if DB.Model(users).Association("Manager").Find(&managers); len(managers) != 2 {
|
||||
t.Errorf("managers count should be %v, but got %v", 2, len(managers))
|
||||
}
|
||||
|
||||
// Append
|
||||
|
||||
// Replace
|
||||
|
||||
// Delete
|
||||
|
||||
// Clear
|
||||
DB.Model(&users).Association("Company").Clear()
|
||||
AssertAssociationCount(t, users, "Company", 0, "After Clear")
|
||||
|
||||
DB.Model(&users).Association("Manager").Clear()
|
||||
AssertAssociationCount(t, users, "Manager", 0, "After Clear")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue