Add Slice Association for BelongsTo

This commit is contained in:
Jinzhu 2020-05-24 20:44:37 +08:00
parent 91a695893c
commit 2db33730b6
5 changed files with 122 additions and 50 deletions

View File

@ -366,6 +366,11 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
if clear && len(values) == 0 { if clear && len(values) == 0 {
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey {
ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface())
}
}
} }
break break
} }
@ -382,6 +387,11 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
case reflect.Struct: case reflect.Struct:
if clear && len(values) == 0 { if clear && len(values) == 0 {
association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey {
ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
}
}
} }
for idx, value := range values { for idx, value := range values {
@ -392,10 +402,12 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
_, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue) _, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue)
} }
if len(values) > 0 {
if hasZero { if hasZero {
association.DB.Save(reflectValue.Addr().Interface()) association.DB.Create(reflectValue.Addr().Interface())
} else { } else {
association.DB.Select(selectedColumns).Save(reflectValue.Addr().Interface()) association.DB.Select(selectedColumns).Model(nil).Save(reflectValue.Addr().Interface())
}
} }
for _, assignBack := range assignBacks { for _, assignBack := range assignBacks {

View File

@ -173,12 +173,30 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
} }
if stmt.Dest != stmt.Model { if stmt.Dest != stmt.Model {
reflectValue := reflect.ValueOf(stmt.Model) reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Model))
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
var priamryKeyExprs []clause.Expression
for i := 0; i < reflectValue.Len(); i++ {
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
var notZero bool
for idx, field := range stmt.Schema.PrimaryFields {
value, isZero := field.ValueOf(reflectValue.Index(i))
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
notZero = notZero || !isZero
}
if notZero {
priamryKeyExprs = append(priamryKeyExprs, clause.And(exprs...))
}
}
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}})
case reflect.Struct:
for _, field := range stmt.Schema.PrimaryFields { for _, field := range stmt.Schema.PrimaryFields {
if value, isZero := field.ValueOf(reflectValue); !isZero { if value, isZero := field.ValueOf(reflectValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
} }
} }
} }
}
return return
} }

View File

@ -19,4 +19,6 @@ var (
ErrMissingWhereClause = errors.New("missing WHERE clause while deleting") ErrMissingWhereClause = errors.New("missing WHERE clause while deleting")
// ErrUnsupportedRelation unsupported relations // ErrUnsupportedRelation unsupported relations
ErrUnsupportedRelation = errors.New("unsupported relations") ErrUnsupportedRelation = errors.New("unsupported relations")
// ErrPtrStructSupported only ptr of struct supported
ErrPtrStructSupported = errors.New("only ptr of struct supported")
) )

View File

@ -23,7 +23,11 @@ func (db *DB) Save(value interface{}) (tx *DB) {
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))}
reflectValue := reflect.ValueOf(value) reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
tx.AddError(ErrPtrStructSupported)
case reflect.Struct:
for idx, pf := range tx.Statement.Schema.PrimaryFields { for idx, pf := range tx.Statement.Schema.PrimaryFields {
if pv, isZero := pf.ValueOf(reflectValue); isZero { if pv, isZero := pf.ValueOf(reflectValue); isZero {
tx.callbacks.Create().Execute(tx) tx.callbacks.Create().Execute(tx)
@ -31,6 +35,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
return return
} }
} }
}
tx.Statement.AddClause(where) tx.Statement.AddClause(where)
} }

View File

@ -6,7 +6,26 @@ import (
. "github.com/jinzhu/gorm/tests" . "github.com/jinzhu/gorm/tests"
) )
func TestAssociationForBelongsTo(t *testing.T) { func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) {
if count := DB.Model(data).Association(name).Count(); count != result {
t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count)
}
var newUser User
if user, ok := data.(User); ok {
DB.Find(&newUser, "id = ?", user.ID)
} else if user, ok := data.(*User); ok {
DB.Find(&newUser, "id = ?", user.ID)
}
if newUser.ID != 0 {
if count := DB.Model(&newUser).Association(name).Count(); count != result {
t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count)
}
}
}
func TestBelongsToAssociation(t *testing.T) {
var user = *GetUser("belongs-to", Config{Company: true, Manager: true}) var user = *GetUser("belongs-to", Config{Company: true, Manager: true})
if err := DB.Create(&user).Error; err != nil { if err := DB.Create(&user).Error; err != nil {
@ -24,13 +43,8 @@ func TestAssociationForBelongsTo(t *testing.T) {
CheckUser(t, user2, user) CheckUser(t, user2, user)
// Count // Count
if count := DB.Model(&user).Association("Company").Count(); count != 1 { AssertAssociationCount(t, user, "Company", 1, "")
t.Errorf("invalid company count, got %v", count) AssertAssociationCount(t, user, "Manager", 1, "")
}
if count := DB.Model(&user).Association("Manager").Count(); count != 1 {
t.Errorf("invalid manager count, got %v", count)
}
// Append // Append
var company = Company{Name: "company-belongs-to-append"} var company = Company{Name: "company-belongs-to-append"}
@ -58,6 +72,9 @@ func TestAssociationForBelongsTo(t *testing.T) {
user.ManagerID = &manager.ID user.ManagerID = &manager.ID
CheckUser(t, user2, user) CheckUser(t, user2, user)
AssertAssociationCount(t, user2, "Company", 1, "AfterAppend")
AssertAssociationCount(t, user2, "Manager", 1, "AfterAppend")
// Replace // Replace
var company2 = Company{Name: "company-belongs-to-replace"} var company2 = Company{Name: "company-belongs-to-replace"}
var manager2 = GetUser("manager-belongs-to-replace", Config{}) var manager2 = GetUser("manager-belongs-to-replace", Config{})
@ -84,40 +101,31 @@ func TestAssociationForBelongsTo(t *testing.T) {
user.ManagerID = &manager2.ID user.ManagerID = &manager2.ID
CheckUser(t, user2, user) CheckUser(t, user2, user)
AssertAssociationCount(t, user2, "Company", 1, "AfterReplace")
AssertAssociationCount(t, user2, "Manager", 1, "AfterReplace")
// Delete // Delete
if err := DB.Model(&user2).Association("Company").Delete(&Company{}); err != nil { if err := DB.Model(&user2).Association("Company").Delete(&Company{}); err != nil {
t.Fatalf("Error happened when delete Company, got %v", err) t.Fatalf("Error happened when delete Company, got %v", err)
} }
AssertAssociationCount(t, user2, "Company", 1, "after delete non-existing data")
if count := DB.Model(&user2).Association("Company").Count(); count != 1 {
t.Errorf("Invalid company count after delete non-existing association, got %v", count)
}
if err := DB.Model(&user2).Association("Company").Delete(&company2); err != nil { if err := DB.Model(&user2).Association("Company").Delete(&company2); err != nil {
t.Fatalf("Error happened when delete Company, got %v", err) t.Fatalf("Error happened when delete Company, got %v", err)
} }
AssertAssociationCount(t, user2, "Company", 0, "after delete")
if count := DB.Model(&user2).Association("Company").Count(); count != 0 {
t.Errorf("Invalid company count after delete, got %v", count)
}
if err := DB.Model(&user2).Association("Manager").Delete(&User{}); err != nil { if err := DB.Model(&user2).Association("Manager").Delete(&User{}); err != nil {
t.Fatalf("Error happened when delete Manager, got %v", err) t.Fatalf("Error happened when delete Manager, got %v", err)
} }
AssertAssociationCount(t, user2, "Manager", 1, "after delete non-existing data")
if count := DB.Model(&user2).Association("Manager").Count(); count != 1 {
t.Errorf("Invalid manager count after delete non-existing association, got %v", count)
}
if err := DB.Model(&user2).Association("Manager").Delete(manager2); err != nil { if err := DB.Model(&user2).Association("Manager").Delete(manager2); err != nil {
t.Fatalf("Error happened when delete Manager, got %v", err) t.Fatalf("Error happened when delete Manager, got %v", err)
} }
AssertAssociationCount(t, user2, "Manager", 0, "after delete")
if count := DB.Model(&user2).Association("Manager").Count(); count != 0 { // Prepare Data for Clear
t.Errorf("Invalid manager count after delete, got %v", count)
}
// Prepare Data
if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { if err := DB.Model(&user2).Association("Company").Append(&company); err != nil {
t.Fatalf("Error happened when append Company, got %v", err) t.Fatalf("Error happened when append Company, got %v", err)
} }
@ -126,13 +134,8 @@ func TestAssociationForBelongsTo(t *testing.T) {
t.Fatalf("Error happened when append Manager, got %v", err) t.Fatalf("Error happened when append Manager, got %v", err)
} }
if count := DB.Model(&user2).Association("Company").Count(); count != 1 { AssertAssociationCount(t, user2, "Company", 1, "after prepare data")
t.Errorf("Invalid company count after append, got %v", count) AssertAssociationCount(t, user2, "Manager", 1, "after prepare data")
}
if count := DB.Model(&user2).Association("Manager").Count(); count != 1 {
t.Errorf("Invalid manager count after append, got %v", count)
}
// Clear // Clear
if err := DB.Model(&user2).Association("Company").Clear(); err != nil { if err := DB.Model(&user2).Association("Company").Clear(); err != nil {
@ -143,11 +146,43 @@ func TestAssociationForBelongsTo(t *testing.T) {
t.Errorf("Error happened when clear Manager, got %v", err) t.Errorf("Error happened when clear Manager, got %v", err)
} }
if count := DB.Model(&user2).Association("Company").Count(); count != 0 { AssertAssociationCount(t, user2, "Company", 0, "after clear")
t.Errorf("Invalid company count after clear, got %v", count) AssertAssociationCount(t, user2, "Manager", 0, "after clear")
} }
if count := DB.Model(&user2).Association("Manager").Count(); count != 0 { func TestBelongsToAssociationForSlice(t *testing.T) {
t.Errorf("Invalid manager count after clear, got %v", count) var users = []User{
*GetUser("slice-belongs-to-1", Config{Company: true, Manager: true}),
*GetUser("slice-belongs-to-2", Config{Company: true, Manager: false}),
*GetUser("slice-belongs-to-3", Config{Company: true, Manager: true}),
} }
DB.Create(&users)
AssertAssociationCount(t, users, "Company", 3, "")
AssertAssociationCount(t, users, "Manager", 2, "")
// Find
var companies []Company
if DB.Model(users).Association("Company").Find(&companies); len(companies) != 3 {
t.Errorf("companies count should be %v, but got %v", 3, len(companies))
}
var managers []User
if DB.Model(users).Association("Manager").Find(&managers); len(managers) != 2 {
t.Errorf("managers count should be %v, but got %v", 2, len(managers))
}
// Append
// Replace
// Delete
// Clear
DB.Model(&users).Association("Company").Clear()
AssertAssociationCount(t, users, "Company", 0, "After Clear")
DB.Model(&users).Association("Manager").Clear()
AssertAssociationCount(t, users, "Manager", 0, "After Clear")
} }