Merge pull request #2721 from rubensayshi/isforeignkeyrace

fix a race condition on IsForeignKey that is being detected by -race
This commit is contained in:
Emir Beganović 2019-10-28 12:09:32 +04:00 committed by GitHub
commit 2586a05016
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 110 additions and 2 deletions

View File

@ -17,6 +17,10 @@ var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
return defaultTableName return defaultTableName
} }
// lock for mutating global cached model metadata
var structsLock sync.Mutex
// global cache of model metadata
var modelStructsMap sync.Map var modelStructsMap sync.Map
// ModelStruct model definition // ModelStruct model definition
@ -419,8 +423,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
for idx, foreignKey := range foreignKeys { for idx, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil { if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil {
// source foreign keys // mark field as foreignkey, use global lock to avoid race
structsLock.Lock()
foreignField.IsForeignKey = true foreignField.IsForeignKey = true
structsLock.Unlock()
// association foreign keys
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)
@ -523,8 +531,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
for idx, foreignKey := range foreignKeys { for idx, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil { if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil {
// mark field as foreignkey, use global lock to avoid race
structsLock.Lock()
foreignField.IsForeignKey = true foreignField.IsForeignKey = true
// source foreign keys structsLock.Unlock()
// association foreign keys
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name)
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName)
@ -582,7 +594,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
for idx, foreignKey := range foreignKeys { for idx, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil { if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil {
// mark field as foreignkey, use global lock to avoid race
structsLock.Lock()
foreignField.IsForeignKey = true foreignField.IsForeignKey = true
structsLock.Unlock()
// association foreign keys // association foreign keys
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)

93
model_struct_test.go Normal file
View File

@ -0,0 +1,93 @@
package gorm_test
import (
"sync"
"testing"
"github.com/jinzhu/gorm"
)
type ModelA struct {
gorm.Model
Name string
ModelCs []ModelC `gorm:"foreignkey:OtherAID"`
}
type ModelB struct {
gorm.Model
Name string
ModelCs []ModelC `gorm:"foreignkey:OtherBID"`
}
type ModelC struct {
gorm.Model
Name string
OtherAID uint64
OtherA *ModelA `gorm:"foreignkey:OtherAID"`
OtherBID uint64
OtherB *ModelB `gorm:"foreignkey:OtherBID"`
}
// This test will try to cause a race condition on the model's foreignkey metadata
func TestModelStructRaceSameModel(t *testing.T) {
// use a WaitGroup to execute as much in-sync as possible
// it's more likely to hit a race condition than without
n := 32
start := sync.WaitGroup{}
start.Add(n)
// use another WaitGroup to know when the test is done
done := sync.WaitGroup{}
done.Add(n)
for i := 0; i < n; i++ {
go func() {
start.Wait()
// call GetStructFields, this had a race condition before we fixed it
DB.NewScope(&ModelA{}).GetStructFields()
done.Done()
}()
start.Done()
}
done.Wait()
}
// This test will try to cause a race condition on the model's foreignkey metadata
func TestModelStructRaceDifferentModel(t *testing.T) {
// use a WaitGroup to execute as much in-sync as possible
// it's more likely to hit a race condition than without
n := 32
start := sync.WaitGroup{}
start.Add(n)
// use another WaitGroup to know when the test is done
done := sync.WaitGroup{}
done.Add(n)
for i := 0; i < n; i++ {
i := i
go func() {
start.Wait()
// call GetStructFields, this had a race condition before we fixed it
if i%2 == 0 {
DB.NewScope(&ModelA{}).GetStructFields()
} else {
DB.NewScope(&ModelB{}).GetStructFields()
}
done.Done()
}()
start.Done()
}
done.Wait()
}