From 6befa0c947e0107f241663e4312a74bddd0a4ffe Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 17 Mar 2022 11:22:25 +0800 Subject: [PATCH] Refactor preload error check --- callbacks/query.go | 5 +++++ finisher_api.go | 4 ---- tests/count_test.go | 14 +++++++++++--- tests/go.mod | 2 +- 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 03798859..04f35c7e 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -186,6 +186,11 @@ func BuildQuerySQL(db *gorm.DB) { func Preload(db *gorm.DB) { if db.Error == nil && len(db.Statement.Preloads) > 0 { + if db.Statement.Schema == nil { + db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired)) + return + } + preloadMap := map[string]map[string][]interface{}{} for name := range db.Statement.Preloads { preloadFields := strings.Split(name, ".") diff --git a/finisher_api.go b/finisher_api.go index 4b428a59..b4d29b71 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -369,10 +369,6 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Count(count *int64) (tx *DB) { tx = db.getInstance() - if len(tx.Statement.Preloads) > 0 { - tx.AddError(ErrPreloadNotAllowed) - return - } if tx.Statement.Model == nil { tx.Statement.Model = tx.Statement.Dest defer func() { diff --git a/tests/count_test.go b/tests/count_test.go index b63a55fc..b71e3de5 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -150,8 +150,16 @@ func TestCount(t *testing.T) { Where("name in ?", []string{user1.Name, user2.Name, user3.Name}). Preload("Toys", func(db *gorm.DB) *gorm.DB { return db.Table("toys").Select("name") - }).Count(&count12).Error; err != gorm.ErrPreloadNotAllowed { - t.Errorf("should returns preload not allowed error, but got %v", err) + }).Count(&count12).Error; err == nil { + t.Errorf("error should raise when using preload without schema") + } + + var count13 int64 + if err := DB.Model(User{}). + Where("name in ?", []string{user1.Name, user2.Name, user3.Name}). + Preload("Toys", func(db *gorm.DB) *gorm.DB { + return db.Table("toys").Select("name") + }).Count(&count13).Error; err != nil { + t.Errorf("no error should raise when using count with preload, but got %v", err) } - } diff --git a/tests/go.mod b/tests/go.mod index c65ea953..4ef7fbe2 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.12 // indirect - golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect + golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd // indirect gorm.io/driver/mysql v1.3.2 gorm.io/driver/postgres v1.3.1 gorm.io/driver/sqlite v1.3.1