forked from mirror/gorm
Merge pull request #2721 from rubensayshi/isforeignkeyrace
fix a race condition on IsForeignKey that is being detected by -race
This commit is contained in:
commit
2586a05016
|
@ -17,6 +17,10 @@ var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
|
||||||
return defaultTableName
|
return defaultTableName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// lock for mutating global cached model metadata
|
||||||
|
var structsLock sync.Mutex
|
||||||
|
|
||||||
|
// global cache of model metadata
|
||||||
var modelStructsMap sync.Map
|
var modelStructsMap sync.Map
|
||||||
|
|
||||||
// ModelStruct model definition
|
// ModelStruct model definition
|
||||||
|
@ -419,8 +423,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
for idx, foreignKey := range foreignKeys {
|
for idx, foreignKey := range foreignKeys {
|
||||||
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
|
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
|
||||||
if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil {
|
if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil {
|
||||||
// source foreign keys
|
// mark field as foreignkey, use global lock to avoid race
|
||||||
|
structsLock.Lock()
|
||||||
foreignField.IsForeignKey = true
|
foreignField.IsForeignKey = true
|
||||||
|
structsLock.Unlock()
|
||||||
|
|
||||||
|
// association foreign keys
|
||||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
|
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
|
||||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)
|
||||||
|
|
||||||
|
@ -523,8 +531,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
for idx, foreignKey := range foreignKeys {
|
for idx, foreignKey := range foreignKeys {
|
||||||
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
|
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
|
||||||
if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil {
|
if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil {
|
||||||
|
// mark field as foreignkey, use global lock to avoid race
|
||||||
|
structsLock.Lock()
|
||||||
foreignField.IsForeignKey = true
|
foreignField.IsForeignKey = true
|
||||||
// source foreign keys
|
structsLock.Unlock()
|
||||||
|
|
||||||
|
// association foreign keys
|
||||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name)
|
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name)
|
||||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName)
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName)
|
||||||
|
|
||||||
|
@ -582,7 +594,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
for idx, foreignKey := range foreignKeys {
|
for idx, foreignKey := range foreignKeys {
|
||||||
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
|
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
|
||||||
if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil {
|
if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil {
|
||||||
|
// mark field as foreignkey, use global lock to avoid race
|
||||||
|
structsLock.Lock()
|
||||||
foreignField.IsForeignKey = true
|
foreignField.IsForeignKey = true
|
||||||
|
structsLock.Unlock()
|
||||||
|
|
||||||
// association foreign keys
|
// association foreign keys
|
||||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
|
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
package gorm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ModelA struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
|
||||||
|
ModelCs []ModelC `gorm:"foreignkey:OtherAID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelB struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
|
||||||
|
ModelCs []ModelC `gorm:"foreignkey:OtherBID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelC struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
|
||||||
|
OtherAID uint64
|
||||||
|
OtherA *ModelA `gorm:"foreignkey:OtherAID"`
|
||||||
|
OtherBID uint64
|
||||||
|
OtherB *ModelB `gorm:"foreignkey:OtherBID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// This test will try to cause a race condition on the model's foreignkey metadata
|
||||||
|
func TestModelStructRaceSameModel(t *testing.T) {
|
||||||
|
// use a WaitGroup to execute as much in-sync as possible
|
||||||
|
// it's more likely to hit a race condition than without
|
||||||
|
n := 32
|
||||||
|
start := sync.WaitGroup{}
|
||||||
|
start.Add(n)
|
||||||
|
|
||||||
|
// use another WaitGroup to know when the test is done
|
||||||
|
done := sync.WaitGroup{}
|
||||||
|
done.Add(n)
|
||||||
|
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
go func() {
|
||||||
|
start.Wait()
|
||||||
|
|
||||||
|
// call GetStructFields, this had a race condition before we fixed it
|
||||||
|
DB.NewScope(&ModelA{}).GetStructFields()
|
||||||
|
|
||||||
|
done.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
start.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
done.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// This test will try to cause a race condition on the model's foreignkey metadata
|
||||||
|
func TestModelStructRaceDifferentModel(t *testing.T) {
|
||||||
|
// use a WaitGroup to execute as much in-sync as possible
|
||||||
|
// it's more likely to hit a race condition than without
|
||||||
|
n := 32
|
||||||
|
start := sync.WaitGroup{}
|
||||||
|
start.Add(n)
|
||||||
|
|
||||||
|
// use another WaitGroup to know when the test is done
|
||||||
|
done := sync.WaitGroup{}
|
||||||
|
done.Add(n)
|
||||||
|
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
i := i
|
||||||
|
go func() {
|
||||||
|
start.Wait()
|
||||||
|
|
||||||
|
// call GetStructFields, this had a race condition before we fixed it
|
||||||
|
if i%2 == 0 {
|
||||||
|
DB.NewScope(&ModelA{}).GetStructFields()
|
||||||
|
} else {
|
||||||
|
DB.NewScope(&ModelB{}).GetStructFields()
|
||||||
|
}
|
||||||
|
|
||||||
|
done.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
start.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
done.Wait()
|
||||||
|
}
|
Loading…
Reference in New Issue