diff --git a/schema/field.go b/schema/field.go index b303fb30..86b4a061 100644 --- a/schema/field.go +++ b/schema/field.go @@ -330,7 +330,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { cacheStore := &sync.Map{} cacheStore.Store(embeddedCacheKey, true) - if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil { + if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil { schema.err = err } diff --git a/schema/relationship.go b/schema/relationship.go index 35af111f..9cfc10be 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -71,7 +71,7 @@ func (schema *Schema) parseRelation(field *Field) { cacheStore = field.OwnerSchema.cacheStore } - if relation.FieldSchema, err = Parse(fieldValue, cacheStore, schema.namer); err != nil { + if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { schema.err = err return } diff --git a/schema/schema.go b/schema/schema.go index 05db641f..89392643 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -38,6 +38,7 @@ type Schema struct { BeforeSave, AfterSave bool AfterFind bool err error + initialized chan struct{} namer Namer cacheStore *sync.Map } @@ -89,7 +90,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } if v, ok := cacheStore.Load(modelType); ok { - return v.(*Schema), nil + s := v.(*Schema) + <-s.initialized + return s, nil } modelValue := reflect.New(modelType) @@ -110,6 +113,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) Relationships: Relationships{Relations: map[string]*Relationship{}}, cacheStore: cacheStore, namer: namer, + initialized: make(chan struct{}), } defer func() { @@ -219,7 +223,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } - if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded { + if s, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded { if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { for _, field := range schema.Fields { if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { @@ -245,8 +249,31 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) } } + close(schema.initialized) } + } else { + return s.(*Schema), nil } return schema, schema.err } + +func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { + modelType := reflect.ValueOf(dest).Type() + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) + } + + if v, ok := cacheStore.Load(modelType); ok { + return v.(*Schema), nil + } + + return Parse(dest, cacheStore, namer) +} diff --git a/tests/go.mod b/tests/go.mod index 55495de3..fa293987 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,6 +6,7 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 + github.com/stretchr/testify v1.5.1 gorm.io/driver/mysql v1.0.3 gorm.io/driver/postgres v1.0.5 gorm.io/driver/sqlite v1.1.3 diff --git a/tests/preload_suits_test.go b/tests/preload_suits_test.go index d40309e7..0ef8890b 100644 --- a/tests/preload_suits_test.go +++ b/tests/preload_suits_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "reflect" "sort" + "sync/atomic" "testing" "gorm.io/gorm" @@ -1497,10 +1498,9 @@ func TestPreloadManyToManyCallbacks(t *testing.T) { } DB.Save(&lvl) - called := 0 - + var called int64 DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(_ *gorm.DB) { - called = called + 1 + atomic.AddInt64(&called, 1) }) DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID) diff --git a/tests/preload_test.go b/tests/preload_test.go index d9035661..4b31b12c 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -5,6 +5,7 @@ import ( "regexp" "sort" "strconv" + "sync" "testing" "gorm.io/gorm" @@ -212,3 +213,21 @@ func TestPreloadEmptyData(t *testing.T) { t.Errorf("json marshal is not empty slice, got %v", string(r)) } } + +func TestPreloadGoroutine(t *testing.T) { + var wg sync.WaitGroup + + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + var user2 []User + tx := DB.Where("id = ?", 1).Session(&gorm.Session{}) + + if err := tx.Preload("Team").Find(&user2).Error; err != nil { + t.Error(err) + } + }() + } + wg.Wait() +}