Fix failed to parse relations when using goroutinue, close #3790

commit ee0ec43e8dfa85c1c1a562c2d0d47776cf8abd92
Author: Jinzhu <wosmvp@gmail.com>
Date:   Fri Nov 27 14:31:57 2020 +0800

    Fix failed to parse relations when using goroutinue, close #3790

commit 590e73ff95
Author: rokeyzhao <rokeyzhao@tencent.com>
Date:   Thu Nov 26 20:27:55 2020 +0800

    test: no cache preload in goroutine
This commit is contained in:
Jinzhu 2020-11-27 14:32:20 +08:00
parent 557b874ee3
commit 6950007d6a
6 changed files with 54 additions and 7 deletions

View File

@ -330,7 +330,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
cacheStore := &sync.Map{}
cacheStore.Store(embeddedCacheKey, true)
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil {
if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil {
schema.err = err
}

View File

@ -71,7 +71,7 @@ func (schema *Schema) parseRelation(field *Field) {
cacheStore = field.OwnerSchema.cacheStore
}
if relation.FieldSchema, err = Parse(fieldValue, cacheStore, schema.namer); err != nil {
if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil {
schema.err = err
return
}

View File

@ -38,6 +38,7 @@ type Schema struct {
BeforeSave, AfterSave bool
AfterFind bool
err error
initialized chan struct{}
namer Namer
cacheStore *sync.Map
}
@ -89,7 +90,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
}
if v, ok := cacheStore.Load(modelType); ok {
return v.(*Schema), nil
s := v.(*Schema)
<-s.initialized
return s, nil
}
modelValue := reflect.New(modelType)
@ -110,6 +113,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
Relationships: Relationships{Relations: map[string]*Relationship{}},
cacheStore: cacheStore,
namer: namer,
initialized: make(chan struct{}),
}
defer func() {
@ -219,7 +223,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
}
}
if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded {
if s, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded {
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
for _, field := range schema.Fields {
if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
@ -245,8 +249,31 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...)
}
}
close(schema.initialized)
}
} else {
return s.(*Schema), nil
}
return schema, schema.err
}
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
modelType := reflect.ValueOf(dest).Type()
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType.Kind() != reflect.Struct {
if modelType.PkgPath() == "" {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
}
if v, ok := cacheStore.Load(modelType); ok {
return v.(*Schema), nil
}
return Parse(dest, cacheStore, namer)
}

View File

@ -6,6 +6,7 @@ require (
github.com/google/uuid v1.1.1
github.com/jinzhu/now v1.1.1
github.com/lib/pq v1.6.0
github.com/stretchr/testify v1.5.1
gorm.io/driver/mysql v1.0.3
gorm.io/driver/postgres v1.0.5
gorm.io/driver/sqlite v1.1.3

View File

@ -5,6 +5,7 @@ import (
"encoding/json"
"reflect"
"sort"
"sync/atomic"
"testing"
"gorm.io/gorm"
@ -1497,10 +1498,9 @@ func TestPreloadManyToManyCallbacks(t *testing.T) {
}
DB.Save(&lvl)
called := 0
var called int64
DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(_ *gorm.DB) {
called = called + 1
atomic.AddInt64(&called, 1)
})
DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID)

View File

@ -5,6 +5,7 @@ import (
"regexp"
"sort"
"strconv"
"sync"
"testing"
"gorm.io/gorm"
@ -212,3 +213,21 @@ func TestPreloadEmptyData(t *testing.T) {
t.Errorf("json marshal is not empty slice, got %v", string(r))
}
}
func TestPreloadGoroutine(t *testing.T) {
var wg sync.WaitGroup
wg.Add(10)
for i := 0; i < 10; i++ {
go func() {
defer wg.Done()
var user2 []User
tx := DB.Where("id = ?", 1).Session(&gorm.Session{})
if err := tx.Preload("Team").Find(&user2).Error; err != nil {
t.Error(err)
}
}()
}
wg.Wait()
}