diff --git a/main.go b/main.go index fda63d29..cc8ac68c 100644 --- a/main.go +++ b/main.go @@ -12,6 +12,7 @@ import ( // DB contains information for current db connection type DB struct { + sync.Mutex Value interface{} Error error RowsAffected int64 @@ -170,7 +171,8 @@ func (s *DB) HasBlockGlobalUpdate() bool { // SingularTable use singular table by default func (s *DB) SingularTable(enable bool) { - modelStructsMap = sync.Map{} + s.parent.Lock() + defer s.parent.Unlock() s.parent.singularTable = enable } diff --git a/main_test.go b/main_test.go index ac40c32b..1dc30093 100644 --- a/main_test.go +++ b/main_test.go @@ -9,6 +9,7 @@ import ( "reflect" "strconv" "strings" + "sync" "testing" "time" @@ -277,6 +278,30 @@ func TestTableName(t *testing.T) { DB.SingularTable(false) } +func TestTableNameConcurrently(t *testing.T) { + DB := DB.Model("") + if DB.NewScope(Order{}).TableName() != "orders" { + t.Errorf("Order's table name should be orders") + } + + var wg sync.WaitGroup + wg.Add(10) + + for i := 1; i <= 10; i++ { + go func(db *gorm.DB) { + DB.SingularTable(true) + wg.Done() + }(DB) + } + wg.Wait() + + if DB.NewScope(Order{}).TableName() != "order" { + t.Errorf("Order's singular table name should be order") + } + + DB.SingularTable(false) +} + func TestNullValues(t *testing.T) { DB.DropTable(&NullValue{}) DB.AutoMigrate(&NullValue{}) @@ -1066,12 +1091,12 @@ func TestCountWithHaving(t *testing.T) { DB.Create(getPreparedUser("user1", "pluck_user")) DB.Create(getPreparedUser("user2", "pluck_user")) - user3:=getPreparedUser("user3", "pluck_user") - user3.Languages=[]Language{} + user3 := getPreparedUser("user3", "pluck_user") + user3.Languages = []Language{} DB.Create(user3) var count int - err:=db.Model(User{}).Select("users.id"). + err := db.Model(User{}).Select("users.id"). Joins("LEFT JOIN user_languages ON user_languages.user_id = users.id"). Joins("LEFT JOIN languages ON user_languages.language_id = languages.id"). Group("users.id").Having("COUNT(languages.id) > 1").Count(&count).Error @@ -1080,7 +1105,7 @@ func TestCountWithHaving(t *testing.T) { t.Error("Unexpected error on query count with having") } - if count!=2{ + if count != 2 { t.Error("Unexpected result on query count with having") } } diff --git a/model_struct.go b/model_struct.go index f646910a..8d6313fb 100644 --- a/model_struct.go +++ b/model_struct.go @@ -40,9 +40,11 @@ func (s *ModelStruct) TableName(db *DB) string { s.defaultTableName = tabler.TableName() } else { tableName := ToTableName(s.ModelType.Name()) + db.parent.Lock() if db == nil || (db.parent != nil && !db.parent.singularTable) { tableName = inflection.Plural(tableName) } + db.parent.Unlock() s.defaultTableName = tableName } } @@ -163,7 +165,18 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // Get Cached model struct - if value, ok := modelStructsMap.Load(reflectType); ok && value != nil { + isSingularTable := false + if scope.db != nil && scope.db.parent != nil { + scope.db.parent.Lock() + isSingularTable = scope.db.parent.singularTable + scope.db.parent.Unlock() + } + + hashKey := struct { + singularTable bool + reflectType reflect.Type + }{isSingularTable, reflectType} + if value, ok := modelStructsMap.Load(hashKey); ok && value != nil { return value.(*ModelStruct) } @@ -612,7 +625,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } - modelStructsMap.Store(reflectType, &modelStruct) + modelStructsMap.Store(hashKey, &modelStruct) return &modelStruct }