From 4da2c28d4dc43372aba7d8ac9260c026e2c46098 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 1 Oct 2015 07:09:00 +0800 Subject: [PATCH] Fix data race warning when get cached model struct --- main.go | 2 +- model_struct.go | 30 ++++++++++++++++++++++++++---- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/main.go b/main.go index 657342e9..870b7cfb 100644 --- a/main.go +++ b/main.go @@ -130,7 +130,7 @@ func (s *DB) LogMode(enable bool) *DB { } func (s *DB) SingularTable(enable bool) { - modelStructs = map[reflect.Type]*ModelStruct{} + modelStructsMap = newModelStructsMap() s.parent.singularTable = enable } diff --git a/model_struct.go b/model_struct.go index f5fc5797..bafea8e8 100644 --- a/model_struct.go +++ b/model_struct.go @@ -7,17 +7,39 @@ import ( "reflect" "strconv" "strings" + "sync" "time" "github.com/qor/inflection" ) -var modelStructs = map[reflect.Type]*ModelStruct{} - var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { return defaultTableName } +type safeModelStructsMap struct { + m map[reflect.Type]*ModelStruct + l *sync.RWMutex +} + +func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) { + s.l.Lock() + defer s.l.Unlock() + s.m[key] = value +} + +func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct { + s.l.RLock() + defer s.l.RUnlock() + return s.m[key] +} + +func newModelStructsMap() *safeModelStructsMap { + return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)} +} + +var modelStructsMap = newModelStructsMap() + type ModelStruct struct { PrimaryFields []*StructField StructFields []*StructField @@ -92,7 +114,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { scopeType = scopeType.Elem() } - if value, ok := modelStructs[scopeType]; ok { + if value := modelStructsMap.Get(scopeType); value != nil { return value } @@ -370,7 +392,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { finished <- true }(finished) - modelStructs[scopeType] = &modelStruct + modelStructsMap.Set(scopeType, &modelStruct) <-finished modelStruct.cached = true