From 109f239fae95b310f64236bcc0fbe87bbfb1edfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9rgio=20Prata=20Almeida?= <2896760+sprataa@users.noreply.github.com> Date: Mon, 17 Jun 2024 04:59:06 +0100 Subject: [PATCH] add DB level propagation for the Unscoped flag (#7007) * adds PropagateUnscoped to db Config * adds PropagateUnscoped test * adds PropagateUnscoped to Session and sets it accordingly --- gorm.go | 10 ++++++++++ tests/hooks_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/gorm.go b/gorm.go index 775cd3de..117d2fd0 100644 --- a/gorm.go +++ b/gorm.go @@ -50,6 +50,8 @@ type Config struct { CreateBatchSize int // TranslateError enabling error translation TranslateError bool + // PropagateUnscoped propagate Unscoped to every other nested statement + PropagateUnscoped bool // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder @@ -110,6 +112,7 @@ type Session struct { DisableNestedTransaction bool AllowGlobalUpdate bool FullSaveAssociations bool + PropagateUnscoped bool QueryFields bool Context context.Context Logger logger.Interface @@ -241,6 +244,10 @@ func (db *DB) Session(config *Session) *DB { txConfig.FullSaveAssociations = true } + if config.PropagateUnscoped { + txConfig.PropagateUnscoped = true + } + if config.Context != nil || config.PrepareStmt || config.SkipHooks { tx.Statement = tx.Statement.clone() tx.Statement.DB = tx @@ -409,6 +416,9 @@ func (db *DB) getInstance() *DB { Vars: make([]interface{}, 0, 8), SkipHooks: db.Statement.SkipHooks, } + if db.Config.PropagateUnscoped { + tx.Statement.Unscoped = db.Statement.Unscoped + } } else { // with clone statement tx.Statement = db.Statement.clone() diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 0753dd0b..04f62bde 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -2,6 +2,8 @@ package tests_test import ( "errors" + "log" + "os" "reflect" "strings" "testing" @@ -566,3 +568,44 @@ func TestUpdateCallbacks(t *testing.T) { t.Fatalf("before update should not be called") } } + +type Product6 struct { + gorm.Model + Name string + Item *ProductItem2 +} + +type ProductItem2 struct { + gorm.Model + Product6ID uint +} + +func (p *Product6) BeforeDelete(tx *gorm.DB) error { + if err := tx.Delete(&p.Item).Error; err != nil { + return err + } + return nil +} + +func TestPropagateUnscoped(t *testing.T) { + _DB, err := OpenTestConnection(&gorm.Config{ + PropagateUnscoped: true, + }) + if err != nil { + log.Printf("failed to connect database, got error %v", err) + os.Exit(1) + } + + _DB.Migrator().DropTable(&Product6{}, &ProductItem2{}) + _DB.AutoMigrate(&Product6{}, &ProductItem2{}) + + p := Product6{ + Name: "unique_code", + Item: &ProductItem2{}, + } + _DB.Model(&Product6{}).Create(&p) + + if err := _DB.Unscoped().Delete(&p).Error; err != nil { + t.Fatalf("unscoped did not propagate") + } +}