mirror of https://github.com/go-gorm/gorm.git
* feat: unscoped association (#5899) * modify name because mysql character is latin1 * work only on has association * format * Unscoped on belongs_to association
This commit is contained in:
parent
67642abfff
commit
32045fdd7d
|
@ -14,6 +14,7 @@ import (
|
|||
type Association struct {
|
||||
DB *DB
|
||||
Relationship *schema.Relationship
|
||||
Unscope bool
|
||||
Error error
|
||||
}
|
||||
|
||||
|
@ -40,6 +41,15 @@ func (db *DB) Association(column string) *Association {
|
|||
return association
|
||||
}
|
||||
|
||||
func (association *Association) Unscoped() *Association {
|
||||
return &Association{
|
||||
DB: association.DB,
|
||||
Relationship: association.Relationship,
|
||||
Error: association.Error,
|
||||
Unscope: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (association *Association) Find(out interface{}, conds ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
association.Error = association.buildCondition().Find(out, conds...).Error
|
||||
|
@ -64,14 +74,30 @@ func (association *Association) Append(values ...interface{}) error {
|
|||
|
||||
func (association *Association) Replace(values ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
reflectValue := association.DB.Statement.ReflectValue
|
||||
rel := association.Relationship
|
||||
|
||||
var oldBelongsToExpr clause.Expression
|
||||
// we have to record the old BelongsTo value
|
||||
if association.Unscope && rel.Type == schema.BelongsTo {
|
||||
var foreignFields []*schema.Field
|
||||
for _, ref := range rel.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
foreignFields = append(foreignFields, ref.ForeignKey)
|
||||
}
|
||||
}
|
||||
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
|
||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
|
||||
oldBelongsToExpr = clause.IN{Column: column, Values: values}
|
||||
}
|
||||
}
|
||||
|
||||
// save associations
|
||||
if association.saveAssociation( /*clear*/ true, values...); association.Error != nil {
|
||||
return association.Error
|
||||
}
|
||||
|
||||
// set old associations's foreign key to null
|
||||
reflectValue := association.DB.Statement.ReflectValue
|
||||
rel := association.Relationship
|
||||
switch rel.Type {
|
||||
case schema.BelongsTo:
|
||||
if len(values) == 0 {
|
||||
|
@ -91,6 +117,9 @@ func (association *Association) Replace(values ...interface{}) error {
|
|||
|
||||
association.Error = association.DB.UpdateColumns(updateMap).Error
|
||||
}
|
||||
if association.Unscope && oldBelongsToExpr != nil {
|
||||
association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
|
||||
}
|
||||
case schema.HasOne, schema.HasMany:
|
||||
var (
|
||||
primaryFields []*schema.Field
|
||||
|
@ -119,7 +148,11 @@ func (association *Association) Replace(values ...interface{}) error {
|
|||
|
||||
if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
|
||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
|
||||
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
|
||||
if association.Unscope {
|
||||
association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error
|
||||
} else {
|
||||
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
|
||||
}
|
||||
}
|
||||
case schema.Many2Many:
|
||||
var (
|
||||
|
@ -184,7 +217,8 @@ func (association *Association) Delete(values ...interface{}) error {
|
|||
|
||||
switch rel.Type {
|
||||
case schema.BelongsTo:
|
||||
tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
|
||||
associationDB := association.DB.Session(&Session{})
|
||||
tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface())
|
||||
|
||||
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
|
||||
if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 {
|
||||
|
@ -198,8 +232,21 @@ func (association *Association) Delete(values ...interface{}) error {
|
|||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||
|
||||
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
||||
if association.Unscope {
|
||||
var foreignFields []*schema.Field
|
||||
for _, ref := range rel.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
foreignFields = append(foreignFields, ref.ForeignKey)
|
||||
}
|
||||
}
|
||||
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
|
||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
|
||||
association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
|
||||
}
|
||||
}
|
||||
case schema.HasOne, schema.HasMany:
|
||||
tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
|
||||
model := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||
tx := association.DB.Model(model)
|
||||
|
||||
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
|
||||
if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 {
|
||||
|
@ -212,7 +259,11 @@ func (association *Association) Delete(values ...interface{}) error {
|
|||
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
|
||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||
|
||||
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
||||
if association.Unscope {
|
||||
association.Error = tx.Clauses(conds...).Delete(model).Error
|
||||
} else {
|
||||
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
||||
}
|
||||
case schema.Many2Many:
|
||||
var (
|
||||
primaryFields, relPrimaryFields []*schema.Field
|
||||
|
|
|
@ -251,3 +251,58 @@ func TestBelongsToDefaultValue(t *testing.T) {
|
|||
err := DB.Create(&user).Error
|
||||
AssertEqual(t, err, nil)
|
||||
}
|
||||
|
||||
func TestBelongsToAssociationUnscoped(t *testing.T) {
|
||||
type ItemParent struct {
|
||||
gorm.Model
|
||||
Logo string `gorm:"not null;type:varchar(50)"`
|
||||
}
|
||||
type ItemChild struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"type:varchar(50)"`
|
||||
ItemParentID uint
|
||||
ItemParent ItemParent
|
||||
}
|
||||
|
||||
tx := DB.Session(&gorm.Session{})
|
||||
tx.Migrator().DropTable(&ItemParent{}, &ItemChild{})
|
||||
tx.AutoMigrate(&ItemParent{}, &ItemChild{})
|
||||
|
||||
item := ItemChild{
|
||||
Name: "name",
|
||||
ItemParent: ItemParent{
|
||||
Logo: "logo",
|
||||
},
|
||||
}
|
||||
if err := tx.Create(&item).Error; err != nil {
|
||||
t.Fatalf("failed to create items, got error: %v", err)
|
||||
}
|
||||
|
||||
tx = tx.Debug()
|
||||
|
||||
// test replace
|
||||
if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{
|
||||
Logo: "updated logo",
|
||||
}); err != nil {
|
||||
t.Errorf("failed to replace item parent, got error: %v", err)
|
||||
}
|
||||
|
||||
var parents []ItemParent
|
||||
if err := tx.Find(&parents).Error; err != nil {
|
||||
t.Errorf("failed to find item parent, got error: %v", err)
|
||||
}
|
||||
if len(parents) != 1 {
|
||||
t.Errorf("expected %d parents, got %d", 1, len(parents))
|
||||
}
|
||||
|
||||
// test delete
|
||||
if err := tx.Model(&item).Association("ItemParent").Unscoped().Delete(&parents); err != nil {
|
||||
t.Errorf("failed to delete item parent, got error: %v", err)
|
||||
}
|
||||
if err := tx.Find(&parents).Error; err != nil {
|
||||
t.Errorf("failed to find item parent, got error: %v", err)
|
||||
}
|
||||
if len(parents) != 0 {
|
||||
t.Errorf("expected %d parents, got %d", 0, len(parents))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
|
@ -471,3 +472,76 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
|
|||
DB.Model(&users).Association("Toys").Clear()
|
||||
AssertAssociationCount(t, users, "Toys", 0, "After Clear")
|
||||
}
|
||||
|
||||
func TestHasManyAssociationUnscoped(t *testing.T) {
|
||||
type ItemContent struct {
|
||||
gorm.Model
|
||||
ItemID uint `gorm:"not null"`
|
||||
Name string `gorm:"not null;type:varchar(50)"`
|
||||
LanguageCode string `gorm:"not null;type:varchar(2)"`
|
||||
}
|
||||
type Item struct {
|
||||
gorm.Model
|
||||
Logo string `gorm:"not null;type:varchar(50)"`
|
||||
Contents []ItemContent `gorm:"foreignKey:ItemID"`
|
||||
}
|
||||
|
||||
tx := DB.Session(&gorm.Session{})
|
||||
tx.Migrator().DropTable(&ItemContent{}, &Item{})
|
||||
tx.AutoMigrate(&ItemContent{}, &Item{})
|
||||
|
||||
item := Item{
|
||||
Logo: "logo",
|
||||
Contents: []ItemContent{
|
||||
{Name: "name", LanguageCode: "en"},
|
||||
{Name: "ar name", LanguageCode: "ar"},
|
||||
},
|
||||
}
|
||||
if err := tx.Create(&item).Error; err != nil {
|
||||
t.Fatalf("failed to create items, got error: %v", err)
|
||||
}
|
||||
|
||||
// test Replace
|
||||
if err := tx.Model(&item).Association("Contents").Unscoped().Replace([]ItemContent{
|
||||
{Name: "updated name", LanguageCode: "en"},
|
||||
{Name: "ar updated name", LanguageCode: "ar"},
|
||||
{Name: "le nom", LanguageCode: "fr"},
|
||||
}); err != nil {
|
||||
t.Errorf("failed to replace item content, got error: %v", err)
|
||||
}
|
||||
|
||||
if count := tx.Model(&item).Association("Contents").Count(); count != 3 {
|
||||
t.Errorf("expected %d contents, got %d", 3, count)
|
||||
}
|
||||
|
||||
var contents []ItemContent
|
||||
if err := tx.Find(&contents).Error; err != nil {
|
||||
t.Errorf("failed to find contents, got error: %v", err)
|
||||
}
|
||||
if len(contents) != 3 {
|
||||
t.Errorf("expected %d contents, got %d", 3, len(contents))
|
||||
}
|
||||
|
||||
// test delete
|
||||
if err := tx.Model(&item).Association("Contents").Unscoped().Delete(&contents[0]); err != nil {
|
||||
t.Errorf("failed to delete Contents, got error: %v", err)
|
||||
}
|
||||
if count := tx.Model(&item).Association("Contents").Count(); count != 2 {
|
||||
t.Errorf("expected %d contents, got %d", 2, count)
|
||||
}
|
||||
|
||||
// test clear
|
||||
if err := tx.Model(&item).Association("Contents").Unscoped().Clear(); err != nil {
|
||||
t.Errorf("failed to clear contents association, got error: %v", err)
|
||||
}
|
||||
if count := tx.Model(&item).Association("Contents").Count(); count != 0 {
|
||||
t.Errorf("expected %d contents, got %d", 0, count)
|
||||
}
|
||||
|
||||
if err := tx.Find(&contents).Error; err != nil {
|
||||
t.Errorf("failed to find contents, got error: %v", err)
|
||||
}
|
||||
if len(contents) != 0 {
|
||||
t.Errorf("expected %d contents, got %d", 0, len(contents))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue