From 532e9cf4ccce927249bcb102c09e4a9093aae4fe Mon Sep 17 00:00:00 2001 From: Michael Anstis Date: Sat, 18 Feb 2023 01:06:43 +0000 Subject: [PATCH] Issue 6054: Unscoped not working with PreLoad on Joins (#6058) * Issue 6054: Unscoped not working with PreLoad on Joins * Formatting --------- Co-authored-by: Michael Anstis --- callbacks/query.go | 1 + clause/select_test.go | 12 +++++++----- migrator/migrator.go | 4 +--- model.go | 7 ++++--- schema/field.go | 2 +- schema/relationship.go | 23 +++++++++++----------- schema/serializer.go | 9 +++------ tests/connpool_test.go | 8 +++++--- tests/embedded_struct_test.go | 1 - tests/helper_test.go | 36 +++++++++++++++++++++++++++++----- tests/migrate_test.go | 3 +-- tests/preload_test.go | 37 +++++++++++++++++++++++++++++++++++ tests/table_test.go | 5 +++-- 13 files changed, 106 insertions(+), 42 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 97fe8a49..9a6d4f4a 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -257,6 +257,7 @@ func Preload(db *gorm.DB) { return } preloadDB.Statement.ReflectValue = db.Statement.ReflectValue + preloadDB.Statement.Unscoped = db.Statement.Unscoped for _, name := range preloadNames { if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { diff --git a/clause/select_test.go b/clause/select_test.go index 18bc2693..9c11b90d 100644 --- a/clause/select_test.go +++ b/clause/select_test.go @@ -49,16 +49,18 @@ func TestSelect(t *testing.T) { Exprs: []clause.Expression{ clause.Expr{ SQL: "? as name", - Vars: []interface{}{clause.Eq{ - Column: clause.Column{Name: "age"}, - Value: 18, - }, + Vars: []interface{}{ + clause.Eq{ + Column: clause.Column{Name: "age"}, + Value: 18, + }, }, }, }, }, }, clause.From{}}, - "SELECT `age` = ? as name FROM `users`", []interface{}{18}, + "SELECT `age` = ? as name FROM `users`", + []interface{}{18}, }, } diff --git a/migrator/migrator.go b/migrator/migrator.go index b8aaef2b..12c2df46 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -16,9 +16,7 @@ import ( "gorm.io/gorm/schema" ) -var ( - regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) -) +var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) // Migrator m struct type Migrator struct { diff --git a/model.go b/model.go index 3334d17c..fa705df1 100644 --- a/model.go +++ b/model.go @@ -4,9 +4,10 @@ import "time" // Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt // It may be embedded into your model or you may build your own model without it -// type User struct { -// gorm.Model -// } +// +// type User struct { +// gorm.Model +// } type Model struct { ID uint `gorm:"primarykey"` CreatedAt time.Time diff --git a/schema/field.go b/schema/field.go index 1589d984..59151878 100644 --- a/schema/field.go +++ b/schema/field.go @@ -174,7 +174,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = String field.Serializer = v } else { - var serializerName = field.TagSettings["JSON"] + serializerName := field.TagSettings["JSON"] if serializerName == "" { serializerName = field.TagSettings["SERIALIZER"] } diff --git a/schema/relationship.go b/schema/relationship.go index 9436f283..b33b94a7 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -123,16 +123,17 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { } // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` -// type User struct { -// Toys []Toy `gorm:"polymorphic:Owner;"` -// } -// type Pet struct { -// Toy Toy `gorm:"polymorphic:Owner;"` -// } -// type Toy struct { -// OwnerID int -// OwnerType string -// } +// +// type User struct { +// Toys []Toy `gorm:"polymorphic:Owner;"` +// } +// type Pet struct { +// Toy Toy `gorm:"polymorphic:Owner;"` +// } +// type Toy struct { +// OwnerID int +// OwnerType string +// } func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) { relation.Polymorphic = &Polymorphic{ Value: schema.Table, @@ -427,7 +428,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu foreignFields = append(foreignFields, f) } } else { - var primarySchemaName = primarySchema.Name + primarySchemaName := primarySchema.Name if primarySchemaName == "" { primarySchemaName = relation.FieldSchema.Name } diff --git a/schema/serializer.go b/schema/serializer.go index 9a6aa4fc..397edff0 100644 --- a/schema/serializer.go +++ b/schema/serializer.go @@ -70,8 +70,7 @@ type SerializerValuerInterface interface { } // JSONSerializer json serializer -type JSONSerializer struct { -} +type JSONSerializer struct{} // Scan implements serializer interface func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { @@ -110,8 +109,7 @@ func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value } // UnixSecondSerializer json serializer -type UnixSecondSerializer struct { -} +type UnixSecondSerializer struct{} // Scan implements serializer interface func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { @@ -141,8 +139,7 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect } // GobSerializer gob serializer -type GobSerializer struct { -} +type GobSerializer struct{} // Scan implements serializer interface func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { diff --git a/tests/connpool_test.go b/tests/connpool_test.go index 42e029bc..e0e1c771 100644 --- a/tests/connpool_test.go +++ b/tests/connpool_test.go @@ -48,9 +48,11 @@ func (c *wrapperConnPool) Ping() error { } // If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction. -// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { -// return c.db.BeginTx(ctx, opts) -// } +// +// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { +// return c.db.BeginTx(ctx, opts) +// } +// // You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries. func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) { tx, err := c.db.BeginTx(ctx, opts) diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index ae69baca..63ec53ee 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -94,7 +94,6 @@ func TestEmbeddedStruct(t *testing.T) { t.Errorf("expected author %s got %s", want, post.Author.Name) } } - } func TestEmbeddedPointerTypeStruct(t *testing.T) { diff --git a/tests/helper_test.go b/tests/helper_test.go index d40fa5ce..c34e357c 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" ) @@ -74,10 +76,18 @@ func GetUser(name string, config Config) *User { return &user } +func CheckPetUnscoped(t *testing.T, pet Pet, expect Pet) { + doCheckPet(t, pet, expect, true) +} + func CheckPet(t *testing.T, pet Pet, expect Pet) { + doCheckPet(t, pet, expect, false) +} + +func doCheckPet(t *testing.T, pet Pet, expect Pet, unscoped bool) { if pet.ID != 0 { var newPet Pet - if err := DB.Where("id = ?", pet.ID).First(&newPet).Error; err != nil { + if err := db(unscoped).Where("id = ?", pet.ID).First(&newPet).Error; err != nil { t.Fatalf("errors happened when query: %v", err) } else { AssertObjEqual(t, newPet, pet, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") @@ -94,10 +104,18 @@ func CheckPet(t *testing.T, pet Pet, expect Pet) { } } +func CheckUserUnscoped(t *testing.T, user User, expect User) { + doCheckUser(t, user, expect, true) +} + func CheckUser(t *testing.T, user User, expect User) { + doCheckUser(t, user, expect, false) +} + +func doCheckUser(t *testing.T, user User, expect User, unscoped bool) { if user.ID != 0 { var newUser User - if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil { + if err := db(unscoped).Where("id = ?", user.ID).First(&newUser).Error; err != nil { t.Fatalf("errors happened when query: %v", err) } else { AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") @@ -114,7 +132,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Errorf("Account's foreign key should be saved") } else { var account Account - DB.First(&account, "user_id = ?", user.ID) + db(unscoped).First(&account, "user_id = ?", user.ID) AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") } } @@ -137,7 +155,7 @@ func CheckUser(t *testing.T, user User, expect User) { if pet == nil || expect.Pets[idx] == nil { t.Errorf("pets#%v should equal, expect: %v, got %v", idx, expect.Pets[idx], pet) } else { - CheckPet(t, *pet, *expect.Pets[idx]) + doCheckPet(t, *pet, *expect.Pets[idx], unscoped) } } }) @@ -174,7 +192,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Errorf("Manager's foreign key should be saved") } else { var manager User - DB.First(&manager, "id = ?", *user.ManagerID) + db(unscoped).First(&manager, "id = ?", *user.ManagerID) AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") } @@ -246,3 +264,11 @@ func tidbSkip(t *testing.T, reason string) { func isTiDB() bool { return os.Getenv("GORM_DIALECT") == "tidb" } + +func db(unscoped bool) *gorm.DB { + if unscoped { + return DB.Unscoped() + } else { + return DB + } +} diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 489da976..8794ccba 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -75,7 +75,6 @@ func TestMigrate(t *testing.T) { t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1]) } } - } func TestAutoMigrateInt8PG(t *testing.T) { @@ -1267,7 +1266,7 @@ func (mm mockMigrator) AlterColumn(dst interface{}, field string) error { } func TestMigrateDonotAlterColumn(t *testing.T) { - var wrapMockMigrator = func(m gorm.Migrator) mockMigrator { + wrapMockMigrator := func(m gorm.Migrator) mockMigrator { return mockMigrator{ Migrator: m, } diff --git a/tests/preload_test.go b/tests/preload_test.go index cb4343ec..e7223b3e 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -269,3 +269,40 @@ func TestPreloadWithDiffModel(t *testing.T) { CheckUser(t, user, result.User) } + +func TestNestedPreloadWithUnscoped(t *testing.T) { + user := *GetUser("nested_preload", Config{Pets: 1}) + pet := user.Pets[0] + pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(1)} + pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(2)} + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var user2 User + DB.Preload("Pets.Toy").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + DB.Delete(&pet) + + var user3 User + DB.Preload(clause.Associations+"."+clause.Associations).Find(&user3, "id = ?", user.ID) + if len(user3.Pets) != 0 { + t.Fatalf("User.Pet[0] was deleted and should not exist.") + } + + var user4 *User + DB.Preload("Pets.Toy").Find(&user4, "id = ?", user.ID) + if len(user4.Pets) != 0 { + t.Fatalf("User.Pet[0] was deleted and should not exist.") + } + + var user5 User + DB.Unscoped().Preload(clause.Associations+"."+clause.Associations).Find(&user5, "id = ?", user.ID) + CheckUserUnscoped(t, user5, user) + + var user6 *User + DB.Unscoped().Preload("Pets.Toy").Find(&user6, "id = ?", user.ID) + CheckUserUnscoped(t, *user6, user) +} diff --git a/tests/table_test.go b/tests/table_test.go index f538c691..fa569d32 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -158,10 +158,11 @@ func (UserWithTableNamer) TableName(namer schema.Namer) string { } func TestTableWithNamer(t *testing.T) { - var db, _ = gorm.Open(tests.DummyDialector{}, &gorm.Config{ + db, _ := gorm.Open(tests.DummyDialector{}, &gorm.Config{ NamingStrategy: schema.NamingStrategy{ TablePrefix: "t_", - }}) + }, + }) sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Model(&UserWithTableNamer{}).Find(&UserWithTableNamer{})