mirror of https://github.com/go-gorm/gorm.git
Add Slice Association for BelongsTo
This commit is contained in:
parent
91a695893c
commit
2db33730b6
|
@ -366,6 +366,11 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||||
if clear && len(values) == 0 {
|
if clear && len(values) == 0 {
|
||||||
for i := 0; i < reflectValue.Len(); i++ {
|
for i := 0; i < reflectValue.Len(); i++ {
|
||||||
association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
|
association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
|
||||||
|
for _, ref := range association.Relationship.References {
|
||||||
|
if !ref.OwnPrimaryKey {
|
||||||
|
ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -382,6 +387,11 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if clear && len(values) == 0 {
|
if clear && len(values) == 0 {
|
||||||
association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
|
association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
|
||||||
|
for _, ref := range association.Relationship.References {
|
||||||
|
if !ref.OwnPrimaryKey {
|
||||||
|
ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, value := range values {
|
for idx, value := range values {
|
||||||
|
@ -392,10 +402,12 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||||
_, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue)
|
_, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(values) > 0 {
|
||||||
if hasZero {
|
if hasZero {
|
||||||
association.DB.Save(reflectValue.Addr().Interface())
|
association.DB.Create(reflectValue.Addr().Interface())
|
||||||
} else {
|
} else {
|
||||||
association.DB.Select(selectedColumns).Save(reflectValue.Addr().Interface())
|
association.DB.Select(selectedColumns).Model(nil).Save(reflectValue.Addr().Interface())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, assignBack := range assignBacks {
|
for _, assignBack := range assignBacks {
|
||||||
|
|
|
@ -173,12 +173,30 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if stmt.Dest != stmt.Model {
|
if stmt.Dest != stmt.Model {
|
||||||
reflectValue := reflect.ValueOf(stmt.Model)
|
reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Model))
|
||||||
|
switch reflectValue.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
var priamryKeyExprs []clause.Expression
|
||||||
|
for i := 0; i < reflectValue.Len(); i++ {
|
||||||
|
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
|
||||||
|
var notZero bool
|
||||||
|
for idx, field := range stmt.Schema.PrimaryFields {
|
||||||
|
value, isZero := field.ValueOf(reflectValue.Index(i))
|
||||||
|
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
|
||||||
|
notZero = notZero || !isZero
|
||||||
|
}
|
||||||
|
if notZero {
|
||||||
|
priamryKeyExprs = append(priamryKeyExprs, clause.And(exprs...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}})
|
||||||
|
case reflect.Struct:
|
||||||
for _, field := range stmt.Schema.PrimaryFields {
|
for _, field := range stmt.Schema.PrimaryFields {
|
||||||
if value, isZero := field.ValueOf(reflectValue); !isZero {
|
if value, isZero := field.ValueOf(reflectValue); !isZero {
|
||||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,4 +19,6 @@ var (
|
||||||
ErrMissingWhereClause = errors.New("missing WHERE clause while deleting")
|
ErrMissingWhereClause = errors.New("missing WHERE clause while deleting")
|
||||||
// ErrUnsupportedRelation unsupported relations
|
// ErrUnsupportedRelation unsupported relations
|
||||||
ErrUnsupportedRelation = errors.New("unsupported relations")
|
ErrUnsupportedRelation = errors.New("unsupported relations")
|
||||||
|
// ErrPtrStructSupported only ptr of struct supported
|
||||||
|
ErrPtrStructSupported = errors.New("only ptr of struct supported")
|
||||||
)
|
)
|
||||||
|
|
|
@ -23,7 +23,11 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
||||||
|
|
||||||
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
|
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
|
||||||
where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))}
|
where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))}
|
||||||
reflectValue := reflect.ValueOf(value)
|
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||||
|
switch reflectValue.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
tx.AddError(ErrPtrStructSupported)
|
||||||
|
case reflect.Struct:
|
||||||
for idx, pf := range tx.Statement.Schema.PrimaryFields {
|
for idx, pf := range tx.Statement.Schema.PrimaryFields {
|
||||||
if pv, isZero := pf.ValueOf(reflectValue); isZero {
|
if pv, isZero := pf.ValueOf(reflectValue); isZero {
|
||||||
tx.callbacks.Create().Execute(tx)
|
tx.callbacks.Create().Execute(tx)
|
||||||
|
@ -31,6 +35,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
tx.Statement.AddClause(where)
|
tx.Statement.AddClause(where)
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,26 @@ import (
|
||||||
. "github.com/jinzhu/gorm/tests"
|
. "github.com/jinzhu/gorm/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAssociationForBelongsTo(t *testing.T) {
|
func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) {
|
||||||
|
if count := DB.Model(data).Association(name).Count(); count != result {
|
||||||
|
t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
var newUser User
|
||||||
|
if user, ok := data.(User); ok {
|
||||||
|
DB.Find(&newUser, "id = ?", user.ID)
|
||||||
|
} else if user, ok := data.(*User); ok {
|
||||||
|
DB.Find(&newUser, "id = ?", user.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if newUser.ID != 0 {
|
||||||
|
if count := DB.Model(&newUser).Association(name).Count(); count != result {
|
||||||
|
t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBelongsToAssociation(t *testing.T) {
|
||||||
var user = *GetUser("belongs-to", Config{Company: true, Manager: true})
|
var user = *GetUser("belongs-to", Config{Company: true, Manager: true})
|
||||||
|
|
||||||
if err := DB.Create(&user).Error; err != nil {
|
if err := DB.Create(&user).Error; err != nil {
|
||||||
|
@ -24,13 +43,8 @@ func TestAssociationForBelongsTo(t *testing.T) {
|
||||||
CheckUser(t, user2, user)
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
// Count
|
// Count
|
||||||
if count := DB.Model(&user).Association("Company").Count(); count != 1 {
|
AssertAssociationCount(t, user, "Company", 1, "")
|
||||||
t.Errorf("invalid company count, got %v", count)
|
AssertAssociationCount(t, user, "Manager", 1, "")
|
||||||
}
|
|
||||||
|
|
||||||
if count := DB.Model(&user).Association("Manager").Count(); count != 1 {
|
|
||||||
t.Errorf("invalid manager count, got %v", count)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Append
|
// Append
|
||||||
var company = Company{Name: "company-belongs-to-append"}
|
var company = Company{Name: "company-belongs-to-append"}
|
||||||
|
@ -58,6 +72,9 @@ func TestAssociationForBelongsTo(t *testing.T) {
|
||||||
user.ManagerID = &manager.ID
|
user.ManagerID = &manager.ID
|
||||||
CheckUser(t, user2, user)
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
AssertAssociationCount(t, user2, "Company", 1, "AfterAppend")
|
||||||
|
AssertAssociationCount(t, user2, "Manager", 1, "AfterAppend")
|
||||||
|
|
||||||
// Replace
|
// Replace
|
||||||
var company2 = Company{Name: "company-belongs-to-replace"}
|
var company2 = Company{Name: "company-belongs-to-replace"}
|
||||||
var manager2 = GetUser("manager-belongs-to-replace", Config{})
|
var manager2 = GetUser("manager-belongs-to-replace", Config{})
|
||||||
|
@ -84,40 +101,31 @@ func TestAssociationForBelongsTo(t *testing.T) {
|
||||||
user.ManagerID = &manager2.ID
|
user.ManagerID = &manager2.ID
|
||||||
CheckUser(t, user2, user)
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
AssertAssociationCount(t, user2, "Company", 1, "AfterReplace")
|
||||||
|
AssertAssociationCount(t, user2, "Manager", 1, "AfterReplace")
|
||||||
|
|
||||||
// Delete
|
// Delete
|
||||||
if err := DB.Model(&user2).Association("Company").Delete(&Company{}); err != nil {
|
if err := DB.Model(&user2).Association("Company").Delete(&Company{}); err != nil {
|
||||||
t.Fatalf("Error happened when delete Company, got %v", err)
|
t.Fatalf("Error happened when delete Company, got %v", err)
|
||||||
}
|
}
|
||||||
|
AssertAssociationCount(t, user2, "Company", 1, "after delete non-existing data")
|
||||||
if count := DB.Model(&user2).Association("Company").Count(); count != 1 {
|
|
||||||
t.Errorf("Invalid company count after delete non-existing association, got %v", count)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := DB.Model(&user2).Association("Company").Delete(&company2); err != nil {
|
if err := DB.Model(&user2).Association("Company").Delete(&company2); err != nil {
|
||||||
t.Fatalf("Error happened when delete Company, got %v", err)
|
t.Fatalf("Error happened when delete Company, got %v", err)
|
||||||
}
|
}
|
||||||
|
AssertAssociationCount(t, user2, "Company", 0, "after delete")
|
||||||
if count := DB.Model(&user2).Association("Company").Count(); count != 0 {
|
|
||||||
t.Errorf("Invalid company count after delete, got %v", count)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := DB.Model(&user2).Association("Manager").Delete(&User{}); err != nil {
|
if err := DB.Model(&user2).Association("Manager").Delete(&User{}); err != nil {
|
||||||
t.Fatalf("Error happened when delete Manager, got %v", err)
|
t.Fatalf("Error happened when delete Manager, got %v", err)
|
||||||
}
|
}
|
||||||
|
AssertAssociationCount(t, user2, "Manager", 1, "after delete non-existing data")
|
||||||
if count := DB.Model(&user2).Association("Manager").Count(); count != 1 {
|
|
||||||
t.Errorf("Invalid manager count after delete non-existing association, got %v", count)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := DB.Model(&user2).Association("Manager").Delete(manager2); err != nil {
|
if err := DB.Model(&user2).Association("Manager").Delete(manager2); err != nil {
|
||||||
t.Fatalf("Error happened when delete Manager, got %v", err)
|
t.Fatalf("Error happened when delete Manager, got %v", err)
|
||||||
}
|
}
|
||||||
|
AssertAssociationCount(t, user2, "Manager", 0, "after delete")
|
||||||
|
|
||||||
if count := DB.Model(&user2).Association("Manager").Count(); count != 0 {
|
// Prepare Data for Clear
|
||||||
t.Errorf("Invalid manager count after delete, got %v", count)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare Data
|
|
||||||
if err := DB.Model(&user2).Association("Company").Append(&company); err != nil {
|
if err := DB.Model(&user2).Association("Company").Append(&company); err != nil {
|
||||||
t.Fatalf("Error happened when append Company, got %v", err)
|
t.Fatalf("Error happened when append Company, got %v", err)
|
||||||
}
|
}
|
||||||
|
@ -126,13 +134,8 @@ func TestAssociationForBelongsTo(t *testing.T) {
|
||||||
t.Fatalf("Error happened when append Manager, got %v", err)
|
t.Fatalf("Error happened when append Manager, got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if count := DB.Model(&user2).Association("Company").Count(); count != 1 {
|
AssertAssociationCount(t, user2, "Company", 1, "after prepare data")
|
||||||
t.Errorf("Invalid company count after append, got %v", count)
|
AssertAssociationCount(t, user2, "Manager", 1, "after prepare data")
|
||||||
}
|
|
||||||
|
|
||||||
if count := DB.Model(&user2).Association("Manager").Count(); count != 1 {
|
|
||||||
t.Errorf("Invalid manager count after append, got %v", count)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear
|
// Clear
|
||||||
if err := DB.Model(&user2).Association("Company").Clear(); err != nil {
|
if err := DB.Model(&user2).Association("Company").Clear(); err != nil {
|
||||||
|
@ -143,11 +146,43 @@ func TestAssociationForBelongsTo(t *testing.T) {
|
||||||
t.Errorf("Error happened when clear Manager, got %v", err)
|
t.Errorf("Error happened when clear Manager, got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if count := DB.Model(&user2).Association("Company").Count(); count != 0 {
|
AssertAssociationCount(t, user2, "Company", 0, "after clear")
|
||||||
t.Errorf("Invalid company count after clear, got %v", count)
|
AssertAssociationCount(t, user2, "Manager", 0, "after clear")
|
||||||
}
|
}
|
||||||
|
|
||||||
if count := DB.Model(&user2).Association("Manager").Count(); count != 0 {
|
func TestBelongsToAssociationForSlice(t *testing.T) {
|
||||||
t.Errorf("Invalid manager count after clear, got %v", count)
|
var users = []User{
|
||||||
|
*GetUser("slice-belongs-to-1", Config{Company: true, Manager: true}),
|
||||||
|
*GetUser("slice-belongs-to-2", Config{Company: true, Manager: false}),
|
||||||
|
*GetUser("slice-belongs-to-3", Config{Company: true, Manager: true}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DB.Create(&users)
|
||||||
|
|
||||||
|
AssertAssociationCount(t, users, "Company", 3, "")
|
||||||
|
AssertAssociationCount(t, users, "Manager", 2, "")
|
||||||
|
|
||||||
|
// Find
|
||||||
|
var companies []Company
|
||||||
|
if DB.Model(users).Association("Company").Find(&companies); len(companies) != 3 {
|
||||||
|
t.Errorf("companies count should be %v, but got %v", 3, len(companies))
|
||||||
|
}
|
||||||
|
|
||||||
|
var managers []User
|
||||||
|
if DB.Model(users).Association("Manager").Find(&managers); len(managers) != 2 {
|
||||||
|
t.Errorf("managers count should be %v, but got %v", 2, len(managers))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append
|
||||||
|
|
||||||
|
// Replace
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
|
||||||
|
// Clear
|
||||||
|
DB.Model(&users).Association("Company").Clear()
|
||||||
|
AssertAssociationCount(t, users, "Company", 0, "After Clear")
|
||||||
|
|
||||||
|
DB.Model(&users).Association("Manager").Clear()
|
||||||
|
AssertAssociationCount(t, users, "Manager", 0, "After Clear")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue