Refactor preload error check

This commit is contained in:
Jinzhu 2022-03-17 11:22:25 +08:00
parent 61b4c31236
commit 6befa0c947
4 changed files with 17 additions and 8 deletions

View File

@ -186,6 +186,11 @@ func BuildQuerySQL(db *gorm.DB) {
func Preload(db *gorm.DB) { func Preload(db *gorm.DB) {
if db.Error == nil && len(db.Statement.Preloads) > 0 { if db.Error == nil && len(db.Statement.Preloads) > 0 {
if db.Statement.Schema == nil {
db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired))
return
}
preloadMap := map[string]map[string][]interface{}{} preloadMap := map[string]map[string][]interface{}{}
for name := range db.Statement.Preloads { for name := range db.Statement.Preloads {
preloadFields := strings.Split(name, ".") preloadFields := strings.Split(name, ".")

View File

@ -369,10 +369,6 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
func (db *DB) Count(count *int64) (tx *DB) { func (db *DB) Count(count *int64) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if len(tx.Statement.Preloads) > 0 {
tx.AddError(ErrPreloadNotAllowed)
return
}
if tx.Statement.Model == nil { if tx.Statement.Model == nil {
tx.Statement.Model = tx.Statement.Dest tx.Statement.Model = tx.Statement.Dest
defer func() { defer func() {

View File

@ -150,8 +150,16 @@ func TestCount(t *testing.T) {
Where("name in ?", []string{user1.Name, user2.Name, user3.Name}). Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).
Preload("Toys", func(db *gorm.DB) *gorm.DB { Preload("Toys", func(db *gorm.DB) *gorm.DB {
return db.Table("toys").Select("name") return db.Table("toys").Select("name")
}).Count(&count12).Error; err != gorm.ErrPreloadNotAllowed { }).Count(&count12).Error; err == nil {
t.Errorf("should returns preload not allowed error, but got %v", err) t.Errorf("error should raise when using preload without schema")
}
var count13 int64
if err := DB.Model(User{}).
Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).
Preload("Toys", func(db *gorm.DB) *gorm.DB {
return db.Table("toys").Select("name")
}).Count(&count13).Error; err != nil {
t.Errorf("no error should raise when using count with preload, but got %v", err)
} }
} }

View File

@ -9,7 +9,7 @@ require (
github.com/jinzhu/now v1.1.4 github.com/jinzhu/now v1.1.4
github.com/lib/pq v1.10.4 github.com/lib/pq v1.10.4
github.com/mattn/go-sqlite3 v1.14.12 // indirect github.com/mattn/go-sqlite3 v1.14.12 // indirect
golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd // indirect
gorm.io/driver/mysql v1.3.2 gorm.io/driver/mysql v1.3.2
gorm.io/driver/postgres v1.3.1 gorm.io/driver/postgres v1.3.1
gorm.io/driver/sqlite v1.3.1 gorm.io/driver/sqlite v1.3.1